面向强化学习的状态空间建模:RSSM的介绍和PyTorch实现

开发 前端
循环状态空间模型(Recurrent State Space Models, RSSM)最初由 Danijar Hafer 等人在论文《Learning Latent Dynamics for Planning from Pixels》中提出。

循环状态空间模型(Recurrent State Space Models, RSSM)最初由 Danijar Hafer 等人在论文《Learning Latent Dynamics for Planning from Pixels》中提出。该模型在现代基于模型的强化学习(Model-Based Reinforcement Learning, MBRL)中发挥着关键作用,其主要目标是构建可靠的环境动态预测模型。通过这些学习得到的模型,智能体能够模拟未来轨迹并进行前瞻性的行为规划。

下面我们就来用一个实际案例来介绍RSSM。

环境配置

环境配置是实现过程中的首要步骤。我们这里用易于使用的 Gym API。为了提高实现效率,设计了多个模块化的包装器(wrapper),用于初始化参数并将观察结果调整为指定格式。

InitialWrapper 的设计允许在不执行任何动作的情况下进行特定数量的观察,同时支持在返回观察结果之前多次重复同一动作。这种设计对于响应具有显著延迟特性的环境特别有效。

PreprocessFrame 包装器负责将观察结果转换为正确的数据类型(本文中使用 numpy 数组),并支持灰度转换功能。

class InitialWrapper(gym.Wrapper):  
     def __init__(self, env: gym.Env, no_ops: int = 0, repeat: int = 1):  
         super(InitialWrapper, self).__init__(env)  
         self.repeat = repeat  
         self.no_ops = no_ops  
 
         self.op_counter = 0  
   
     def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:  
         if self.op_counter < self.no_ops:  
             obs, reward, done, info = self.env.step(0)  
             self.op_counter += 1  
   
         total_reward = 0.0  
         done = False  
         for _ in range(self.repeat):  
             obs, reward, done, info = self.env.step(action)  
             total_reward += reward  
             if done:  
                 break  
   
         return obs, total_reward, done, info  
 
 
 class PreprocessFrame(gym.ObservationWrapper):  
     def __init__(self, env: gym.Env, new_shape: Sequence[int] = (128, 128, 3), grayscale: bool = False):  
         super(PreprocessFrame, self).__init__(env)  
         self.shape = new_shape  
         self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=self.shape, dtype=np.float32)  
         self.grayscale = grayscale  
   
         if self.grayscale:  
             self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(*self.shape[:-1], 1), dtype=np.float32)  
   
     def observation(self, obs: torch.Tensor) -> torch.Tensor:  
         obs = obs.astype(np.uint8)  
         new_frame = cv.resize(obs, self.shape[:-1], interpolation=cv.INTER_AREA)  
         if self.grayscale:  
             new_frame = cv.cvtColor(new_frame, cv.COLOR_RGB2GRAY)  
             new_frame = np.expand_dims(new_frame, -1)  
   
         torch_frame = torch.from_numpy(new_frame).float()  
         torch_frame = torch_frame / 255.0  
   
         return torch_frame  
   
 def make_env(env_name: str, new_shape: Sequence[int] = (128, 128, 3), grayscale: bool = True, **kwargs):  
     env = gym.make(env_name, **kwargs)  
     env = PreprocessFrame(env, new_shape, grayscale=grayscale)  
     return env

make_env 函数用于创建一个具有指定配置参数的环境实例。

模型架构

RSSM 的实现依赖于多个关键模型组件。具体来说,需要实现以下四个核心模块:

  • 原始观察编码器(Encoder)
  • 动态模型(Dynamics Model):通过确定性状态 h 和随机状态 s 对编码观察的时间依赖性进行建模
  • 解码器(Decoder):将随机状态和确定性状态映射回原始观察空间
  • 奖励模型(Reward Model):将随机状态和确定性状态映射到奖励值

RSSM 模型组件结构图。模型包含随机状态 s 和确定性状态 h。

编码器实现

编码器采用简单的卷积神经网络(CNN)结构,将输入图像降维到一维嵌入表示。实现中使用了 BatchNorm 来提升训练稳定性。

class EncoderCNN(nn.Module):  
     def __init__(self, in_channels: int, embedding_dim: int = 2048, input_shape: Tuple[int, int] = (128, 128)):  
         super(EncoderCNN, self).__init__()  
         # 定义卷积层结构
         self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)  
         self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  
         self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)  
         self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)  
   
         self.fc1 = nn.Linear(self._compute_conv_output((in_channels, input_shape[0], input_shape[1])), embedding_dim)  
   
         # 批标准化层
         self.bn1 = nn.BatchNorm2d(32)  
         self.bn2 = nn.BatchNorm2d(64)  
         self.bn3 = nn.BatchNorm2d(128)  
         self.bn4 = nn.BatchNorm2d(256)  
   
     def _compute_conv_output(self, shape: Tuple[int, int, int]):  
         with torch.no_grad():  
             x = torch.randn(1, shape[0], shape[1], shape[2])  
             x = self.conv1(x)  
             x = self.conv2(x)  
             x = self.conv3(x)  
             x = self.conv4(x)  
   
             return x.shape[1] * x.shape[2] * x.shape[3]  
 
     def forward(self, x):  
         x = torch.relu(self.conv1(x))  
         x = self.bn1(x)  
         x = torch.relu(self.conv2(x))  
         x = self.bn2(x)  
   
         x = torch.relu(self.conv3(x))  
         x = self.bn3(x)  
   
         x = self.conv4(x)  
         x = self.bn4(x)  
   
         x = x.view(x.size(0), -1)  
         x = self.fc1(x)  
   
         return x

解码器实现

解码器遵循传统自编码器架构设计,其功能是将编码后的观察结果重建回原始观察空间。

class DecoderCNN(nn.Module):  
     def __init__(self, hidden_size: int, state_size: int,  embedding_size: int,  
                  use_bn: bool = True, output_shape: Tuple[int, int] = (3, 128, 128)):  
         super(DecoderCNN, self).__init__()  
   
         self.output_shape = output_shape  
   
         self.embedding_size = embedding_size  
         # 全连接层进行特征变换
         self.fc1 = nn.Linear(hidden_size + state_size, embedding_size)  
         self.fc2 = nn.Linear(embedding_size, 256 * (output_shape[1] // 16) * (output_shape[2] // 16))  
   
         # 反卷积层进行上采样
         self.conv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)  # ×2  
         self.conv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)  # ×2  
         self.conv3 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)  # ×2  
         self.conv4 = nn.ConvTranspose2d(32, output_shape[0], kernel_size=3, stride=2, padding=1, output_padding=1)  
   
         # 批标准化层
         self.bn1 = nn.BatchNorm2d(128)  
         self.bn2 = nn.BatchNorm2d(64)  
         self.bn3 = nn.BatchNorm2d(32)  
   
         self.use_bn = use_bn  
 
     def forward(self, h: torch.Tensor, s: torch.Tensor):  
         x = torch.cat([h, s], dim=-1)  
         x = self.fc1(x)  
         x = torch.relu(x)  
         x = self.fc2(x)  
   
         x = x.view(-1, 256, self.output_shape[1] // 16, self.output_shape[2] // 16)  
   
         if self.use_bn:  
             x = torch.relu(self.bn1(self.conv1(x)))  
             x = torch.relu(self.bn2(self.conv2(x)))  
             x = torch.relu(self.bn3(self.conv3(x)))  
   
         else:  
             x = torch.relu(self.conv1(x))  
             x = torch.relu(self.conv2(x))  
             x = torch.relu(self.conv3(x))  
   
         x = self.conv4(x)  
   
         return x

奖励模型实现

奖励模型采用了一个三层前馈神经网络结构,用于将随机状态 s 和确定性状态 h 映射到正态分布参数,进而通过采样获得奖励预测。

class RewardModel(nn.Module):  
     def __init__(self, hidden_dim: int, state_dim: int):  
         super(RewardModel, self).__init__()  
   
         self.fc1 = nn.Linear(hidden_dim + state_dim, hidden_dim)  
         self.fc2 = nn.Linear(hidden_dim, hidden_dim)  
         self.fc3 = nn.Linear(hidden_dim, 2)  
   
     def forward(self, h: torch.Tensor, s: torch.Tensor):  
         x = torch.cat([h, s], dim=-1)  
         x = torch.relu(self.fc1(x))  
         x = torch.relu(self.fc2(x))  
         x = self.fc3(x)  
   
         return x

动态模型的实现

动态模型是 RSSM 架构中最复杂的组件,需要同时处理先验和后验状态转移模型:

  1. 后验转移模型:在能够访问真实观察的情况下使用(主要在训练阶段),用于在给定观察和历史状态的条件下近似随机状态的后验分布。
  2. 先验转移模型:用于近似先验状态分布,仅依赖于前一时刻状态,不依赖于观察。这在无法获取后验观察的推理阶段使用。

这两个模型均通过单层前馈网络进行参数化,输出各自正态分布的均值和对数方差,用于状态 s 的采样。该实现采用了简单的网络结构,但可以根据需要扩展为更复杂的架构。

确定性状态采用门控循环单元(GRU)实现。其输入包括:

  • 前一时刻的隐藏状态
  • 独热编码动作
  • 前一时刻随机状态 s(根据是否可以获取观察来选择使用后验或先验状态)

这些输入信息足以让模型了解动作历史和系统状态。以下是具体实现代码:

class DynamicsModel(nn.Module):  
     def __init__(self, hidden_dim: int, action_dim: int, state_dim: int, embedding_dim: int, rnn_layer: int = 1):  
         super(DynamicsModel, self).__init__()  
   
         self.hidden_dim = hidden_dim  
           
         # 递归层实现,支持多层 GRU
         self.rnn = nn.ModuleList([nn.GRUCell(hidden_dim, hidden_dim) for _ in range(rnn_layer)])  
           
         # 状态动作投影层
         self.project_state_action = nn.Linear(action_dim + state_dim, hidden_dim)  
           
         # 先验网络:输出正态分布参数
         self.prior = nn.Linear(hidden_dim, state_dim * 2)  
         self.project_hidden_action = nn.Linear(hidden_dim + action_dim, hidden_dim)  
           
         # 后验网络:输出正态分布参数
         self.posterior = nn.Linear(hidden_dim, state_dim * 2)  
         self.project_hidden_obs = nn.Linear(hidden_dim + embedding_dim, hidden_dim)  
   
         self.state_dim = state_dim  
         self.act_fn = nn.ReLU()  
   
     def forward(self, prev_hidden: torch.Tensor, prev_state: torch.Tensor, actions: torch.Tensor,  
                 obs: torch.Tensor = None, dones: torch.Tensor = None):  
         """  
        动态模型的前向传播
        参数:  
            prev_hidden: RNN的前一隐藏状态,形状 (batch_size, hidden_dim)  
            prev_state: 前一随机状态,形状 (batch_size, state_dim)  
            actions: 独热编码动作序列,形状 (sequence_length, batch_size, action_dim)  
            obs: 编码器输出的观察嵌入,形状 (sequence_length, batch_size, embedding_dim)  
            dones: 终止状态标志
        """  
         B, T, _ = actions.size()  # 用于无观察访问时的推理
   
         # 初始化存储列表
         hiddens_list = []  
         posterior_means_list = []  
         posterior_logvars_list = []  
         prior_means_list = []  
         prior_logvars_list = []  
         prior_states_list = []  
         posterior_states_list = []  
           
         # 存储初始状态
         hiddens_list.append(prev_hidden.unsqueeze(1))    
         prior_states_list.append(prev_state.unsqueeze(1))  
         posterior_states_list.append(prev_state.unsqueeze(1))  
   
         # 时序展开
         for t in range(T - 1):  
             # 提取当前时刻状态和动作
             action_t = actions[:, t, :]  
             obs_t = obs[:, t, :] if obs is not None else torch.zeros(B, self.embedding_dim, device=actions.device)  
             state_t = posterior_states_list[-1][:, 0, :] if obs is not None else prior_states_list[-1][:, 0, :]  
             state_t = state_t if dones is None else state_t * (1 - dones[:, t, :])  
             hidden_t = hiddens_list[-1][:, 0, :]  
               
             # 状态动作组合
             state_action = torch.cat([state_t, action_t], dim=-1)  
             state_action = self.act_fn(self.project_state_action(state_action))  
   
             # RNN 状态更新
             for i in range(len(self.rnn)):  
                 hidden_t = self.rnn[i](state_action, hidden_t)  
   
             # 先验分布计算
             hidden_action = torch.cat([hidden_t, action_t], dim=-1)  
             hidden_action = self.act_fn(self.project_hidden_action(hidden_action))  
             prior_params = self.prior(hidden_action)  
             prior_mean, prior_logvar = torch.chunk(prior_params, 2, dim=-1)  
   
             # 从先验分布采样
             prior_dist = torch.distributions.Normal(prior_mean, torch.exp(F.softplus(prior_logvar)))  
             prior_state_t = prior_dist.rsample()  
   
             # 后验分布计算
             if obs is None:  
                 posterior_mean = prior_mean  
                 posterior_logvar = prior_logvar  
             else:  
                 hidden_obs = torch.cat([hidden_t, obs_t], dim=-1)  
                 hidden_obs = self.act_fn(self.project_hidden_obs(hidden_obs))  
                 posterior_params = self.posterior(hidden_obs)  
                 posterior_mean, posterior_logvar = torch.chunk(posterior_params, 2, dim=-1)  
   
             # 从后验分布采样
             posterior_dist = torch.distributions.Normal(posterior_mean, torch.exp(F.softplus(posterior_logvar)))  
             posterior_state_t = posterior_dist.rsample()  
   
             # 保存状态
             posterior_means_list.append(posterior_mean.unsqueeze(1))  
             posterior_logvars_list.append(posterior_logvar.unsqueeze(1))  
             prior_means_list.append(prior_mean.unsqueeze(1))  
             prior_logvars_list.append(prior_logvar.unsqueeze(1))  
             prior_states_list.append(prior_state_t.unsqueeze(1))  
             posterior_states_list.append(posterior_state_t.unsqueeze(1))  
             hiddens_list.append(hidden_t.unsqueeze(1))  
   
         # 合并时序数据
         hiddens = torch.cat(hiddens_list, dim=1)  
         prior_states = torch.cat(prior_states_list, dim=1)  
         posterior_states = torch.cat(posterior_states_list, dim=1)  
         prior_means = torch.cat(prior_means_list, dim=1)  
         prior_logvars = torch.cat(prior_logvars_list, dim=1)  
         posterior_means = torch.cat(posterior_means_list, dim=1)  
         posterior_logvars = torch.cat(posterior_logvars_list, dim=1)  
   
         return hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars

需要特别注意的是,这里的观察输入并非原始观察数据,而是经过编码器处理后的嵌入表示。这种设计能够有效降低计算复杂度并提升模型的泛化能力。

RSSM 整体架构

将前述组件整合为完整的 RSSM 模型。其核心是 generate_rollout 方法,负责调用动态模型并生成环境动态的潜在表示序列。对于没有历史潜在状态的情况(通常发生在轨迹开始时),该方法会进行必要的初始化。下面是完整的实现代码:

class RSSM:  
     def __init__(self,  
                  encoder: EncoderCNN,  
                  decoder: DecoderCNN,  
                  reward_model: RewardModel,  
                  dynamics_model: nn.Module,  
                  hidden_dim: int,  
                  state_dim: int,  
                  action_dim: int,  
                  embedding_dim: int,  
                  device: str = "mps"):  
         """  
        循环状态空间模型(RSSM)实现
         
        参数:
            encoder: 确定性状态编码器
            decoder: 观察重构解码器
            reward_model: 奖励预测模型
            dynamics_model: 状态动态模型
            hidden_dim: RNN 隐藏层维度
            state_dim: 随机状态维度
            action_dim: 动作空间维度
            embedding_dim: 观察嵌入维度
            device: 计算设备
        """  
         super(RSSM, self).__init__()  
   
         # 模型组件初始化
         self.dynamics = dynamics_model  
         self.encoder = encoder  
         self.decoder = decoder  
         self.reward_model = reward_model  
   
         # 维度参数存储
         self.hidden_dim = hidden_dim  
         self.state_dim = state_dim  
         self.action_dim = action_dim  
         self.embedding_dim = embedding_dim  
   
         # 模型迁移至指定设备
         self.dynamics.to(device)  
         self.encoder.to(device)  
         self.decoder.to(device)  
         self.reward_model.to(device)  
   
     def generate_rollout(self, actions: torch.Tensor, hiddens: torch.Tensor = None, states: torch.Tensor = None,  
                          obs: torch.Tensor = None, dones: torch.Tensor = None):  
         """
        生成状态序列展开
         
        参数:
            actions: 动作序列
            hiddens: 初始隐藏状态(可选)
            states: 初始随机状态(可选)
            obs: 观察序列(可选)
            dones: 终止标志序列
             
        返回:
            完整的状态展开序列
        """
         # 状态初始化
         if hiddens is None:  
             hiddens = torch.zeros(actions.size(0), self.hidden_dim).to(actions.device)  
   
         if states is None:  
             states = torch.zeros(actions.size(0), self.state_dim).to(actions.device)  
   
         # 执行动态模型展开
         dynamics_result = self.dynamics(hiddens, states, actions, obs, dones)  
         hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars = dynamics_result  
   
         return hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars  
   
     def train(self):  
         """启用训练模式"""
         self.dynamics.train()  
         self.encoder.train()  
         self.decoder.train()  
         self.reward_model.train()  
   
     def eval(self):  
         """启用评估模式"""
         self.dynamics.eval()  
         self.encoder.eval()  
         self.decoder.eval()  
         self.reward_model.eval()  
   
     def encode(self, obs: torch.Tensor):  
         """观察编码"""
         return self.encoder(obs)  
   
     def decode(self, state: torch.Tensor):  
         """状态解码为观察"""
         return self.decoder(state)  
   
     def predict_reward(self, h: torch.Tensor, s: torch.Tensor):  
         """奖励预测"""
         return self.reward_model(h, s)  
   
     def parameters(self):  
         """返回所有可训练参数"""
         return list(self.dynamics.parameters()) + list(self.encoder.parameters()) + \
                list(self.decoder.parameters()) + list(self.reward_model.parameters())  
   
     def save(self, path: str):  
         """模型状态保存"""
         torch.save({  
             "dynamics": self.dynamics.state_dict(),  
             "encoder": self.encoder.state_dict(),  
             "decoder": self.decoder.state_dict(),  
             "reward_model": self.reward_model.state_dict()  
        }, path)  
   
     def load(self, path: str):  
         """模型状态加载"""
         checkpoint = torch.load(path)  
         self.dynamics.load_state_dict(checkpoint["dynamics"])  
         self.encoder.load_state_dict(checkpoint["encoder"])  
         self.decoder.load_state_dict(checkpoint["decoder"])  
         self.reward_model.load_state_dict(checkpoint["reward_model"])

这个实现提供了一个完整的 RSSM 框架,包含了模型的训练、评估、状态保存和加载等基本功能。该框架可以作为基础结构,根据具体应用场景进行扩展和优化。

训练系统设计

RSSM 的训练系统主要包含两个核心组件:经验回放缓冲区(Experience Replay Buffer)和智能体(Agent)。其中,缓冲区负责存储历史经验数据用于训练,而智能体则作为环境与 RSSM 之间的接口,实现数据收集策略。

经验回放缓冲区实现

缓冲区采用循环队列结构,用于存储和管理观察、动作、奖励和终止状态等数据。通过 sample 方法可以随机采样训练序列。

class Buffer:  
     def __init__(self, buffer_size: int, obs_shape: tuple, action_shape: tuple, device: torch.device):  
         """
        经验回放缓冲区初始化
         
        参数:
            buffer_size: 缓冲区容量
            obs_shape: 观察数据维度
            action_shape: 动作数据维度
            device: 计算设备
        """
         self.buffer_size = buffer_size  
         self.obs_buffer = np.zeros((buffer_size, *obs_shape), dtype=np.float32)  
         self.action_buffer = np.zeros((buffer_size, *action_shape), dtype=np.int32)  
         self.reward_buffer = np.zeros((buffer_size, 1), dtype=np.float32)  
         self.done_buffer = np.zeros((buffer_size, 1), dtype=np.bool_)  
   
         self.device = device  
         self.idx = 0  
   
     def add(self, obs: torch.Tensor, action: int, reward: float, done: bool):  
         """
        添加单步经验数据
        """
         self.obs_buffer[self.idx] = obs  
         self.action_buffer[self.idx] = action  
         self.reward_buffer[self.idx] = reward  
         self.done_buffer[self.idx] = done  
         self.idx = (self.idx + 1) % self.buffer_size  
 
     def sample(self, batch_size: int, sequence_length: int):  
         """
        随机采样经验序列
         
        参数:
            batch_size: 批量大小
            sequence_length: 序列长度
             
        返回:
            经验数据元组 (observations, actions, rewards, dones)
        """
         # 随机选择序列起始位置
         starting_idxs = np.random.randint(0, (self.idx % self.buffer_size) - sequence_length, (batch_size,))  
         
         # 构建完整序列索引
         index_tensor = np.stack([np.arange(start, start + sequence_length) for start in starting_idxs])  
         
         # 提取数据序列
         obs_sequence = self.obs_buffer[index_tensor]  
         action_sequence = self.action_buffer[index_tensor]  
         reward_sequence = self.reward_buffer[index_tensor]  
         done_sequence = self.done_buffer[index_tensor]  
   
         return obs_sequence, action_sequence, reward_sequence, done_sequence  
 
     def save(self, path: str):  
         """保存缓冲区数据"""
         np.savez(path, obs_buffer=self.obs_buffer, action_buffer=self.action_buffer,  
                  reward_buffer=self.reward_buffer, done_buffer=self.done_buffer, idx=self.idx)  
   
     def load(self, path: str):  
         """加载缓冲区数据"""
         data = np.load(path)  
         self.obs_buffer = data["obs_buffer"]  
         self.action_buffer = data["action_buffer"]  
         self.reward_buffer = data["reward_buffer"]  
         self.done_buffer = data["done_buffer"]  
         self.idx = data["idx"]

智能体设计

智能体实现了数据收集和规划功能。当前实现采用了简单的随机策略进行数据收集,但该框架支持扩展更复杂的策略。

class Policy(ABC):  
     """策略基类"""
     @abstractmethod  
     def __call__(self, obs):  
         pass  
   
 class RandomPolicy(Policy):  
     """随机采样策略"""
     def __init__(self, env: Env):  
         self.env = env  
   
     def __call__(self, obs):  
         return self.env.action_space.sample()  
 
 class Agent:  
     def __init__(self, env: Env, rssm: RSSM, buffer_size: int = 100000,
                  collection_policy: str = "random", device="mps"):  
         """
        智能体初始化
         
        参数:
            env: 环境实例
            rssm: RSSM模型实例
            buffer_size: 经验缓冲区大小
            collection_policy: 数据收集策略类型
            device: 计算设备
        """
         self.env = env  
         # 策略选择
         match collection_policy:  
             case "random":  
                 self.rollout_policy = RandomPolicy(env)  
             case _:  
                 raise ValueError("Invalid rollout policy")  
   
         self.buffer = Buffer(buffer_size, env.observation_space.shape,
                            env.action_space.shape, device=device)  
         self.rssm = rssm  
   
     def data_collection_action(self, obs):  
         """执行数据收集动作"""
         return self.rollout_policy(obs)  
   
     def collect_data(self, num_steps: int):  
         """
        收集训练数据
         
        参数:
            num_steps: 收集步数
        """
         obs = self.env.reset()  
         done = False  
   
         iterator = tqdm(range(num_steps), desc="Data Collection")  
         for _ in iterator:  
             action = self.data_collection_action(obs)  
             next_obs, reward, done, _, _ = self.env.step(action)  
             self.buffer.add(next_obs, action, reward, done)  
             obs = next_obs  
             if done:  
                 obs = self.env.reset()  
   
     def imagine_rollout(self, prev_hidden: torch.Tensor, prev_state: torch.Tensor,
                        actions: torch.Tensor):  
         """
        执行想象展开
         
        参数:
            prev_hidden: 前一隐藏状态
            prev_state: 前一随机状态
            actions: 动作序列
             
        返回:
            完整的模型输出,包括隐藏状态、先验状态、后验状态等
        """
         hiddens, prior_states, posterior_states, prior_means, prior_logvars, \
         posterior_means, posterior_logvars = self.rssm.generate_rollout(
             actions, prev_hidden, prev_state)  
   
         # 在想象阶段使用先验状态预测奖励
         rewards = self.rssm.predict_reward(hiddens, prior_states)  
   
         return hiddens, prior_states, posterior_states, prior_means, \
                prior_logvars, posterior_means, posterior_logvars, rewards  
   
     def plan(self, num_steps: int, prev_hidden: torch.Tensor,
              prev_state: torch.Tensor, actions: torch.Tensor):  
         """
        执行规划
         
        参数:
            num_steps: 规划步数
            prev_hidden: 初始隐藏状态
            prev_state: 初始随机状态
            actions: 动作序列
             
        返回:
            规划得到的隐藏状态和先验状态序列
        """
         hidden_states = []  
         prior_states = []  
   
         hiddens = prev_hidden  
         states = prev_state  
   
         for _ in range(num_steps):  
             hiddens, states, _, _, _, _, _, _ = self.imagine_rollout(
                 hiddens, states, actions)  
             hidden_states.append(hiddens)  
             prior_states.append(states)  
   
         hidden_states = torch.stack(hidden_states)  
         prior_states = torch.stack(prior_states)  
   
         return hidden_states, prior_states

这部分实现提供了完整的数据管理和智能体交互框架。通过经验回放缓冲区,可以高效地存储和重用历史数据;通过智能体的抽象策略接口,可以方便地扩展不同的数据收集策略。同时智能体还实现了基于模型的想象展开和规划功能,为后续的决策制定提供了基础。

训练器实现与实验

训练器设计

训练器是 RSSM 实现中的最后一个关键组件,负责协调模型训练过程。训练器接收 RSSM 模型、智能体、优化器等组件,并实现具体的训练逻辑。

logging.basicConfig(  
     level=logging.INFO,  
     format="%(asctime)s - %(levelname)s - %(message)s",  
     handlers=[  
         logging.StreamHandler(),  # 控制台输出
         logging.FileHandler("training.log", mode="w")  # 文件输出
    ]  
 )  
   
 logger = logging.getLogger(__name__)  
 
 class Trainer:  
     def __init__(self, rssm: RSSM, agent: Agent, optimizer: torch.optim.Optimizer,
                  device: torch.device):  
         """
        训练器初始化
         
        参数:
            rssm: RSSM 模型实例
            agent: 智能体实例
            optimizer: 优化器实例
            device: 计算设备
        """
         self.rssm = rssm  
         self.optimizer = optimizer  
         self.device = device  
         self.agent = agent  
         self.writer = SummaryWriter()  # tensorboard 日志记录器
   
     def train_batch(self, batch_size: int, seq_len: int, iteration: int,
                    save_images: bool = False):  
         """
        单批次训练
         
        参数:
            batch_size: 批量大小
            seq_len: 序列长度
            iteration: 当前迭代次数
            save_images: 是否保存重建图像
        """
         # 采样训练数据
         obs, actions, rewards, dones = self.agent.buffer.sample(batch_size, seq_len)  
   
         # 数据预处理
         actions = torch.tensor(actions).long().to(self.device)  
         actions = F.one_hot(actions, self.rssm.action_dim).float()  
         obs = torch.tensor(obs, requires_grad=True).float().to(self.device)  
         rewards = torch.tensor(rewards, requires_grad=True).float().to(self.device)  
         dones = torch.tensor(dones).float().to(self.device)  
   
         # 观察编码
         encoded_obs = self.rssm.encoder(obs.reshape(-1, *obs.shape[2:]).permute(0, 3, 1, 2))  
         encoded_obs = encoded_obs.reshape(batch_size, seq_len, -1)  
   
         # 执行 RSSM 展开
         rollout = self.rssm.generate_rollout(actions, obs=encoded_obs, dones=dones)  
         hiddens, prior_states, posterior_states, prior_means, prior_logvars, \
         posterior_means, posterior_logvars = rollout  
   
         # 重构观察
         hiddens_reshaped = hiddens.reshape(batch_size * seq_len, -1)  
         posterior_states_reshaped = posterior_states.reshape(batch_size * seq_len, -1)  
         decoded_obs = self.rssm.decoder(hiddens_reshaped, posterior_states_reshaped)  
         decoded_obs = decoded_obs.reshape(batch_size, seq_len, *obs.shape[-3:])  
   
         # 奖励预测
         reward_params = self.rssm.reward_model(hiddens, posterior_states)  
         mean, logvar = torch.chunk(reward_params, 2, dim=-1)  
         logvar = F.softplus(logvar)  
         reward_dist = Normal(mean, torch.exp(logvar))  
         predicted_rewards = reward_dist.rsample()  
   
         # 可视化
         if save_images:  
             batch_idx = np.random.randint(0, batch_size)  
             seq_idx = np.random.randint(0, seq_len - 3)  
             fig = self._visualize(obs, decoded_obs, rewards, predicted_rewards,
                                 batch_idx, seq_idx, iteration, grayscale=True)  
             if not os.path.exists("reconstructions"):  
                 os.makedirs("reconstructions")  
             fig.savefig(f"reconstructions/iteration_{iteration}.png")  
             self.writer.add_figure("Reconstructions", fig, iteration)  
             plt.close(fig)  
   
         # 计算损失
         reconstruction_loss = self._reconstruction_loss(decoded_obs, obs)  
         kl_loss = self._kl_loss(prior_means, F.softplus(prior_logvars),
                                posterior_means, F.softplus(posterior_logvars))  
         reward_loss = self._reward_loss(rewards, predicted_rewards)  
   
         loss = reconstruction_loss + kl_loss + reward_loss  
   
         # 反向传播和优化
         self.optimizer.zero_grad()  
         loss.backward()  
         nn.utils.clip_grad_norm_(self.rssm.parameters(), 1, norm_type=2)  
         self.optimizer.step()  
   
         return loss.item(), reconstruction_loss.item(), kl_loss.item(), reward_loss.item()  
   
     def train(self, iterations: int, batch_size: int, seq_len: int):  
         """
        执行完整训练过程
         
        参数:
            iterations: 迭代总次数
            batch_size: 批量大小
            seq_len: 序列长度
        """
         self.rssm.train()  
         iterator = tqdm(range(iterations), desc="Training", total=iterations)  
         losses = []  
         infos = []  
         last_loss = float("inf")  
         
         for i in iterator:  
             # 执行单批次训练
             loss, reconstruction_loss, kl_loss, reward_loss = self.train_batch(
                 batch_size, seq_len, i, save_images=i % 100 == 0)  
   
             # 记录训练指标
             self.writer.add_scalar("Loss", loss, i)  
             self.writer.add_scalar("Reconstruction Loss", reconstruction_loss, i)  
             self.writer.add_scalar("KL Loss", kl_loss, i)  
             self.writer.add_scalar("Reward Loss", reward_loss, i)  
   
             # 保存最佳模型
             if loss < last_loss:  
                 self.rssm.save("rssm.pth")  
                 last_loss = loss  
   
             # 记录详细信息
             info = {  
                 "Loss": loss,  
                 "Reconstruction Loss": reconstruction_loss,  
                 "KL Loss": kl_loss,  
                 "Reward Loss": reward_loss  
            }  
             losses.append(loss)  
             infos.append(info)  
   
             # 定期输出训练状态
             if i % 10 == 0:  
                 logger.info("\n----------------------------")  
                 logger.info(f"Iteration: {i}")  
                 logger.info(f"Loss: {loss:.4f}")  
                 logger.info(f"Running average last 20 losses: {sum(losses[-20:]) / 20: .4f}")  
                 logger.info(f"Reconstruction Loss: {reconstruction_loss:.4f}")  
                 logger.info(f"KL Loss: {kl_loss:.4f}")  
                 logger.info(f"Reward Loss: {reward_loss:.4f}")
 
 ### 实验示例
 
 以下是一个在 CarRacing 环境中训练 RSSM 的完整示例:
 
 ```python
 # 环境初始化
 env = make_env("CarRacing-v2", render_mode="rgb_array", continuous=False, grayscale=True)  
 
 # 模型参数设置
 hidden_size = 1024  
 embedding_dim = 1024  
 state_dim = 512  
 
 # 模型组件实例化
 encoder = EncoderCNN(in_channels=1, embedding_dim=embedding_dim)  
 decoder = DecoderCNN(hidden_size=hidden_size, state_size=state_dim,
                      embedding_size=embedding_dim, output_shape=(1,128,128))  
 reward_model = RewardModel(hidden_dim=hidden_size, state_dim=state_dim)  
 dynamics_model = DynamicsModel(hidden_dim=hidden_size, state_dim=state_dim,
                               action_dim=5, embedding_dim=embedding_dim)  
 
 # RSSM 模型构建
 rssm = RSSM(dynamics_model=dynamics_model,  
             encoder=encoder,  
             decoder=decoder,  
             reward_model=reward_model,  
             hidden_dim=hidden_size,  
             state_dim=state_dim,  
             action_dim=5,  
             embedding_dim=embedding_dim)  
 
 # 训练设置
 optimizer = torch.optim.Adam(rssm.parameters(), lr=1e-3)  
 agent = Agent(env, rssm)  
 trainer = Trainer(rssm, agent, optimizer=optimizer, device="cuda")  
 
 # 数据收集和训练
 trainer.collect_data(20000)  # 收集 20000 步经验数据
 trainer.save_buffer("buffer.npz")  # 保存经验缓冲区
 trainer.train(10000, 32, 20)  # 执行 10000 次迭代训练

总结

本文详细介绍了基于 PyTorch 实现 RSSM 的完整过程。RSSM 的架构相比传统的 VAE 或 RNN 更为复杂,这主要源于其混合了随机和确定性状态的特性。通过手动实现这一架构,我们可以深入理解其背后的理论基础及其强大之处。RSSM 能够递归地生成未来潜在状态轨迹,这为智能体的行为规划提供了基础。

实现的优点在于其计算负载适中,可以在单个消费级 GPU 上进行训练,在有充足时间的情况下甚至可以在 CPU 上运行。这一工作基于论文《Learning Latent Dynamics for Planning from Pixels》,该论文为 RSSM 类动态模型奠定了基础。后续的研究工作如《Dream to Control: Learning Behaviors by Latent Imagination》进一步发展了这一架构。这些改进的架构将在未来的研究中深入探讨,因为它们对理解 MBRL 方法提供了重要的见解。

责任编辑:华轩 来源: DeepHub IMBA
相关推荐

2023-03-23 16:30:53

PyTorchDDPG算法

2020-08-10 06:36:21

强化学习代码深度学习

2019-09-29 10:42:02

人工智能机器学习技术

2022-05-31 10:45:01

深度学习防御

2024-01-26 08:31:49

2022-03-25 10:35:20

机器学习深度学习强化学习

2020-11-12 19:31:41

强化学习人工智能机器学习

2021-09-17 15:54:41

深度学习机器学习人工智能

2023-06-25 11:30:47

可视化

2023-01-24 17:03:13

强化学习算法机器人人工智能

2025-01-03 11:46:31

2022-11-02 14:02:02

强化学习训练

2020-06-05 08:09:01

Python强化学习框架

2023-03-09 08:00:00

强化学习机器学习围棋

2023-07-20 15:18:42

2024-12-09 08:45:00

模型AI

2023-08-14 16:49:13

强化学习时态差分法

2020-05-12 07:00:00

深度学习强化学习人工智能

2023-12-03 22:08:41

深度学习人工智能

2023-11-07 07:13:31

推荐系统多任务学习
点赞
收藏

51CTO技术栈公众号