Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,9 @@ def grpo_train(
)
timeout.start_iterations()

logger.log_config_dict(master_config, "train_config.json")
logger.log_config_dict(master_config, "train_config.yaml")

NEED_REFIT = True
# If policy_generation is None, use the policy as the generation interface (megatron framework backend)
if policy_generation is None:
Expand Down Expand Up @@ -1174,6 +1177,10 @@ def async_grpo_train(
from nemo_rl.algorithms.async_utils import AsyncTrajectoryCollector, ReplayBuffer

timer = Timer()

logger.log_config_dict(master_config, "train_config.json")
logger.log_config_dict(master_config, "train_config.yaml")

NEED_REFIT = True

# Setup generation interface
Expand Down
3 changes: 3 additions & 0 deletions nemo_rl/algorithms/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,9 @@ def sft_train(
)
timeout.start_iterations()

logger.log_config_dict(master_config, "train_config.json")
logger.log_config_dict(master_config, "train_config.yaml")

if sft_save_state is None:
sft_save_state = _default_sft_save_state()
current_epoch = 0
Expand Down
25 changes: 25 additions & 0 deletions nemo_rl/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import swanlab
import torch
import wandb
import yaml
from matplotlib import pyplot as plt
from prometheus_client.parser import text_string_to_metric_families
from prometheus_client.samples import Sample
Expand Down Expand Up @@ -1016,6 +1017,30 @@ def log_plot_token_mult_prob_error(

plt.close(fig)

def log_config_dict(self, config: dict, filename: str):
"""Log a config dictionary to a json or yaml file.

Args:
config: Dict to log
filename: Filename to log to (within the log directory)
"""
assert isinstance(config, dict)

# Create full path within log directory
filepath = os.path.join(self.base_log_dir, filename)
os.makedirs(os.path.dirname(filepath), exist_ok=True)

if filepath.endswith(".json"):
with open(filepath, "w") as f:
print(json.dumps(config, indent=2), end="", file=f)
elif filepath.endswith(".yaml") or filepath.endswith(".yml"):
with open(filepath, "w") as f:
yaml.safe_dump(config, f, sort_keys=False)
else:
raise NotImplementedError

print(f"Logged config dict to {filepath!r}", flush=True)

def __del__(self) -> None:
"""Clean up resources when the logger is destroyed."""
if self.gpu_monitor:
Expand Down
Loading