当前位置: 首页 > news >正文

中国免费网站服务器下载电子商务网站建设功能

中国免费网站服务器下载,电子商务网站建设功能,威海市住房和城乡建设局网站,网络营销方案格式最近想尝试下使用GNN A2C 进行强化学习#xff0c;GNN 可以充当一个特征提取器#xff0c;这样可以增加强化学状态空间因为张量长度受限泛化能力不足的缺点#xff0c;之前做强化学习的时候受限于需要在环境里提取每个对手的特征#xff0c;在每个不同场景下因为对手的数量…最近想尝试下使用GNN A2C 进行强化学习GNN 可以充当一个特征提取器这样可以增加强化学状态空间因为张量长度受限泛化能力不足的缺点之前做强化学习的时候受限于需要在环境里提取每个对手的特征在每个不同场景下因为对手的数量是变化的对应的状态空间也得一一对应每个场景的训练都是定制化的。 提高泛化能力一直是模型训练推理的一个课题。 恰好最近在看图神经网络相关的内容这里贴点废话回头找工作面试官问起来可以喷一些 图神经网络GNN是一种专门用于处理图结构数据的深度学习方法具有以下优点 优点 处理非欧几里得数据GNN 能够有效处理图结构数据而传统的神经网络主要处理的是欧几里得数据如图像和文本。 捕捉节点间关系GNN 可以利用节点之间的连接关系来捕捉复杂的结构信息从而更好地理解数据的上下文。 灵活性GNN 可以应用于各种类型的图包括无向图、有向图和加权图适应性强。 共享参数GNN 的参数可以在图的不同部分共享这减少了模型的复杂性和训练时间。 强大的特征学习能力GNN 能够自动学习节点的特征表示并通过聚合邻居节点的信息来更新自身的特征。 适用于多种任务GNN 可以用于节点分类、边预测、图分类等多种任务。 图神经网络通过以下方式扩展了特征 邻居信息聚合GNN 通过聚合邻居节点的特征将局部结构信息融入到节点的特征中从而生成更丰富的特征表示。 多层堆叠通过多层堆叠的网络结构GNN 能够逐步捕捉更高阶的邻接信息使得节点的特征不仅反映自身的信息还能反映其邻居的特征和关系。 动态更新节点特征在每一层中不断更新使得特征能够随着图的结构变化而变化从而增强了模型的表达能力。 图神经网络GNN与强化学习RL的结合形成了图强化学习Graph Reinforcement Learning这种结合具有多种优点 结构化数据处理GNN 能够有效处理图结构数据使得 RL 能够在复杂的环境中如社交网络、交通网络等做出更好的决策。 信息传递GNN 通过节点间的信息传递将邻居节点的状态和特征引入到决策过程中提高了智能体对环境的理解。 特征学习GNN 可以自动学习图中节点的特征表示帮助强化学习算法更好地估计状态值和动作值提升策略的性能。 上下文感知结合 GNN 的强化学习能够更好地捕捉环境的动态变化适应不同的上下文从而提高决策的灵活性和准确性。 这么说可能很多人还不是清楚优点具体是什么下面我用个试验的例子来说明 首先制作CS架构的粒子干扰避障的游戏这里就不细讲了直接上代码 服务端代码 # server.py from flask import Flask, request, jsonify import threading import randomapp Flask(__name__)# 存储主球的位置和粒子 clients {} particles [] particles_number 30# 初始化粒子 def generate_particles():while len(particles) particles_number: # 生成初始粒子particles.append({x: random.randint(0, 500),y: random.randint(0, 500),vx: random.choice([-6, -3, -1, 1, 3, 6]),vy: random.choice([-6, -3, -1, 1, 3, 6])})def update_particles():while True:for particle in particles:# 更新粒子位置particle[x] particle[vx]particle[y] particle[vy]# 碰撞边界处理if particle[x] 0 or particle[x] 500:particle[vx] * -1if particle[y] 0 or particle[y] 500:particle[vy] * -1threading.Event().wait(0.1)app.route(/register, methods[POST]) def register_client():client_id request.json.get(id)clients[client_id] {position: {x: 250, y: 250}} # 初始化主球位置return jsonify(successTrue)width, height 500, 500 ball_radius 15app.route(/move/client_id, methods[POST]) def move(client_id):direction request.json.get(direction)if client_id in clients:# 获取当前球的位置position clients[client_id][position]if direction up:new_y position[y] - 10if new_y 0: # 确保不超出上边界position[y] new_yelif direction down:new_y position[y] 10if new_y height - ball_radius: # 确保不超出下边界 (减去半径)position[y] new_yelif direction left:new_x position[x] - 10if new_x 0: # 确保不超出左边界position[x] new_xelif direction right:new_x position[x] 10if new_x width - ball_radius: # 确保不超出右边界 (减去半径)position[x] new_xreturn jsonify(clients[client_id][position])app.route(/position/client_id, methods[GET]) def get_position(client_id):if client_id in clients:return jsonify(clients[client_id][position])else:return jsonify({error: Client not found}), 404app.route(/particles, methods[GET]) def get_particles():return jsonify(particles)def run_server():app.run(host0.0.0.0, port5000, threadedTrue)if __name__ __main__:threading.Thread(targetgenerate_particles, daemonTrue).start()threading.Thread(targetupdate_particles, daemonTrue).start()run_server() 客户端代码 import pygame import requests# 初始化pygame pygame.init()# 设置窗口大小 width, height 500, 500 window pygame.display.set_mode((width, height)) pygame.display.set_caption(Particle Avoidance Game)# 颜色 BLACK (0, 0, 0) # 黑色 BLUE (0, 0, 255) RED (255, 0, 0)# 主球初始位置 ball_radius 15 client_id client1 # 确保每个客户端使用不同的 ID# 注册客户端 register_response requests.post(http://127.0.0.1:5000/register, json{id: client_id}) if register_response.status_code ! 200:print(Failed to register client:, register_response.text)def get_ball_position():response requests.get(fhttp://127.0.0.1:5000/position/{client_id})if response.status_code 200:ball_pos response.json()print(Ball position response:, ball_pos) # 打印响应if x in ball_pos and y in ball_pos: # 确保包含 x 和 yreturn ball_poselse:print(Ball position does not contain x and y:, ball_pos) # 额外调试信息else:print(Failed to get ball position:, response.text)return {x: 250, y: 250} # 默认值def move_ball(direction):requests.post(fhttp://127.0.0.1:5000/move/{client_id}, json{direction: direction})def get_particles():response requests.get(http://127.0.0.1:5000/particles)if response.status_code 200:particles response.json()print(Particles response:, particles) # 打印响应return particleselse:print(Failed to get particles:, response.text)return [] # 返回空列表def check_collision(ball_pos, particle_pos):distance ((ball_pos[x] - particle_pos[x]) ** 2 (ball_pos[y] - particle_pos[y]) ** 2) ** 0.5return distance (ball_radius 5) # 粒子的半径为5running True while running:for event in pygame.event.get():if event.type pygame.QUIT:running Falsekeys pygame.key.get_pressed()if keys[pygame.K_UP]:move_ball(up)if keys[pygame.K_DOWN]:move_ball(down)if keys[pygame.K_LEFT]:move_ball(left)if keys[pygame.K_RIGHT]:move_ball(right)# 更新球的位置ball_position get_ball_position()particles get_particles()# 检查碰撞if x in ball_position and y in ball_position: # 确保球有有效位置for particle in particles:if x in particle and y in particle: # 确保粒子有有效位置if check_collision(ball_position, particle):print(Game Over! You collided with a particle.)running Falseelse:print(Invalid ball position:, ball_position) # 调试信息# 渲染window.fill(BLACK) # 将背景填充为黑色if x in ball_position and y in ball_position: # 确保球有有效位置pygame.draw.circle(window, BLUE, (ball_position[x], ball_position[y]), ball_radius)else:print(Ball position is invalid, not drawing.) # 调试信息for particle in particles:if x in particle and y in particle: # 确保粒子有有效位置pygame.draw.circle(window, RED, (particle[x], particle[y]), 5) # 画粒子pygame.display.flip()pygame.time.delay(100)pygame.quit() 启动程序调试了下基本可以充当强化学习的环境 后续直接在客户端上添加强化模型的训练代码把server端代码部署到k8s 上做为训练环境 要设计一个基于图神经网络GNN和优势演员-评论家A2C算法的强化学习模型以训练 main_ball 在环境中左右移动我们可以遵循以下步骤 1. 确定问题 目标控制 main_ball 左右移动以避免与粒子碰撞。状态空间包含 main_ball 和粒子的状态信息包括位置和速度。动作空间定义 main_ball 的动作左右上下或保持不动。 2. 数据结构和环境设计 首先构建一个环境来模拟 main_ball 和粒子的动态行为。 class Environment:def __init__(self, main_ball, particles):self.main_ball main_ballself.particles particlesdef reset(self):# 重置环境返回初始状态return self.get_state()def get_state(self):# 获取当前状态state {main_ball: self.main_ball,particles: self.particles}return statedef step(self, action):# 根据动作更新环境状态按照上右下左顺序if not action_space[action]:move_ball(action_space[action])# 定义奖励和终止标志reward 0done Falsenew_main_ball get_ball_position()particles get_particles()COLLISION_THRESHOLD 20for particle in self.particles:distance ((self.main_ball[x] - particle[x]) ** 2 (self.main_ball[y] - particle[y]) ** 2) ** 0.5if distance COLLISION_THRESHOLD: # 定义一个阈值判断碰撞reward -1 # 碰撞时给予负奖励break# 检查是否重新开始if x in new_main_ball and y in new_main_ball: # 确保球有有效位置for particle in particles:if x in particle and y in particle: # 确保粒子有有效位置if check_collision(new_main_ball, particle):print(Game Over! You collided with a particle.)done True # 结束游戏breakif not done:reward 5return self.get_state(), reward, done, {} 3. GNN 构建 使用 DGL 构建图结构以表示 main_ball 和粒子之间的关系。 import dgl import torchdef create_graph(main_ball, particles):num_particles len(particles)G dgl.graph(([], []), num_nodesnum_particles 1)# 添加主球节点G.ndata[pos] torch.zeros(num_particles 1, 2)G.ndata[pos][0] torch.tensor([main_ball[x], main_ball[y]])# 添加粒子节点for i, particle in enumerate(particles):G.ndata[pos][i 1] torch.tensor([particle[x], particle[y]])# 添加边和距离权重edges []distances []for i in range(num_particles):for j in range(i 1, num_particles):edges.append((i 1, j 1)) # 粒子之间的边# 计算距离并存储distance torch.norm(G.ndata[pos][i 1] - G.ndata[pos][j 1])distances.append(distance.item())edges.append((0, i 1)) # 主球与粒子之间的边# 计算距离并存储distance torch.norm(G.ndata[pos][0] - G.ndata[pos][i 1])distances.append(distance.item())G.add_edges(*zip(*edges))# 将距离作为边的特征G.edata[distance] torch.tensor(distances)# 添加自环G dgl.add_self_loop(G)return G 4. A2C 模型设计 使用 PyTorch 设计 A2C 模型包括 Actor 和 Critic。 在强化学习中Actor演员和Critic评论家是两个关键的角色或组件。 Actor演员 Actor 是强化学习中的一个组件通常用于确定在给定状态下应该采取的动作。它负责根据当前状态选择动作并将其发送给环境。 Actor 的目标是学习一个策略即从状态到动作的映射函数以便在与环境的交互中获得高回报。策略可以是确定性的直接选择最优动作或概率性的选择动作的概率分布。Actor 的训练目标是最大化预期回报通常使用梯度上升方法如策略梯度法进行优化。 Critic评论家 Critic 是强化学习中的另一个组件用于评估 Actor 的动作选择。它通过对当前状态和采取的动作进行评估 提供一个值函数或者动作值函数来估计在给定策略下获得的长期回报。Critic 的目标是学习一个值函数 用于评估不同状态-动作对的价值并提供即时的反馈信号。Critic 的训练目标是最小化值函数的预测误差 通常使用时序差分学习如 Q-learning 或 TD-learning或函数逼近方法如神经网络进行优化。 在某些强化学习算法中Actor 和 Critic 可以是分离的组件各自独立进行训练。 Actor 使用 Critic 提供的价值信息来指导动作选择 而 Critic 使用 Actor 选择的动作进行评估和训练。它们通过交互和相互反馈来改善策略和值函数的性能。 class ActorCritic(nn.Module):def __init__(self, n_devices, action_space_dim):super(ActorCritic, self).__init__()self.conv1 dgl.nn.GraphConv(2, 128) # 输入特征为 2位置self.conv2 dgl.nn.GraphConv(128, (1 action_space_dim)) # 隐藏层self.commonCov nn.Linear((len(particles) 1 ) * (1 action_space_dim), 128)self.actor nn.Linear(128, action_space_dim) # 行动空间self.critic nn.Linear(128, 1)def forward(self, g):g dgl.add_self_loop(g)x g.ndata[pos]x self.conv1(g, x)x F.relu(x)x self.conv2(g, x)x F.relu(x)x x.reshape(-1)x torch.relu(self.commonCov(x))actor self.actor(x)critic self.critic(x)return actor, critic 5. 训练循环 设计训练循环使用强化学习算法更新模型的参数。 # 初始化环境和模型 main_ball get_ball_position() particles get_particles() env Environment(main_ball, particles) device_id 1 if torch.cuda.is_available():device_id torch.cuda.current_device() model ActorCritic(n_devicesdevice_id, action_space_dimlen(action_space)).to(device) # 检查模型文件是否存在 model_path actor_critic_model.pth if os.path.exists(model_path):# 加载保存的模型状态model.load_state_dict(torch.load(model_path))print(Model loaded successfully.)model.train() # 切换到训练模式 optimizer optim.Adam(model.parameters(), lr0.01)# 训练循环 for episode in range(num_episodes):state env.reset()done Falsewhile not done:# 创建图g create_graph(state[main_ball], state[particles])# 前向传播actor_logits, critic_value model(g)# 选择动作使用 softmaxaction_prob F.softmax(actor_logits, dim-1)# 使用 torch.multinomial 选择一个动作action torch.multinomial(action_prob, num_samples1).item()# 执行动作并获取下一个状态和奖励next_state, reward, done, _ env.step(action)# 这里需要存储轨迹并计算损失# 更新模型参数optimizer.zero_grad()_, next_value model(g)td_target reward 0.95 * next_valuedelta td_target - critic_value这是actor网络的损失函数。目标是最大化选择当前动作的对数概率乘以TD误差。乘以delta.detach()是为了使actor网络的更新不会影响critic网络的预测。actor_loss -action_prob[action] * delta.detach()critic_loss delta.pow(2)loss actor_loss critic_lossprint(epoch %d, device_id %d, epoch_loss %lf % (episode, device_id, loss.item()))loss.backward()optimizer.step()if done:restart_game()state next_state# 保存模型状态torch.save(model.state_dict(), model_path)print(fModel saved after episode {episode}.) 通过以上步骤你可以构建一个 GNN A2C 的强化学习模型来训练 main_ball 的左右移动策略。 启动训练 num_episodes 1000 强化学习在对接仿真有没有gpu加速速度并不明显因为受限于仿真state-action-next_state 这套流程的处理速度我使用在本机上跑该训练训练结束后保存模型文件到actor_critic_model.pth 6. 模型推理 使用 actor_critic_model.pth接入到之前的客户端代码上用模型决策来替换键盘操作 import os import pygame import requests import random import string import torch import torch.nn.functional as F from g1_client_train import ActorCritic, create_graph, action_space# 初始化pygame pygame.init()# 设置窗口大小 width, height 500, 500 window pygame.display.set_mode((width, height)) pygame.display.set_caption(Particle Avoidance Game)# 颜色 BLACK (0, 0, 0) # 黑色 BLUE (0, 0, 255) RED (255, 0, 0)# 主球初始位置 ball_radius 15 if torch.cuda.is_available():# 使用 CUDA 设备device torch.device(cuda) else:# 使用 CPU 设备device torch.device(cpu) def generate_random_string(length5):# 定义可用字符包括数字和大小写字母characters string.ascii_letters string.digits# 随机选择字符并生成字符串random_string .join(random.choice(characters) for _ in range(length))return random_string# 生成并打印随机字符串 random_string generate_random_string() print(random_string) client_idrandom_stringhost os.getenv(SERVER_HOST, http://192.168.110.126:31007)def register_client():register_response requests.post(f{host}/register, json{id: client_id})if register_response.status_code ! 200:print(Failed to register client:, register_response.text)return Falsereturn Truedef get_ball_position():response requests.get(f{host}/position/{client_id})if response.status_code 200:ball_pos response.json()print(Ball position response:, ball_pos) # 打印响应if x in ball_pos and y in ball_pos: # 确保包含 x 和 yreturn ball_posprint(Failed to get ball position:, response.text)return None # 返回 None 表示失败def move_ball(direction):requests.post(f{host}/move/{client_id}, json{direction: direction})def get_particles():response requests.get(f{host}/particles)if response.status_code 200:particles response.json()print(Particles response:, particles) # 打印响应return particlesprint(Failed to get particles:, response.text)return [] # 返回空列表def check_collision(ball_pos, particle_pos):distance ((ball_pos[x] - particle_pos[x]) ** 2 (ball_pos[y] - particle_pos[y]) ** 2) ** 0.5return distance (ball_radius 5) # 粒子的半径为5def restart_game():global runningprint(Restarting game...)running Trueregister_client() # 重新注册客户端if __name__ __main__:# 注册客户端if not register_client():exit()device_id 1if torch.cuda.is_available():device_id torch.cuda.current_device()# 更新球的位置main_ball get_ball_position()particles get_particles()model ActorCritic(n_devicesdevice_id,particles_dimlen(particles),action_space_dimlen(action_space)).to(device)model_path actor_critic_model.pthmodel.load_state_dict(torch.load(model_path))running Truewhile running:for event in pygame.event.get():if event.type pygame.QUIT:running Falsemain_ball get_ball_position()particles get_particles()g create_graph(main_ballmain_ball, particlesparticles)actor_logits, _ model(g)action_prob F.softmax(actor_logits, dim-1)action torch.multinomial(action_prob, num_samples1).item()move_ball(action_space[action])# 检查碰撞if main_ball: # 确保球有有效位置for particle in particles:if x in particle and y in particle: # 确保粒子有有效位置if check_collision(main_ball, particle):print(Game Over! You collided with a particle.)restart_game()breakelse:print(Invalid ball position, restarting game.)restart_game()# 渲染window.fill(BLACK) # 将背景填充为黑色if main_ball: # 确保球有有效位置pygame.draw.circle(window, BLUE, (main_ball[x], main_ball[y]), ball_radius)for particle in particles:if x in particle and y in particle: # 确保粒子有有效位置pygame.draw.circle(window, RED, (particle[x], particle[y]), 5) # 画粒子pygame.display.flip()pygame.time.delay(100)pygame.quit() 录个动图看看效果 有必要再优化下奖励函数每一步避险操作后的next_state应该是优于之前state 训练的好的模型奖惩函数规则都挺细的 附上完整的训练代码文件 import osimport requests import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import dgl import dgl.nn as dglnn import os import random import stringfrom g1_server import speed_choicesball_radius 15 width, height 500, 500 num_episodes 1000min_vx min(speed_choices) max_vx max(speed_choices) min_vy min(speed_choices) max_vy max(speed_choices)def register_client():register_response requests.post(f{host}/register, json{id: client_id})if register_response.status_code ! 200:print(Failed to register client:, register_response.text)return Falsereturn Truedef restart_game():global runningprint(Restarting game...)running Trueregister_client() # 重新注册客户端def generate_random_string(length5):# 定义可用字符包括数字和大小写字母characters string.ascii_letters string.digits# 随机选择字符并生成字符串random_string .join(random.choice(characters) for _ in range(length))return random_string# 生成并打印随机字符串 random_string generate_random_string() print(random_string) client_id random_stringhost os.getenv(SERVER_HOST, http://192.168.110.126:31007) # 注册客户端 register_response requests.post(f{host}/register, json{id: client_id}) if register_response.status_code ! 200:print(Failed to register client:, register_response.text)if torch.cuda.is_available():# 使用 CUDA 设备device torch.device(cuda) else:# 使用 CPU 设备device torch.device(cpu)def get_ball_position():response requests.get(f{host}/position/{client_id})if response.status_code 200:ball_pos response.json()print(Ball position response:, ball_pos) # 打印响应if x in ball_pos and y in ball_pos: # 确保包含 x 和 yreturn ball_poselse:print(Ball position does not contain x and y:, ball_pos) # 额外调试信息else:print(Failed to get ball position:, response.text)return {x: 250, y: 250} # 默认值def move_ball(direction):requests.post(f{host}/move/{client_id}, json{direction: direction})def get_particles():response requests.get(f{host}/particles)if response.status_code 200:particles response.json()print(Particles response:, particles) # 打印响应return particleselse:print(Failed to get particles:, response.text)return [] # 返回空列表def check_collision(ball_pos, particle_pos):distance ((ball_pos[x] - particle_pos[x]) ** 2 (ball_pos[y] - particle_pos[y]) ** 2) ** 0.5return distance (ball_radius 5) # 粒子的半径为5def create_graph(main_ball, particles):num_particles len(particles)G dgl.graph(([], []), num_nodesnum_particles 1)# 添加主球节点G.ndata[pos] torch.zeros(num_particles 1, 2)G.ndata[pos][0] torch.tensor([main_ball[x], main_ball[y]])# 添加粒子节点for i, particle in enumerate(particles):G.ndata[pos][i 1] torch.tensor([particle[x], particle[y]])# 添加边和距离权重edges []distances []for i in range(num_particles):for j in range(i 1, num_particles):edges.append((i 1, j 1)) # 粒子之间的边# 计算距离并存储distance torch.norm(G.ndata[pos][i 1] - G.ndata[pos][j 1])distances.append(distance.item())edges.append((0, i 1)) # 主球与粒子之间的边# 计算距离并存储distance torch.norm(G.ndata[pos][0] - G.ndata[pos][i 1])distances.append(distance.item())G.add_edges(*zip(*edges))# 将距离作为边的特征G.edata[distance] torch.tensor(distances)# 添加自环G dgl.add_self_loop(G)return Gdef compute_state(main_ball, particles):flow []normalized_main_ball_x main_ball[x] / widthnormalized_main_ball_y main_ball[y] / heightflow.append([0, 0, normalized_main_ball_x, normalized_main_ball_y])for particle in particles:# 归一化粒子的速度和位置normalized_vx (particle[vx] - min_vx) / (max_vx - min_vx) # 根据你的数据范围进行归一化normalized_vy (particle[vy] - min_vy) / (max_vy - min_vy)normalized_x particle[x] / widthnormalized_y particle[y] / heightflow.append([normalized_vx, normalized_vy, normalized_x, normalized_y])return torch.tensor(flow).float() 在强化学习中Actor演员和Critic评论家是两个关键的角色或组件。Actor演员 Actor 是强化学习中的一个组件通常用于确定在给定状态下应该采取的动作。它负责根据当前状态选择动作并将其发送给环境。 Actor 的目标是学习一个策略即从状态到动作的映射函数以便在与环境的交互中获得高回报。策略可以是确定性的直接选择最优动作或概率性的选择动作的概率分布。Actor 的训练目标是最大化预期回报通常使用梯度上升方法如策略梯度法进行优化。Critic评论家 Critic 是强化学习中的另一个组件用于评估 Actor 的动作选择。它通过对当前状态和采取的动作进行评估 提供一个值函数或者动作值函数来估计在给定策略下获得的长期回报。Critic 的目标是学习一个值函数 用于评估不同状态-动作对的价值并提供即时的反馈信号。Critic 的训练目标是最小化值函数的预测误差 通常使用时序差分学习如 Q-learning 或 TD-learning或函数逼近方法如神经网络进行优化。在某些强化学习算法中Actor 和 Critic 可以是分离的组件各自独立进行训练。 Actor 使用 Critic 提供的价值信息来指导动作选择 而 Critic 使用 Actor 选择的动作进行评估和训练。它们通过交互和相互反馈来改善策略和值函数的性能。class Environment:def __init__(self, main_ball, particles):self.main_ball main_ballself.particles particlesdef reset(self):# 重置环境返回初始状态return self.get_state()def get_state(self):# 获取当前状态state {main_ball: self.main_ball,particles: self.particles}return statedef step(self, action):# 根据动作更新环境状态按照上右下左顺序if not action_space[action]:move_ball(action_space[action])# 定义奖励和终止标志reward 0done Falsenew_main_ball get_ball_position()particles get_particles()COLLISION_THRESHOLD 20for particle in self.particles:distance ((self.main_ball[x] - particle[x]) ** 2 (self.main_ball[y] - particle[y]) ** 2) ** 0.5if distance COLLISION_THRESHOLD: # 定义一个阈值判断碰撞reward -1 # 碰撞时给予负奖励break# 检查是否重新开始if x in new_main_ball and y in new_main_ball: # 确保球有有效位置for particle in particles:if x in particle and y in particle: # 确保粒子有有效位置if check_collision(new_main_ball, particle):print(Game Over! You collided with a particle.)done True # 结束游戏breakif not done:reward 5return self.get_state(), reward, done, {}class ActorCritic(nn.Module):def __init__(self, n_devices, action_space_dim):super(ActorCritic, self).__init__()self.conv1 dgl.nn.GraphConv(2, 128) # 输入特征为 2位置self.conv2 dgl.nn.GraphConv(128, (1 action_space_dim)) # 隐藏层self.commonCov nn.Linear((len(particles) 1 ) * (1 action_space_dim), 128)self.actor nn.Linear(128, action_space_dim) # 行动空间self.critic nn.Linear(128, 1)def forward(self, g):g dgl.add_self_loop(g)x g.ndata[pos]x self.conv1(g, x)x F.relu(x)x self.conv2(g, x)x F.relu(x)x x.reshape(-1)x torch.relu(self.commonCov(x))actor self.actor(x)critic self.critic(x)return actor, criticaction_space [up, down, left, right, None]if not register_client():print(not register)exit()# 初始化环境和模型 main_ball get_ball_position() particles get_particles() env Environment(main_ball, particles) device_id 1 if torch.cuda.is_available():device_id torch.cuda.current_device() model ActorCritic(n_devicesdevice_id, action_space_dimlen(action_space)).to(device) # 检查模型文件是否存在 model_path actor_critic_model.pth if os.path.exists(model_path):# 加载保存的模型状态model.load_state_dict(torch.load(model_path))print(Model loaded successfully.)model.train() # 切换到训练模式 optimizer optim.Adam(model.parameters(), lr0.01)# 训练循环 for episode in range(num_episodes):state env.reset()done Falsewhile not done:# 创建图g create_graph(state[main_ball], state[particles])# 前向传播actor_logits, critic_value model(g)# 选择动作使用 softmaxaction_prob F.softmax(actor_logits, dim-1)# 使用 torch.multinomial 选择一个动作action torch.multinomial(action_prob, num_samples1).item()# 执行动作并获取下一个状态和奖励next_state, reward, done, _ env.step(action)# 这里需要存储轨迹并计算损失# 更新模型参数optimizer.zero_grad()_, next_value model(g)td_target reward 0.95 * next_valuedelta td_target - critic_value这是actor网络的损失函数。目标是最大化选择当前动作的对数概率乘以TD误差。乘以delta.detach()是为了使actor网络的更新不会影响critic网络的预测。actor_loss -action_prob[action] * delta.detach()critic_loss delta.pow(2)loss actor_loss critic_lossprint(epoch %d, device_id %d, epoch_loss %lf % (episode, device_id, loss.item()))loss.backward()optimizer.step()if done:restart_game()state next_state# 保存模型状态torch.save(model.state_dict(), model_path)print(fModel saved after episode {episode}.)
http://www.zqtcl.cn/news/563607/

相关文章:

  • 《网站开发实践》 实训报告广告策划书案例完整版
  • 一级 爰做片免费网站做中学学中做网站
  • 网站排名如何提升网络营销的有哪些特点
  • 巨腾外贸网站建设个人主页网站模板免费
  • 有哪些网站免费做推广淄博网站电子商城平台建设
  • 网站建设的技术支持论文做网站买什么品牌笔记本好
  • 凡科网站后台在哪里.工程与建设
  • 静态网站源文件下载建设手机网站价格
  • 苏州做网站优化的网站开发邮件
  • 做网站怎么搭建环境阿里云大学 网站建设
  • 网站改版业务嵌入式培训推荐
  • 腾讯云 怎样建设网站网站开发 报价
  • 网络科技公司门户网站免费人脉推广官方软件
  • 建和做网站网络营销推广可以理解为
  • 太原市网站建设网站人防工程做资料的网站
  • 怎么做免费推广网站做网站第一部
  • 橙色网站后台模板WordPress的SEO插件安装失败
  • 做网站好还是做微信小程序好外包加工网外放加工活
  • 中国建设银行网站查征信电子商务网站建设及推广
  • 扫描网站漏洞的软件php网站后台验证码不显示
  • 诸城哪里有做网站的做网站的尺寸
  • 网站开发参考书目做网站推广赚钱吗
  • 九度网站建设网站做ppt模板
  • 浙江做公司网站多少钱评论回复网站怎么做
  • 江门网络建站模板虚拟主机价格一般多少钱
  • 网站建设公司云南深圳手机商城网站设计费用
  • 汇泽网站建设网页版快手
  • 手机销售培训网站wordpress案例插件
  • 滨江道做网站公司wordpress 花瓣网
  • 如何建网站快捷方式软件开发做平台