diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 6e6869bd2..c0f7f0717 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -585,6 +585,9 @@ def grpo_train( ) timeout.start_iterations() + logger.log_config_dict(master_config, "train_config.json") + logger.log_config_dict(master_config, "train_config.yaml") + NEED_REFIT = True # If policy_generation is None, use the policy as the generation interface (megatron framework backend) if policy_generation is None: @@ -1174,6 +1177,10 @@ def async_grpo_train( from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer timer = Timer() + + logger.log_config_dict(master_config, "train_config.json") + logger.log_config_dict(master_config, "train_config.yaml") + NEED_REFIT = True # Setup generation interface diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 39e593e30..491de5185 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -362,6 +362,9 @@ def sft_train( ) timeout.start_iterations() + logger.log_config_dict(master_config, "train_config.json") + logger.log_config_dict(master_config, "train_config.yaml") + if sft_save_state is None: sft_save_state = _default_sft_save_state() current_epoch = 0 diff --git a/nemo_rl/utils/logger.py b/nemo_rl/utils/logger.py index fa76f1295..c97f9274b 100644 --- a/nemo_rl/utils/logger.py +++ b/nemo_rl/utils/logger.py @@ -31,6 +31,7 @@ import swanlab import torch import wandb +import yaml from matplotlib import pyplot as plt from prometheus_client.parser import text_string_to_metric_families from prometheus_client.samples import Sample @@ -1016,6 +1017,30 @@ def log_plot_token_mult_prob_error( plt.close(fig) + def log_config_dict(self, config: dict, filename: str): + """Log a config dictionary to a json or yaml file. + + Args: + config: Dict to log + filename: Filename to log to (within the log directory) + """ + assert isinstance(config, dict) + + # Create full path within log directory + filepath = os.path.join(self.base_log_dir, filename) + os.makedirs(os.path.dirname(filepath), exist_ok=True) + + if filepath.endswith(".json"): + with open(filepath, "w") as f: + print(json.dumps(config, indent=2), end="", file=f) + elif filepath.endswith(".yaml") or filepath.endswith(".yml"): + with open(filepath, "w") as f: + yaml.safe_dump(config, f, sort_keys=False) + else: + raise NotImplementedError + + print(f"Logged config dict to {filepath!r}", flush=True) + def __del__(self) -> None: """Clean up resources when the logger is destroyed.""" if self.gpu_monitor: