diff --git a/rsl_rl/algorithms/__init__.py b/rsl_rl/algorithms/__init__.py index effcaa28..68efbc4b 100644 --- a/rsl_rl/algorithms/__init__.py +++ b/rsl_rl/algorithms/__init__.py @@ -7,5 +7,6 @@ from .distillation import Distillation from .ppo import PPO +from .rl2_ppo import RL2PPO -__all__ = ["PPO", "Distillation"] +__all__ = ["PPO", "Distillation", "RL2PPO"] diff --git a/rsl_rl/algorithms/rl2_ppo.py b/rsl_rl/algorithms/rl2_ppo.py new file mode 100644 index 00000000..c6aaa932 --- /dev/null +++ b/rsl_rl/algorithms/rl2_ppo.py @@ -0,0 +1,486 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.optim as optim +from itertools import chain + +from rsl_rl.modules import ActorCritic +from rsl_rl.modules import ActorCriticRL2 +from rsl_rl.modules.rnd import RandomNetworkDistillation +from rsl_rl.storage import RolloutStorage +from rsl_rl.utils import string_to_callable + + +class RL2PPO: + """Proximal Policy Optimization algorithm (https://arxiv.org/abs/1707.06347).""" + + policy: ActorCriticRL2 + """The actor critic module.""" + + def __init__( + self, + policy, + num_learning_epochs=1, + num_mini_batches=1, + clip_param=0.2, + gamma=0.998, + lam=0.95, + value_loss_coef=1.0, + entropy_coef=0.0, + learning_rate=1e-3, + max_grad_norm=1.0, + use_clipped_value_loss=True, + schedule="fixed", + desired_kl=0.01, + device="cpu", + normalize_advantage_per_mini_batch=False, + # RND parameters + rnd_cfg: dict | None = None, + # Symmetry parameters + symmetry_cfg: dict | None = None, + # Distributed training parameters + multi_gpu_cfg: dict | None = None, + chunk_size=1, + ): + # device-related parameters + self.device = device + self.is_multi_gpu = multi_gpu_cfg is not None + # Multi-GPU parameters + if multi_gpu_cfg is not None: + self.gpu_global_rank = multi_gpu_cfg["global_rank"] + self.gpu_world_size = multi_gpu_cfg["world_size"] + else: + self.gpu_global_rank = 0 + self.gpu_world_size = 1 + + # RND components + if rnd_cfg is not None: + # Create RND module + self.rnd = RandomNetworkDistillation(device=self.device, **rnd_cfg) + # Create RND optimizer + params = self.rnd.predictor.parameters() + self.rnd_optimizer = optim.Adam(params, lr=rnd_cfg.get("learning_rate", 1e-3)) + else: + self.rnd = None + self.rnd_optimizer = None + + # Symmetry components + if symmetry_cfg is not None: + # Check if symmetry is enabled + use_symmetry = symmetry_cfg["use_data_augmentation"] or symmetry_cfg["use_mirror_loss"] + # Print that we are not using symmetry + if not use_symmetry: + print("Symmetry not used for learning. We will use it for logging instead.") + # If function is a string then resolve it to a function + if isinstance(symmetry_cfg["data_augmentation_func"], str): + symmetry_cfg["data_augmentation_func"] = string_to_callable(symmetry_cfg["data_augmentation_func"]) + # Check valid configuration + if symmetry_cfg["use_data_augmentation"] and not callable(symmetry_cfg["data_augmentation_func"]): + raise ValueError( + "Data augmentation enabled but the function is not callable:" + f" {symmetry_cfg['data_augmentation_func']}" + ) + # Store symmetry configuration + self.symmetry = symmetry_cfg + else: + self.symmetry = None + + # PPO components + self.policy = policy + self.policy.to(self.device) + # Create optimizer + self.optimizer = optim.Adam(self.policy.parameters(), lr=learning_rate) + # Create rollout storage + self.storage: RolloutStorage = None # type: ignore + self.transition = RolloutStorage.Transition() + + # PPO parameters + self.clip_param = clip_param + self.num_learning_epochs = num_learning_epochs + self.num_mini_batches = num_mini_batches + self.value_loss_coef = value_loss_coef + self.entropy_coef = entropy_coef + self.gamma = gamma + self.lam = lam + self.max_grad_norm = max_grad_norm + self.use_clipped_value_loss = use_clipped_value_loss + self.desired_kl = desired_kl + self.schedule = schedule + self.learning_rate = learning_rate + self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch + self.chunk_size = chunk_size + + def init_storage( + self, training_type, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, actions_shape + ): + # create memory for RND as well :) + if self.rnd: + rnd_state_shape = [self.rnd.num_states] + else: + rnd_state_shape = None + # create rollout storage + self.storage = RolloutStorage( + training_type, + num_envs, + num_transitions_per_env, + actor_obs_shape, + critic_obs_shape, + actions_shape, + rnd_state_shape, + self.device, + ) + + def act(self, obs, critic_obs, prev_action): + if self.policy.is_recurrent: + self.transition.hidden_states = self.policy.get_hidden_states() + # compute the actions and values + self.transition.actions = self.policy.act(obs, prev_action).detach() + self.transition.values = self.policy.evaluate(critic_obs, prev_action).detach() + self.transition.actions_log_prob = self.policy.get_actions_log_prob(self.transition.actions).detach() + self.transition.action_mean = self.policy.action_mean.detach() + self.transition.action_sigma = self.policy.action_std.detach() + # need to record obs and critic_obs before env.step() + self.transition.observations = obs + self.transition.privileged_observations = critic_obs + return self.transition.actions + + def process_env_step(self, rewards, dones, infos): + # Record the rewards and dones + # Note: we clone here because later on we bootstrap the rewards based on timeouts + self.transition.rewards = rewards.clone() + self.transition.dones = dones + + # Compute the intrinsic rewards and add to extrinsic rewards + if self.rnd: + # Obtain curiosity gates / observations from infos + rnd_state = infos["observations"]["rnd_state"] + # Compute the intrinsic rewards + # note: rnd_state is the gated_state after normalization if normalization is used + self.intrinsic_rewards, rnd_state = self.rnd.get_intrinsic_reward(rnd_state) + # Add intrinsic rewards to extrinsic rewards + self.transition.rewards += self.intrinsic_rewards + # Record the curiosity gates + self.transition.rnd_state = rnd_state.clone() + + # Bootstrapping on time outs + if "time_outs" in infos: + self.transition.rewards += self.gamma * torch.squeeze( + self.transition.values * infos["time_outs"].unsqueeze(1).to(self.device), 1 + ) + + # record the transition + self.storage.add_transitions(self.transition) + self.transition.clear() + self.policy.reset(dones) + + def compute_returns(self, last_critic_obs, prev_action): + # compute value for the last step + last_values = self.policy.evaluate(last_critic_obs, prev_action).detach() + self.storage.compute_returns( + last_values, self.gamma, self.lam, normalize_advantage=not self.normalize_advantage_per_mini_batch + ) + + def update(self): # noqa: C901 + mean_value_loss = 0 + mean_surrogate_loss = 0 + mean_entropy = 0 + # -- RND loss + if self.rnd: + mean_rnd_loss = 0 + else: + mean_rnd_loss = None + # -- Symmetry loss + if self.symmetry: + mean_symmetry_loss = 0 + else: + mean_symmetry_loss = None + + # generator for mini batches + # modify for RL^2 + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + if self.policy.is_recurrent: + # generator = self.storage.recurrent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) + generator = self.storage.debug_chunk_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs, self.chunk_size) + else: + generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) + # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + + # iterate over batches + # addtionally output prev_action_batch + for ( + obs_batch, + prev_actions_batch, + critic_obs_batch, + actions_batch, + target_values_batch, + advantages_batch, + returns_batch, + old_actions_log_prob_batch, + old_mu_batch, + old_sigma_batch, + hid_states_batch, + masks_batch, + rnd_state_batch, + ) in generator: + + # number of augmentations per sample + # we start with 1 and increase it if we use symmetry augmentation + num_aug = 1 + # original batch size + original_batch_size = obs_batch.shape[0] + + # check if we should normalize advantages per mini batch + if self.normalize_advantage_per_mini_batch: + with torch.no_grad(): + advantages_batch = (advantages_batch - advantages_batch.mean()) / (advantages_batch.std() + 1e-8) + + # Perform symmetric augmentation + if self.symmetry and self.symmetry["use_data_augmentation"]: + # augmentation using symmetry + data_augmentation_func = self.symmetry["data_augmentation_func"] + # returned shape: [batch_size * num_aug, ...] + obs_batch, actions_batch = data_augmentation_func( + obs=obs_batch, actions=actions_batch, env=self.symmetry["_env"], obs_type="policy" + ) + critic_obs_batch, _ = data_augmentation_func( + obs=critic_obs_batch, actions=None, env=self.symmetry["_env"], obs_type="critic" + ) + # compute number of augmentations per sample + num_aug = int(obs_batch.shape[0] / original_batch_size) + # repeat the rest of the batch + # -- actor + old_actions_log_prob_batch = old_actions_log_prob_batch.repeat(num_aug, 1) + # -- critic + target_values_batch = target_values_batch.repeat(num_aug, 1) + advantages_batch = advantages_batch.repeat(num_aug, 1) + returns_batch = returns_batch.repeat(num_aug, 1) + + # Recompute actions log prob and entropy for current batch of transitions + # Note: we need to do this because we updated the policy with the new parameters + # -- actor + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + # RL^2 addtionally input prev_actions + self.policy.act(obs_batch, prev_actions_batch, masks=masks_batch, hidden_states=hid_states_batch[0]) + # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + actions_log_prob_batch = self.policy.get_actions_log_prob(actions_batch) + # -- critic + value_batch = self.policy.evaluate(critic_obs_batch, prev_actions_batch, masks=masks_batch, hidden_states=hid_states_batch[1]) + # -- entropy + # we only keep the entropy of the first augmentation (the original one) + mu_batch = self.policy.action_mean[:original_batch_size] + sigma_batch = self.policy.action_std[:original_batch_size] + entropy_batch = self.policy.entropy[:original_batch_size] + + # KL + if self.desired_kl is not None and self.schedule == "adaptive": + with torch.inference_mode(): + kl = torch.sum( + torch.log(sigma_batch / old_sigma_batch + 1.0e-5) + + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) + / (2.0 * torch.square(sigma_batch)) + - 0.5, + axis=-1, + ) + kl_mean = torch.mean(kl) + + # Reduce the KL divergence across all GPUs + if self.is_multi_gpu: + torch.distributed.all_reduce(kl_mean, op=torch.distributed.ReduceOp.SUM) + kl_mean /= self.gpu_world_size + + # Update the learning rate + # Perform this adaptation only on the main process + # TODO: Is this needed? If KL-divergence is the "same" across all GPUs, + # then the learning rate should be the same across all GPUs. + if self.gpu_global_rank == 0: + if kl_mean > self.desired_kl * 2.0: + self.learning_rate = max(1e-5, self.learning_rate / 1.5) + elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0: + self.learning_rate = min(1e-2, self.learning_rate * 1.5) + + # Update the learning rate for all GPUs + if self.is_multi_gpu: + lr_tensor = torch.tensor(self.learning_rate, device=self.device) + torch.distributed.broadcast(lr_tensor, src=0) + self.learning_rate = lr_tensor.item() + + # Update the learning rate for all parameter groups + for param_group in self.optimizer.param_groups: + param_group["lr"] = self.learning_rate + + # Surrogate loss + ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)) + surrogate = -torch.squeeze(advantages_batch) * ratio + surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp( + ratio, 1.0 - self.clip_param, 1.0 + self.clip_param + ) + surrogate_loss = torch.max(surrogate, surrogate_clipped).mean() + + # Value function loss + if self.use_clipped_value_loss: + value_clipped = target_values_batch + (value_batch - target_values_batch).clamp( + -self.clip_param, self.clip_param + ) + value_losses = (value_batch - returns_batch).pow(2) + value_losses_clipped = (value_clipped - returns_batch).pow(2) + value_loss = torch.max(value_losses, value_losses_clipped).mean() + else: + value_loss = (returns_batch - value_batch).pow(2).mean() + + loss = surrogate_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy_batch.mean() + + # Symmetry loss + if self.symmetry: + # obtain the symmetric actions + # if we did augmentation before then we don't need to augment again + if not self.symmetry["use_data_augmentation"]: + data_augmentation_func = self.symmetry["data_augmentation_func"] + obs_batch, _ = data_augmentation_func( + obs=obs_batch, actions=None, env=self.symmetry["_env"], obs_type="policy" + ) + # compute number of augmentations per sample + num_aug = int(obs_batch.shape[0] / original_batch_size) + + # actions predicted by the actor for symmetrically-augmented observations + mean_actions_batch = self.policy.act_inference(obs_batch.detach().clone()) + + # compute the symmetrically augmented actions + # note: we are assuming the first augmentation is the original one. + # We do not use the action_batch from earlier since that action was sampled from the distribution. + # However, the symmetry loss is computed using the mean of the distribution. + action_mean_orig = mean_actions_batch[:original_batch_size] + _, actions_mean_symm_batch = data_augmentation_func( + obs=None, actions=action_mean_orig, env=self.symmetry["_env"], obs_type="policy" + ) + + # compute the loss (we skip the first augmentation as it is the original one) + mse_loss = torch.nn.MSELoss() + symmetry_loss = mse_loss( + mean_actions_batch[original_batch_size:], actions_mean_symm_batch.detach()[original_batch_size:] + ) + # add the loss to the total loss + if self.symmetry["use_mirror_loss"]: + loss += self.symmetry["mirror_loss_coeff"] * symmetry_loss + else: + symmetry_loss = symmetry_loss.detach() + + # Random Network Distillation loss + if self.rnd: + # predict the embedding and the target + predicted_embedding = self.rnd.predictor(rnd_state_batch) + target_embedding = self.rnd.target(rnd_state_batch).detach() + # compute the loss as the mean squared error + mseloss = torch.nn.MSELoss() + rnd_loss = mseloss(predicted_embedding, target_embedding) + + # Compute the gradients + # -- For PPO + self.optimizer.zero_grad() + loss.backward() + # -- For RND + if self.rnd: + self.rnd_optimizer.zero_grad() # type: ignore + rnd_loss.backward() + + # Collect gradients from all GPUs + if self.is_multi_gpu: + self.reduce_parameters() + + # Apply the gradients + # -- For PPO + nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.optimizer.step() + # -- For RND + if self.rnd_optimizer: + self.rnd_optimizer.step() + + # Store the losses + mean_value_loss += value_loss.item() + mean_surrogate_loss += surrogate_loss.item() + mean_entropy += entropy_batch.mean().item() + # -- RND loss + if mean_rnd_loss is not None: + mean_rnd_loss += rnd_loss.item() + # -- Symmetry loss + if mean_symmetry_loss is not None: + mean_symmetry_loss += symmetry_loss.item() + + # -- For PPO + num_updates = self.num_learning_epochs * self.num_mini_batches + mean_value_loss /= num_updates + mean_surrogate_loss /= num_updates + mean_entropy /= num_updates + # -- For RND + if mean_rnd_loss is not None: + mean_rnd_loss /= num_updates + # -- For Symmetry + if mean_symmetry_loss is not None: + mean_symmetry_loss /= num_updates + # -- Clear the storage + self.storage.clear() + + # construct the loss dictionary + loss_dict = { + "value_function": mean_value_loss, + "surrogate": mean_surrogate_loss, + "entropy": mean_entropy, + } + if self.rnd: + loss_dict["rnd"] = mean_rnd_loss + if self.symmetry: + loss_dict["symmetry"] = mean_symmetry_loss + + return loss_dict + + """ + Helper functions + """ + + def broadcast_parameters(self): + """Broadcast model parameters to all GPUs.""" + # obtain the model parameters on current GPU + model_params = [self.policy.state_dict()] + if self.rnd: + model_params.append(self.rnd.predictor.state_dict()) + # broadcast the model parameters + torch.distributed.broadcast_object_list(model_params, src=0) + # load the model parameters on all GPUs from source GPU + self.policy.load_state_dict(model_params[0]) + if self.rnd: + self.rnd.predictor.load_state_dict(model_params[1]) + + def reduce_parameters(self): + """Collect gradients from all GPUs and average them. + + This function is called after the backward pass to synchronize the gradients across all GPUs. + """ + # Create a tensor to store the gradients + grads = [param.grad.view(-1) for param in self.policy.parameters() if param.grad is not None] + if self.rnd: + grads += [param.grad.view(-1) for param in self.rnd.parameters() if param.grad is not None] + all_grads = torch.cat(grads) + + # Average the gradients across all GPUs + torch.distributed.all_reduce(all_grads, op=torch.distributed.ReduceOp.SUM) + all_grads /= self.gpu_world_size + + # Get all parameters + all_params = self.policy.parameters() + if self.rnd: + all_params = chain(all_params, self.rnd.parameters()) + + # Update the gradients for all parameters with the reduced gradients + offset = 0 + for param in all_params: + if param.grad is not None: + numel = param.numel() + # copy data back from shared buffer + param.grad.data.copy_(all_grads[offset : offset + numel].view_as(param.grad.data)) + # update the offset for the next parameter + offset += numel diff --git a/rsl_rl/modules/__init__.py b/rsl_rl/modules/__init__.py index 0a96bd93..ebf83593 100644 --- a/rsl_rl/modules/__init__.py +++ b/rsl_rl/modules/__init__.py @@ -11,6 +11,7 @@ from .rnd import RandomNetworkDistillation from .student_teacher import StudentTeacher from .student_teacher_recurrent import StudentTeacherRecurrent +from .actor_critic_rl2 import ActorCriticRL2 __all__ = [ "ActorCritic", @@ -19,4 +20,5 @@ "RandomNetworkDistillation", "StudentTeacher", "StudentTeacherRecurrent", + "ActorCriticRL2", ] diff --git a/rsl_rl/modules/actor_critic_rl2.py b/rsl_rl/modules/actor_critic_rl2.py new file mode 100644 index 00000000..0a0ddb7e --- /dev/null +++ b/rsl_rl/modules/actor_critic_rl2.py @@ -0,0 +1,154 @@ +# Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION +# All rights reserved. +# +# SPDX-License-Identifier: BSD-3-Clause + +from __future__ import annotations + +import warnings + +from rsl_rl.modules import ActorCritic +from rsl_rl.networks import Memory +from rsl_rl.utils import resolve_nn_activation, unpad_trajectories + +import torch + +class ActorCriticRL2(ActorCritic): + """ + Minimal RL^2 actor-critic: + - Actor input: obs + prev_action + - Critic input: obs + """ + is_recurrent = True + + def __init__( + self, + num_actor_obs, + num_critic_obs, + num_actions, + actor_hidden_dims=[256, 256, 256], + critic_hidden_dims=[256, 256, 256], + activation="elu", + rnn_type="lstm", + rnn_hidden_dim=256, + rnn_num_layers=1, + init_noise_std=1.0, + **kwargs, + ): + if "rnn_hidden_size" in kwargs: + warnings.warn( + "The argument `rnn_hidden_size` is deprecated, use `rnn_hidden_dim` instead.", + DeprecationWarning, + ) + if rnn_hidden_dim == 256: + rnn_hidden_dim = kwargs.pop("rnn_hidden_size") + if kwargs: + print( + "ActorCriticRL2.__init__ got unexpected arguments, which will be ignored: " + + str(kwargs.keys()) + ) + + # ActorCritic base expects num_actor_obs = rnn_hidden_dim + super().__init__( + num_actor_obs=rnn_hidden_dim + num_actor_obs, # 现在的实验条件默认critic_obs=actor_obs + num_critic_obs=rnn_hidden_dim + num_critic_obs, + num_actions=num_actions, + actor_hidden_dims=actor_hidden_dims, + critic_hidden_dims=critic_hidden_dims, + activation=activation, + init_noise_std=init_noise_std, + ) + + activation = resolve_nn_activation(activation) + + # Actor RNN input = obs + prev_action + self.memory_a = Memory( + input_size=num_actor_obs+num_actions, + type=rnn_type, + num_layers=rnn_num_layers, + hidden_size=rnn_hidden_dim, + ) + # Critic RNN input = critic_obs + prev_action + self.memory_c = Memory( + input_size=num_critic_obs+num_actions, + type=rnn_type, + num_layers=rnn_num_layers, + hidden_size=rnn_hidden_dim, + ) + + print(f"Actor RNN: {self.memory_a}") + print(f"Critic RNN: {self.memory_c}") + + def reset(self, dones=None): + self.memory_a.reset(dones) + self.memory_c.reset(dones) + + # def act(self, observations, prev_actions, masks=None, hidden_states=None): + # # concat obs + prev_action along last dim + # if hidden_states is not None and observations.dim() == 2: + # # add dumping time dimension + # observations = observations.unsqueeze(0) + # prev_actions = prev_actions.unsqueeze(0) + # # input_a = torch.cat([observations, prev_actions], dim=-1) + # input_a = observations + # input_a = self.memory_a(input_a, masks, hidden_states, chunk_mode=True) + # input_a = input_a.squeeze(0) + # # 如果 observations 和 input_a 维度不一致, 且用mask掩码处理 + # if input_a.shape[:-1] != observations.shape[:-1] and masks is not None: + # # self.alg.update()的时候可能出现以上情况 + # # masked_obs = unpad_trajectories(observations, masks) + # # masked_obs = masked_obs.squeeze(0) + # # input_a = torch.cat([input_a, masked_obs], dim=-1) + # masked_obs = observations.squeeze(0) + # else: + # # input_a = torch.cat([input_a, observations], dim=-1) + # masked_obs = observations # (num_envs, obs_dim) + # return super().act(masked_obs) + + def act(self, observations, prev_actions, masks=None, hidden_states=None): + if masks is not None: + pass + input_a = torch.cat([observations, prev_actions], dim=-1) + input_a = self.memory_a(input_a, masks, hidden_states) + mlp_a_input = torch.cat([input_a.squeeze(0), observations], dim=-1) + return super().act(mlp_a_input) + + # 脚本训练过程用不到,应该不影响训练,暂时不修改 + def act_inference(self, observations, prev_actions): + input_a = torch.cat([observations, prev_actions], dim=-1) + input_a = self.memory_a(input_a) + return super().act_inference(input_a.squeeze(0)) + + def evaluate(self, critic_observations, prev_action, masks=None, hidden_states=None): + input_c = torch.cat([critic_observations, prev_action], dim=-1) + # actor和critic共用一个RNN + input_c = self.memory_a(input_c, masks, hidden_states) + mlp_c_input = torch.cat([input_c.squeeze(0), critic_observations], dim=-1) + return super().evaluate(mlp_c_input) + + # # 我们改成critic和actor使用同一个RNN,输入相同context和obs拼接 + # def evaluate(self, observations, prev_actions, masks=None, hidden_states=None): + # # concat obs + prev_action along last dim + # if hidden_states is not None and observations.dim() == 2: + # # add dumping time dimension + # observations = observations.unsqueeze(0) + # prev_actions = prev_actions.unsqueeze(0) + # # input_a = torch.cat([observations, prev_actions], dim=-1) + # input_a = observations + # input_a = self.memory_c(input_a, masks, hidden_states, chunk_mode=True) + # input_a = input_a.squeeze(0) + # # 如果 observations 和 input_a 维度不一致, 用mask掩码处理 + # if input_a.shape[:-1] != observations.shape[:-1] and masks is not None: + # # masked_obs = unpad_trajectories(observations, masks) + # # masked_obs = masked_obs.squeeze(0) + # # input_a = torch.cat([input_a, masked_obs], dim=-1) + # # masked_obs = observations + # masked_obs = observations.squeeze(0) + # else: + # # input_a = torch.cat([input_a, observations], dim=-1) + # masked_obs = observations + # return super().evaluate(masked_obs) + + # 强制都返回Mem_a的隐层,和上层的API对齐,减少修改 + def get_hidden_states(self): + return self.memory_a.hidden_states, self.memory_a.hidden_states \ No newline at end of file diff --git a/rsl_rl/runners/on_policy_runner.py b/rsl_rl/runners/on_policy_runner.py index ecda87ae..82cd423d 100644 --- a/rsl_rl/runners/on_policy_runner.py +++ b/rsl_rl/runners/on_policy_runner.py @@ -12,7 +12,7 @@ from collections import deque import rsl_rl -from rsl_rl.algorithms import PPO, Distillation +from rsl_rl.algorithms import PPO, Distillation, RL2PPO from rsl_rl.env import VecEnv from rsl_rl.modules import ( ActorCritic, @@ -20,6 +20,7 @@ EmpiricalNormalization, StudentTeacher, StudentTeacherRecurrent, + ActorCriticRL2, ) from rsl_rl.utils import store_code_state @@ -38,7 +39,7 @@ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, dev self._configure_multi_gpu() # resolve training type depending on the algorithm - if self.alg_cfg["class_name"] == "PPO": + if self.alg_cfg["class_name"] == "PPO" or self.alg_cfg["class_name"] == "RL2PPO": self.training_type = "rl" elif self.alg_cfg["class_name"] == "Distillation": self.training_type = "distillation" @@ -69,7 +70,7 @@ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, dev # evaluate the policy class policy_class = eval(self.policy_cfg.pop("class_name")) - policy: ActorCritic | ActorCriticRecurrent | StudentTeacher | StudentTeacherRecurrent = policy_class( + policy: ActorCritic | ActorCriticRecurrent | StudentTeacher | StudentTeacherRecurrent | ActorCriticRL2 = policy_class( num_obs, num_privileged_obs, self.env.num_actions, **self.policy_cfg ).to(self.device) @@ -93,7 +94,7 @@ def __init__(self, env: VecEnv, train_cfg: dict, log_dir: str | None = None, dev # initialize algorithm alg_class = eval(self.alg_cfg.pop("class_name")) - self.alg: PPO | Distillation = alg_class(policy, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg) + self.alg: PPO | Distillation | RL2PPO = alg_class(policy, device=self.device, **self.alg_cfg, multi_gpu_cfg=self.multi_gpu_cfg) # store training configuration self.num_steps_per_env = self.cfg["num_steps_per_env"] @@ -167,6 +168,10 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals obs, extras = self.env.get_observations() privileged_obs = extras["observations"].get(self.privileged_obs_type, obs) obs, privileged_obs = obs.to(self.device), privileged_obs.to(self.device) + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + # zero-init prev_action for RL^2 + prev_actions = torch.zeros(self.env.action_space.shape, device=self.device) + # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< self.train_mode() # switch to train mode (for dropout for example) # Book keeping @@ -190,6 +195,11 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals # TODO: Do we need to synchronize empirical normalizers? # Right now: No, because they all should converge to the same values "asymptotically". + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + # init prev_action for RL^2 + prev_actions = torch.zeros(self.env.action_space.shape, device=self.device) + # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + # Start training start_iter = self.current_learning_iteration tot_iter = start_iter + num_learning_iterations @@ -199,12 +209,26 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals with torch.inference_mode(): for _ in range(self.num_steps_per_env): # Sample actions - actions = self.alg.act(obs, privileged_obs) + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + if isinstance(self.alg, RL2PPO): + # special process for RL2PPO + actions = self.alg.act(obs, privileged_obs, prev_actions) + prev_actions = actions.detach() + else: + actions = self.alg.act(obs, privileged_obs) + # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< # Step the environment obs, rewards, dones, infos = self.env.step(actions.to(self.env.device)) # Move to device obs, rewards, dones = (obs.to(self.device), rewards.to(self.device), dones.to(self.device)) # perform normalization + + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + # Reset prev_actions at episode boundary + if isinstance(self.alg, RL2PPO): + prev_actions[dones] = torch.zeros(self.env.action_space.shape, device=self.device) + # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + obs = self.obs_normalizer(obs) if self.privileged_obs_type is not None: privileged_obs = self.privileged_obs_normalizer( @@ -254,7 +278,10 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals # compute returns if self.training_type == "rl": - self.alg.compute_returns(privileged_obs) + if isinstance(self.alg, RL2PPO): + self.alg.compute_returns(obs, prev_actions) + else: + self.alg.compute_returns(privileged_obs) # update policy loss_dict = self.alg.update() @@ -285,6 +312,7 @@ def learn(self, num_learning_iterations: int, init_at_random_ep_len: bool = Fals if self.log_dir is not None and not self.disable_logs: self.save(os.path.join(self.log_dir, f"model_{self.current_learning_iteration}.pt")) + # 改用训练总步数作为x轴 def log(self, locs: dict, width: int = 80, pad: int = 35): # Compute the collection size collection_size = self.num_steps_per_env * self.env.num_envs * self.gpu_world_size @@ -293,6 +321,114 @@ def log(self, locs: dict, width: int = 80, pad: int = 35): self.tot_time += locs["collection_time"] + locs["learn_time"] iteration_time = locs["collection_time"] + locs["learn_time"] + # -- Episode info + ep_string = "" + if locs["ep_infos"]: + for key in locs["ep_infos"][0]: + infotensor = torch.tensor([], device=self.device) + for ep_info in locs["ep_infos"]: + # handle scalar and zero dimensional tensor infos + if key not in ep_info: + continue + if not isinstance(ep_info[key], torch.Tensor): + ep_info[key] = torch.Tensor([ep_info[key]]) + if len(ep_info[key].shape) == 0: + ep_info[key] = ep_info[key].unsqueeze(0) + infotensor = torch.cat((infotensor, ep_info[key].to(self.device))) + value = torch.mean(infotensor) + # log to logger and terminal + if "/" in key: + self.writer.add_scalar(key, value, self.tot_timesteps) + ep_string += f"""{f'{key}:':>{pad}} {value:.4f}\n""" + else: + self.writer.add_scalar("Episode/" + key, value, self.tot_timesteps) + ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n""" + + mean_std = self.alg.policy.action_std.mean() + fps = int(collection_size / (locs["collection_time"] + locs["learn_time"])) + + # -- Losses + for key, value in locs["loss_dict"].items(): + self.writer.add_scalar(f"Loss/{key}", value, self.tot_timesteps) + self.writer.add_scalar("Loss/learning_rate", self.alg.learning_rate, self.tot_timesteps) + + # -- Policy + self.writer.add_scalar("Policy/mean_noise_std", mean_std.item(), self.tot_timesteps) + + # -- Performance + self.writer.add_scalar("Perf/total_fps", fps, self.tot_timesteps) + self.writer.add_scalar("Perf/collection time", locs["collection_time"], self.tot_timesteps) + self.writer.add_scalar("Perf/learning_time", locs["learn_time"], self.tot_timesteps) + + # -- Training + if len(locs["rewbuffer"]) > 0: + # separate logging for intrinsic and extrinsic rewards + if self.alg.rnd: + self.writer.add_scalar("Rnd/mean_extrinsic_reward", statistics.mean(locs["erewbuffer"]), self.tot_timesteps) + self.writer.add_scalar("Rnd/mean_intrinsic_reward", statistics.mean(locs["irewbuffer"]), self.tot_timesteps) + self.writer.add_scalar("Rnd/weight", self.alg.rnd.weight, self.tot_timesteps) + # everything else + self.writer.add_scalar("Train/mean_reward", statistics.mean(locs["rewbuffer"]), self.tot_timesteps) + self.writer.add_scalar("Train/mean_episode_length", statistics.mean(locs["lenbuffer"]), self.tot_timesteps) + if self.logger_type != "wandb": # wandb does not support non-integer x-axis logging + self.writer.add_scalar("Train/mean_reward/time", statistics.mean(locs["rewbuffer"]), self.tot_time) + self.writer.add_scalar( + "Train/mean_episode_length/time", statistics.mean(locs["lenbuffer"]), self.tot_time + ) + + str = f" \033[1m Learning iteration {locs['it']}/{locs['tot_iter']} \033[0m " + + if len(locs["rewbuffer"]) > 0: + log_string = ( + f"""{'#' * width}\n""" + f"""{str.center(width, ' ')}\n\n""" + f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ + 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" + f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" + ) + # -- Losses + for key, value in locs["loss_dict"].items(): + log_string += f"""{f'Mean {key} loss:':>{pad}} {value:.4f}\n""" + # -- Rewards + if self.alg.rnd: + log_string += ( + f"""{'Mean extrinsic reward:':>{pad}} {statistics.mean(locs['erewbuffer']):.2f}\n""" + f"""{'Mean intrinsic reward:':>{pad}} {statistics.mean(locs['irewbuffer']):.2f}\n""" + ) + log_string += f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n""" + # -- episode info + log_string += f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""" + else: + log_string = ( + f"""{'#' * width}\n""" + f"""{str.center(width, ' ')}\n\n""" + f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ + 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" + f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" + ) + for key, value in locs["loss_dict"].items(): + log_string += f"""{f'{key}:':>{pad}} {value:.4f}\n""" + + log_string += ep_string + log_string += ( + f"""{'-' * width}\n""" + f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n""" + f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n""" + f"""{'Time elapsed:':>{pad}} {time.strftime("%H:%M:%S", time.gmtime(self.tot_time))}\n""" + f"""{'ETA:':>{pad}} {time.strftime("%H:%M:%S", time.gmtime(self.tot_time / (locs['it'] - locs['start_iter'] + 1) * ( + locs['start_iter'] + locs['num_learning_iterations'] - locs['it'])))}\n""" + ) + print(log_string) + + # 旧版 + def log_backup(self, locs: dict, width: int = 80, pad: int = 35): + # Compute the collection size + collection_size = self.num_steps_per_env * self.env.num_envs * self.gpu_world_size + # Update total time-steps and time + self.tot_timesteps += collection_size + self.tot_time += locs["collection_time"] + locs["learn_time"] + iteration_time = locs["collection_time"] + locs["learn_time"] + # -- Episode info ep_string = "" if locs["ep_infos"]: diff --git a/rsl_rl/storage/rollout_storage.py b/rsl_rl/storage/rollout_storage.py index 42b8c9fd..6eb61a68 100644 --- a/rsl_rl/storage/rollout_storage.py +++ b/rsl_rl/storage/rollout_storage.py @@ -314,3 +314,474 @@ def recurrent_mini_batch_generator(self, num_mini_batches, num_epochs=8): ), masks_batch, rnd_state_batch first_traj = last_traj + + + # for reinforcement learning with recurrent networks + RL^2 style input (obs + prev_action) + def recurrent_mini_batch_generator_with_prev_action(self, num_mini_batches, num_epochs=8): + if self.training_type != "rl": + raise ValueError("This function is only available for reinforcement learning training.") + + # pad observations + padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones) + if self.privileged_observations is not None: + padded_privileged_obs_trajectories, _ = split_and_pad_trajectories(self.privileged_observations, self.dones) + else: + padded_privileged_obs_trajectories = padded_obs_trajectories + + # pad rnd state if exists + if self.rnd_state_shape is not None: + padded_rnd_state_trajectories, _ = split_and_pad_trajectories(self.rnd_state, self.dones) + else: + padded_rnd_state_trajectories = None + + # pad actions (for RL^2 input as prev_action) + padded_action_trajectories, _ = split_and_pad_trajectories(self.actions, self.dones) + + mini_batch_size = self.num_envs // num_mini_batches + for ep in range(num_epochs): + first_traj = 0 + for i in range(num_mini_batches): + start = i * mini_batch_size + stop = (i + 1) * mini_batch_size + + dones = self.dones.squeeze(-1) + last_was_done = torch.zeros_like(dones, dtype=torch.bool) + last_was_done[1:] = dones[:-1] + last_was_done[0] = True + trajectories_batch_size = torch.sum(last_was_done[:, start:stop]) + last_traj = first_traj + trajectories_batch_size + + masks_batch = trajectory_masks[:, first_traj:last_traj] + obs_batch = padded_obs_trajectories[:, first_traj:last_traj] + privileged_obs_batch = padded_privileged_obs_trajectories[:, first_traj:last_traj] + # 新增 prev_action batch (for RL^2 input) + prev_actions_batch = padded_action_trajectories[:, first_traj:last_traj] + # RL^2: shift prev_actions by 1 step + prev_actions_batch = torch.cat( + [torch.zeros_like(prev_actions_batch[:, :1, :], device=prev_actions_batch.device), + prev_actions_batch[:, :-1, :]], dim=1 + ) + + if padded_rnd_state_trajectories is not None: + rnd_state_batch = padded_rnd_state_trajectories[:, first_traj:last_traj] + else: + rnd_state_batch = None + + actions_batch = self.actions[:, start:stop] + old_mu_batch = self.mu[:, start:stop] + old_sigma_batch = self.sigma[:, start:stop] + returns_batch = self.returns[:, start:stop] + advantages_batch = self.advantages[:, start:stop] + values_batch = self.values[:, start:stop] + old_actions_log_prob_batch = self.actions_log_prob[:, start:stop] + + # hidden states reshape + last_was_done = last_was_done.permute(1, 0) + hid_a_batch = [ + saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] + .transpose(1, 0) + .contiguous() + for saved_hidden_states in self.saved_hidden_states_a + ] + hid_c_batch = [ + saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] + .transpose(1, 0) + .contiguous() + for saved_hidden_states in self.saved_hidden_states_c + ] + hid_a_batch = hid_a_batch[0] if len(hid_a_batch) == 1 else hid_a_batch + hid_c_batch = hid_c_batch[0] if len(hid_c_batch) == 1 else hid_c_batch + + # 新增 prev_actions_batch 输出 + yield obs_batch, prev_actions_batch, privileged_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( + hid_a_batch, + hid_c_batch, + ), masks_batch, rnd_state_batch + + first_traj = last_traj + + + # for reinforcement learning with recurrent networks + RL^2 style input (obs + prev_action) + shuffle + def recurrent_mini_batch_generator_with_prev_action_shuffle(self, num_mini_batches, num_epochs=8): + if self.training_type != "rl": + raise ValueError("This function is only available for reinforcement learning training.") + + # pad observations + padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories(self.observations, self.dones) + if self.privileged_observations is not None: + padded_privileged_obs_trajectories, _ = split_and_pad_trajectories(self.privileged_observations, self.dones) + else: + padded_privileged_obs_trajectories = padded_obs_trajectories + + # pad rnd state if exists + if self.rnd_state_shape is not None: + padded_rnd_state_trajectories, _ = split_and_pad_trajectories(self.rnd_state, self.dones) + else: + padded_rnd_state_trajectories = None + + # pad actions (for RL^2 input as prev_action) + padded_action_trajectories, _ = split_and_pad_trajectories(self.actions, self.dones) + + mini_batch_size = self.num_envs // num_mini_batches + for ep in range(num_epochs): + first_traj = 0 + + # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> + # shuffle the env + shuffle_env_index = torch.randperm(self.num_envs, device=self.device) + + # re-pad the trajectories according to the shuffled env index + padded_obs_trajectories, trajectory_masks = split_and_pad_trajectories( + self.observations[:, shuffle_env_index], + self.dones[:, shuffle_env_index], + ) + if self.privileged_observations is not None: + padded_privileged_obs_trajectories, _ = split_and_pad_trajectories( + self.privileged_observations[:, shuffle_env_index], + self.dones[:, shuffle_env_index], + ) + else: + padded_privileged_obs_trajectories = padded_obs_trajectories + + # re-pad rnd state if exists + if self.rnd_state_shape is not None: + padded_rnd_state_trajectories, _ = split_and_pad_trajectories( + self.rnd_state[:, shuffle_env_index], + self.dones[:, shuffle_env_index], + ) + else: + padded_rnd_state_trajectories = None + + # re-pad actions (for RL^2 input as prev_action) + padded_action_trajectories, _ = split_and_pad_trajectories( + self.actions[:, shuffle_env_index], + self.dones[:, shuffle_env_index], + ) + # <<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<< + + for i in range(num_mini_batches): + start = i * mini_batch_size + stop = (i + 1) * mini_batch_size + shuffled_batch_envs = shuffle_env_index[start:stop] + + dones = self.dones.squeeze(-1) + last_was_done = torch.zeros_like(dones, dtype=torch.bool) + last_was_done[1:] = dones[:-1] + last_was_done[0] = True + trajectories_batch_size = torch.sum(last_was_done[:, shuffled_batch_envs]) + last_traj = first_traj + trajectories_batch_size + + masks_batch = trajectory_masks[:, first_traj:last_traj] + obs_batch = padded_obs_trajectories[:, first_traj:last_traj] + privileged_obs_batch = padded_privileged_obs_trajectories[:, first_traj:last_traj] + # 新增 prev_action batch (for RL^2 input) + prev_actions_batch = padded_action_trajectories[:, first_traj:last_traj] + # RL^2: shift prev_actions by 1 step + prev_actions_batch = torch.cat( + [torch.zeros_like(prev_actions_batch[:, :1, :], device=prev_actions_batch.device), + prev_actions_batch[:, :-1, :]], dim=1 + ) + + if padded_rnd_state_trajectories is not None: + rnd_state_batch = padded_rnd_state_trajectories[:, first_traj:last_traj] + else: + rnd_state_batch = None + + actions_batch = self.actions[:, shuffled_batch_envs] + old_mu_batch = self.mu[:, shuffled_batch_envs] + old_sigma_batch = self.sigma[:, shuffled_batch_envs] + returns_batch = self.returns[:, shuffled_batch_envs] + advantages_batch = self.advantages[:, shuffled_batch_envs] + values_batch = self.values[:, shuffled_batch_envs] + old_actions_log_prob_batch = self.actions_log_prob[:, shuffled_batch_envs] + + # hidden states reshape + last_was_done = last_was_done.permute(1, 0) + hid_a_batch = [ + saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] + .transpose(1, 0) + .contiguous() + for saved_hidden_states in self.saved_hidden_states_a + ] + hid_c_batch = [ + saved_hidden_states.permute(2, 0, 1, 3)[last_was_done][first_traj:last_traj] + .transpose(1, 0) + .contiguous() + for saved_hidden_states in self.saved_hidden_states_c + ] + hid_a_batch = hid_a_batch[0] if len(hid_a_batch) == 1 else hid_a_batch + hid_c_batch = hid_c_batch[0] if len(hid_c_batch) == 1 else hid_c_batch + + # 新增 prev_actions_batch 输出 + yield obs_batch, prev_actions_batch, privileged_obs_batch, actions_batch, values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( + hid_a_batch, + hid_c_batch, + ), masks_batch, rnd_state_batch + + first_traj = last_traj + + + # chunk 生成器 9.16 + def debug_chunk_mini_batch_generator(self, num_mini_batches, num_epochs=8, chunk_size=1): + """ + Generator that yields mini-batches of experience in chunks of length `chunk_size`, + + Yields: + Tuple: ( + obs, privileged_obs, actions, values, returns, + log_probs, advantages, old_mu, old_sigma, + hid_states(first), masks(None), rnd_state + ) + Shapes: + obs: [chunk_size, batch, obs_dim] + masks: [chunk_size, batch, 1] (1=valid, 0=padding) + hidden_state: [num_layers, batch, hidden_dim] + """ + if self.training_type != "rl": + raise ValueError("This function is only available for reinforcement learning training.") + + T = self.num_transitions_per_env + N = self.num_envs + assert T % chunk_size == 0, "T must be divisible by chunk_size" + + num_chunks = T // chunk_size + batch_size = num_chunks * N + mini_batch_size = batch_size // num_mini_batches + + indices = torch.randperm(batch_size, requires_grad=False, device=self.device) + + # 切分chunk函数 + def make_chunks(x): + T, N = x.shape[:2] + num_chunks = T // chunk_size + + x = x.view(num_chunks, chunk_size, N, *x.shape[2:]) + x = x.transpose(0, 1) + + x = x.reshape(chunk_size, num_chunks * N, *x.shape[3:]) + return x + + # return hidden_chunks_first # [num_layers, batch, hidden_dim] + def get_chunk_hidden_states(saved_hidden_states): + """ + saved_hidden_states: list of [T, num_layers, num_envs, hidden_dim] + returns: tensor [num_layers, total_chunks, hidden_dim] + """ + if len(saved_hidden_states) == 1: + h = saved_hidden_states[0] # [T, num_layers, num_envs, hidden_dim] + else: + h = torch.stack(saved_hidden_states, dim=1) # [T, num_layers, num_envs, hidden_dim] or list stacking + + # take first timestep of each chunk + h = h.permute(2, 0, 1, 3) # [num_envs, T, num_layers, hidden_dim] + h_chunks = h.view(N, num_chunks, chunk_size, h.shape[2], h.shape[3]) + h_first = h_chunks[:, :, 0, :, :] # [num_envs, num_chunks, num_layers, hidden_dim] + h_first = h_first.permute(2, 0, 1, 3).reshape(h_first.shape[2], num_chunks*N, h_first.shape[3]) + return h_first # [num_layers, batch, hidden_dim] + + # Core + observations = make_chunks(self.observations) + if self.privileged_observations is not None: + privileged_observations = make_chunks(self.privileged_observations) + else: + privileged_observations = observations + + actions = make_chunks(self.actions) + values = make_chunks(self.values) + returns = make_chunks(self.returns) + + # prev_actions for RL^2 input + prev_actions = torch.zeros_like(self.actions, device=self.actions.device) + prev_actions[1:, :, :] = self.actions[:-1, :, :] + prev_actions = make_chunks(prev_actions) + + # For PPO + old_actions_log_prob = make_chunks(self.actions_log_prob) + advantages = make_chunks(self.advantages) + old_mu = make_chunks(self.mu) + old_sigma = make_chunks(self.sigma) + + # For RND + if self.rnd_state_shape is not None: + rnd_state = make_chunks(self.rnd_state) + + # For hidden_states_first + hid_a_chunks_first = get_chunk_hidden_states(self.saved_hidden_states_a) + hid_c_chunks_first = get_chunk_hidden_states(self.saved_hidden_states_c) + + # 为了对齐后面RNN的输入,直接生成一个全true的mask + # 必须mask有值才会用输入的hid,mask=None默认用memory自己存的hid + masks_batch = torch.ones((chunk_size, mini_batch_size), dtype=torch.bool, device=observations.device) + + for epoch in range(num_epochs): + for i in range(num_mini_batches): + # Select the indices for the mini-batch + start = i * mini_batch_size + end = (i + 1) * mini_batch_size + batch_idx = indices[start:end] + + # Create the mini-batch + # [batch_size, batch, dim],记得切第二维 + # -- Core + obs_batch = observations[:, batch_idx] + privileged_observations_batch = privileged_observations[:, batch_idx] + actions_batch = actions[:, batch_idx] + prev_actions_batch = prev_actions[:, batch_idx] + + # -- For PPO + target_values_batch = values[:, batch_idx] + returns_batch = returns[:, batch_idx] + old_actions_log_prob_batch = old_actions_log_prob[:, batch_idx] + advantages_batch = advantages[:, batch_idx] + old_mu_batch = old_mu[:, batch_idx] + old_sigma_batch = old_sigma[:, batch_idx] + + # hidden_state_first + hid_a_batch = hid_a_chunks_first[:, batch_idx] + hid_c_batch = hid_c_chunks_first[:, batch_idx] + hid_states_batch = (hid_a_batch, hid_c_batch) + + # -- For RND + if self.rnd_state_shape is not None: + rnd_state_batch = rnd_state[:, batch_idx] + else: + rnd_state_batch = None + + # yield the mini-batch + yield obs_batch, prev_actions_batch, privileged_observations_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, old_mu_batch, old_sigma_batch, ( + hid_a_batch, + hid_c_batch, + ), masks_batch, rnd_state_batch + + + # 旧的chunk生成器 已经弃用 + def chunk_mini_batch_generator(self, num_mini_batches, num_epochs=8, chunk_size=2): + """ + Fully vectorized chunk mini-batch generator. + Outputs are identical to your previous generator but with much less Python overhead. + + Shapes: + obs: [chunk_size, batch, obs_dim] + masks: [chunk_size, batch] + hidden_state: [num_layers, batch, hidden_dim] + """ + if self.training_type != "rl": + raise ValueError("This function is only available for reinforcement learning training.") + + num_envs = self.num_envs + T = self.num_transitions_per_env + num_chunks_per_env = T // chunk_size + total_chunks = num_envs * num_chunks_per_env + + device = self.device + + # ========== 1. Permute to [num_envs, T, dim] ========== + obs_env = self.observations.permute(1, 0, 2) # [num_envs, T, obs_dim] + priv_obs_env = obs_env if self.privileged_observations is None else self.privileged_observations.permute(1, 0, 2) + actions_env = self.actions.permute(1, 0, 2) + values_env = self.values.permute(1, 0, 2) + returns_env = self.returns.permute(1, 0, 2) + log_probs_env = self.actions_log_prob.permute(1, 0, 2) + advantages_env = self.advantages.permute(1, 0, 2) + old_mu_env = self.mu.permute(1, 0, 2) + old_sigma_env = self.sigma.permute(1, 0, 2) + dones_env = self.dones.permute(1, 0, 2) + + prev_actions_env = torch.zeros_like(actions_env) + prev_actions_env[:, 1:, :] = actions_env[:, :-1, :] + + rnd_env = None if self.rnd_state_shape is None else self.rnd_state.permute(1, 0, 2) + + # ========== 2. Create chunk indices ========== + # chunk shape: [num_envs, num_chunks_per_env, chunk_size, dim] + def make_chunks(x): + x = x.view(num_envs, num_chunks_per_env, chunk_size, -1) + return x + + obs_chunks = make_chunks(obs_env) + priv_obs_chunks = make_chunks(priv_obs_env) + actions_chunks = make_chunks(actions_env) + prev_actions_chunks = make_chunks(prev_actions_env) + values_chunks = make_chunks(values_env) + returns_chunks = make_chunks(returns_env) + log_probs_chunks = make_chunks(log_probs_env) + advantages_chunks = make_chunks(advantages_env) + old_mu_chunks = make_chunks(old_mu_env) + old_sigma_chunks = make_chunks(old_sigma_env) + masks_chunks = (~make_chunks(dones_env).bool().squeeze(-1)) # [num_envs, num_chunks, chunk_size] + + if rnd_env is not None: + rnd_chunks = make_chunks(rnd_env) + + # ========== 3. Prepare hidden states (first timestep of each chunk) ========== + def get_chunk_hidden_states(saved_hidden_states): + """ + saved_hidden_states: list of [T, num_layers, num_envs, hidden_dim] + returns: tensor [num_layers, total_chunks, hidden_dim] + """ + if len(saved_hidden_states) == 1: + h = saved_hidden_states[0] # [T, num_layers, num_envs, hidden_dim] + else: + h = torch.stack(saved_hidden_states, dim=1) # [T, num_layers, num_envs, hidden_dim] or list stacking + + # take first timestep of each chunk + h = h.permute(2, 0, 1, 3) # [num_envs, T, num_layers, hidden_dim] + h_chunks = h.view(num_envs, num_chunks_per_env, chunk_size, h.shape[2], h.shape[3]) + h_first = h_chunks[:, :, 0, :, :] # [num_envs, num_chunks, num_layers, hidden_dim] + h_first = h_first.permute(2, 0, 1, 3).reshape(h_first.shape[2], total_chunks, h_first.shape[3]) + return h_first # [num_layers, total_chunks, hidden_dim] + + hid_a_chunks = get_chunk_hidden_states(self.saved_hidden_states_a) + hid_c_chunks = get_chunk_hidden_states(self.saved_hidden_states_c) + hid_states_chunks = (hid_a_chunks, hid_c_chunks) + + # ========== 4. Flatten chunks ========== + def flatten_chunks(x): + return x.reshape(total_chunks, chunk_size, *x.shape[3:]) # [total_chunks, chunk_size, dim] + + obs_chunks = flatten_chunks(obs_chunks) + priv_obs_chunks = flatten_chunks(priv_obs_chunks) + actions_chunks = flatten_chunks(actions_chunks) + prev_actions_chunks = flatten_chunks(prev_actions_chunks) + values_chunks = flatten_chunks(values_chunks) + returns_chunks = flatten_chunks(returns_chunks) + log_probs_chunks = flatten_chunks(log_probs_chunks) + advantages_chunks = flatten_chunks(advantages_chunks) + old_mu_chunks = flatten_chunks(old_mu_chunks) + old_sigma_chunks = flatten_chunks(old_sigma_chunks) + masks_chunks = flatten_chunks(masks_chunks) # [total_chunks, chunk_size] + if rnd_env is not None: + rnd_chunks = flatten_chunks(rnd_chunks) + + # ========== 5. Shuffle chunks ========== + indices = torch.randperm(total_chunks, device=device) + + mini_batch_size = total_chunks // num_mini_batches + + for epoch in range(num_epochs): + for i in range(num_mini_batches): + batch_idx = indices[i * mini_batch_size: (i + 1) * mini_batch_size] + + # index all tensors + obs_batch = obs_chunks[batch_idx].transpose(0, 1) + prev_actions_batch = prev_actions_chunks[batch_idx].transpose(0, 1) + priv_obs_batch = priv_obs_chunks[batch_idx].transpose(0, 1) + actions_batch = actions_chunks[batch_idx].transpose(0, 1) + values_batch = values_chunks[batch_idx].transpose(0, 1) + returns_batch = returns_chunks[batch_idx].transpose(0, 1) + log_probs_batch = log_probs_chunks[batch_idx].transpose(0, 1) + advantages_batch = advantages_chunks[batch_idx].transpose(0, 1) + old_mu_batch = old_mu_chunks[batch_idx].transpose(0, 1) + old_sigma_batch = old_sigma_chunks[batch_idx].transpose(0, 1) + masks_batch = masks_chunks[batch_idx].transpose(0, 1) + hid_a_batch = hid_a_chunks[:, batch_idx, :] + hid_c_batch = hid_c_chunks[:, batch_idx, :] + hid_states_batch = (hid_a_batch, hid_c_batch) + rnd_state_batch = None if rnd_env is None else rnd_chunks[batch_idx].transpose(0, 1) + + yield ( + obs_batch, prev_actions_batch, priv_obs_batch, actions_batch, + values_batch, advantages_batch, returns_batch, + log_probs_batch, old_mu_batch, old_sigma_batch, + hid_states_batch, masks_batch, rnd_state_batch + ) \ No newline at end of file