您的位置:首页 > 娱乐 > 明星 > 湛江建设工程信息网_网页打不开图片_中国企业500强排行榜_网站的友情链接是什么意思

湛江建设工程信息网_网页打不开图片_中国企业500强排行榜_网站的友情链接是什么意思

2025/5/11 5:21:36 来源:https://blog.csdn.net/weixin_38428827/article/details/146302928  浏览:    关键词:湛江建设工程信息网_网页打不开图片_中国企业500强排行榜_网站的友情链接是什么意思
湛江建设工程信息网_网页打不开图片_中国企业500强排行榜_网站的友情链接是什么意思

视频链接:

DQN 玩 2048 实战|第二期!设计 ε 贪心策略神经网络,简单训练一下吧!

代码仓库:LitchiCheng/DRL-learning: 深度强化学习

概念介绍:

DQN(深度 Q 网络,Deep Q-Network)中,Q 的全称是 “Quality”(质量),对应的完整术语是“状态 - 动作值函数”(State-Action Value Function),记作 Q(s,a)

定义:Q(s,a) 表示在状态 s 下执行动作 a 后,智能体未来累积奖励的期望(即 “长期收益的质量”)。

作用:

Q 值是强化学习中 “决策” 的核心依据。智能体通过比较当前状态下所有可能动作的 Q 值,选择 Q 值最大的动作(即 “最优动作”),以最大化累积奖励。

网络设计有三点:

  1. 深度 Q 网络定义:使用 PyTorch 定义一个神经网络,用于近似 Q 值函数。
  2. 经验回放机制:实现经验回放缓冲区,用于存储智能体的经验,并随机采样进行训练。
  3. 使用 Epsilon-greedy 策略,是一种平衡探索(Exploration)与利用(Exploitation)的经典策略,核心解决 “如何避免智能体只依赖已知最优动作,而错过潜在更好策略” 的问题。

下面是代码

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import random
from collections import deque
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.table import Table# 2048 游戏环境类
class Game2048:def __init__(self):self.board = np.zeros((4, 4), dtype=int)self.add_random_tile()self.add_random_tile()def add_random_tile(self):empty_cells = np.argwhere(self.board == 0)if len(empty_cells) > 0:index = random.choice(empty_cells)self.board[index[0], index[1]] = 2 if random.random() < 0.9 else 4def move_left(self):reward = 0new_board = np.copy(self.board)for row in range(4):line = new_board[row]non_zero = line[line != 0]merged = []i = 0while i < len(non_zero):if i + 1 < len(non_zero) and non_zero[i] == non_zero[i + 1]:merged.append(2 * non_zero[i])reward += 2 * non_zero[i]i += 2else:merged.append(non_zero[i])i += 1new_board[row] = np.pad(merged, (0, 4 - len(merged)), 'constant')if not np.array_equal(new_board, self.board):self.board = new_boardself.add_random_tile()return rewarddef move_right(self):self.board = np.fliplr(self.board)reward = self.move_left()self.board = np.fliplr(self.board)return rewarddef move_up(self):self.board = self.board.Treward = self.move_left()self.board = self.board.Treturn rewarddef move_down(self):self.board = self.board.Treward = self.move_right()self.board = self.board.Treturn rewarddef step(self, action):if action == 0:reward = self.move_left()elif action == 1:reward = self.move_right()elif action == 2:reward = self.move_up()elif action == 3:reward = self.move_down()done = not np.any(self.board == 0) and all([np.all(self.board[:, i] != self.board[:, i + 1]) for i in range(3)]) and all([np.all(self.board[i, :] != self.board[i + 1, :]) for i in range(3)])state = self.board.flatten()return state, reward, donedef reset(self):self.board = np.zeros((4, 4), dtype=int)self.add_random_tile()self.add_random_tile()return self.board.flatten()# 深度 Q 网络类
class DQN(nn.Module):def __init__(self, input_size, output_size):super(DQN, self).__init__()self.fc1 = nn.Linear(input_size, 128)self.fc2 = nn.Linear(128, 128)self.fc3 = nn.Linear(128, output_size)def forward(self, x):x = torch.relu(self.fc1(x))x = torch.relu(self.fc2(x))return self.fc3(x)# 经验回放缓冲区类
class ReplayBuffer:def __init__(self, capacity):self.buffer = deque(maxlen=capacity)def add(self, state, action, reward, next_state, done):self.buffer.append((state, action, reward, next_state, done))def sample(self, batch_size):batch = random.sample(self.buffer, batch_size)states, actions, rewards, next_states, dones = zip(*batch)return np.array(states), np.array(actions), np.array(rewards), np.array(next_states), np.array(dones)def __len__(self):return len(self.buffer)# 可视化函数
def visualize_board(board, ax):ax.clear()table = Table(ax, bbox=[0, 0, 1, 1])nrows, ncols = board.shapewidth, height = 1.0 / ncols, 1.0 / nrows# 定义颜色映射cmap = mcolors.LinearSegmentedColormap.from_list("", ["white", "yellow", "orange", "red"])for (i, j), val in np.ndenumerate(board):color = cmap(np.log2(val + 1) / np.log2(2048 + 1)) if val > 0 else "white"table.add_cell(i, j, width, height, text=val if val > 0 else "",loc='center', facecolor=color)ax.add_table(table)ax.set_axis_off()plt.draw()plt.pause(0.1)# 训练函数
def train():env = Game2048()input_size = 16output_size = 4model = DQN(input_size, output_size)target_model = DQN(input_size, output_size)target_model.load_state_dict(model.state_dict())target_model.eval()optimizer = optim.Adam(model.parameters(), lr=0.001)criterion = nn.MSELoss()replay_buffer = ReplayBuffer(capacity=10000)batch_size = 32gamma = 0.99epsilon = 1.0epsilon_decay = 0.995epsilon_min = 0.01update_target_freq = 10num_episodes = 1000fig, ax = plt.subplots()for episode in range(num_episodes):state = env.reset()state = torch.FloatTensor(state).unsqueeze(0)done = Falsetotal_reward = 0while not done:visualize_board(env.board, ax)if random.random() < epsilon:action = random.randint(0, output_size - 1)else:q_values = model(state)action = torch.argmax(q_values, dim=1).item()next_state, reward, done = env.step(action)next_state = torch.FloatTensor(next_state).unsqueeze(0)replay_buffer.add(state.squeeze(0).numpy(), action, reward, next_state.squeeze(0).numpy(), done)if len(replay_buffer) >= batch_size:states, actions, rewards, next_states, dones = replay_buffer.sample(batch_size)states = torch.FloatTensor(states)actions = torch.LongTensor(actions)rewards = torch.FloatTensor(rewards)next_states = torch.FloatTensor(next_states)dones = torch.FloatTensor(dones)q_values = model(states)# 得到每个状态下实际采取动作的 Q 值q_values = q_values.gather(1, actions.unsqueeze(1)).squeeze(1)next_q_values = target_model(next_states)# 得到下一个状态下最大的 Q 值next_q_values = next_q_values.max(1)[0]# 目标 Q 值target_q_values = rewards + gamma * (1 - dones) * next_q_valuesloss = criterion(q_values, target_q_values)optimizer.zero_grad()loss.backward()optimizer.step()state = next_statetotal_reward += rewardif episode % update_target_freq == 0:target_model.load_state_dict(model.state_dict())epsilon = max(epsilon * epsilon_decay, epsilon_min)print(f"Episode {episode}: Total Reward = {total_reward}, Epsilon = {epsilon}")plt.close()if __name__ == "__main__":train()

运行,会出现matplotlib可视化的2048操作过程,控制台输出当前训练的轮数等信息

版权声明:

本网仅为发布的内容提供存储空间,不对发表、转载的内容提供任何形式的保证。凡本网注明“来源:XXX网络”的作品,均转载自其它媒体,著作权归作者所有,商业转载请联系作者获得授权,非商业转载请注明出处。

我们尊重并感谢每一位作者,均已注明文章来源和作者。如因作品内容、版权或其它问题,请及时与我们联系,联系邮箱:809451989@qq.com,投稿邮箱:809451989@qq.com