diff --git a/compose_rl/utils/__init__.py b/compose_rl/utils/__init__.py index 8f30bb9a..28dfb836 100644 --- a/compose_rl/utils/__init__.py +++ b/compose_rl/utils/__init__.py @@ -53,6 +53,15 @@ switch_left_to_right_padding, print_batch_shapes, ) +from compose_rl.utils.mlflow_utils import ( + get_mlflow_run_id, + get_valid_mlflow_experiment_name, + setup_mlflow, + artifact_exists_on_mlflow, + get_mlflow_absolute_path_for_save_folder, + get_mlflow_relative_path_for_save_folder, + validate_save_folder, +) __all__ = [ 'get_mb_load_balancing_loss', @@ -103,4 +112,11 @@ 'remove_boxed', 'ray_utils', 'print_batch_shapes', + 'get_mlflow_run_id', + 'get_valid_mlflow_experiment_name', + 'setup_mlflow', + 'artifact_exists_on_mlflow', + 'get_mlflow_absolute_path_for_save_folder', + 'get_mlflow_relative_path_for_save_folder', + 'validate_save_folder', ] diff --git a/compose_rl/utils/mlflow_utils.py b/compose_rl/utils/mlflow_utils.py new file mode 100644 index 00000000..7aacc022 --- /dev/null +++ b/compose_rl/utils/mlflow_utils.py @@ -0,0 +1,167 @@ +import os +from typing import Any, Optional + +import mlflow +import torch.distributed as dist + +from composer.utils import dist as composer_dist + + +def get_mlflow_run_id() -> Optional[str]: + return os.environ.get('MLFLOW_RUN_ID', None) + + +def get_mlflow_relative_path_for_save_folder(save_folder: str) -> str: + """Returns the relative path for the given save folder""" + return save_folder.lstrip('/') + + +def get_mlflow_absolute_path_for_save_folder(save_folder: str) -> str: + """Returns the mlflow artifact path for the given save folder""" + mlflow_prefix = 'dbfs:/databricks/mlflow-tracking/{mlflow_experiment_id}/{mlflow_run_id}' + mlflow_artifact_path = os.path.join(mlflow_prefix, 'artifacts', get_mlflow_relative_path_for_save_folder(save_folder)) + return mlflow_artifact_path + + +def validate_save_folder(save_folder: str) -> None: + """Validates the save folder""" + if save_folder.startswith("dbfs:/"): + raise ValueError(f"Using dbfs save_folder ({save_folder}) to store checkpoints is not supported. Please use a local save_folder.") + + +def artifact_exists_on_mlflow(artifact_path: str) -> bool: + """Return True if artifact_path exists (file or directory) for the run. + + Artifact path needs to be a relative path to the save folder. + """ + client = mlflow.MlflowClient() + run_id = get_mlflow_run_id() + assert run_id is not None, "Run ID must be set" + + # Walk down the path parts level-by-level + parent = "" + if artifact_path: + parts = artifact_path.split("/") + for i, part in enumerate(parts): + entries = {os.path.basename(fi.path): fi for fi in client.list_artifacts(run_id, parent)} + if part not in entries: + return False + fi = entries[part] + is_last = (i == len(parts) - 1) + if not is_last and not fi.is_dir: + # trying to descend into a file + return False + parent = fi.path # descend + + # If we got here, the path exists (root or found item). + return True + + +def get_valid_mlflow_experiment_name(config: Any) -> str: + """Fixes the experiment name to be an absolute path for mlflow. + + MLflow requires the experiment name to be an absolute path. + If the experiment name is not an absolute path, we turn it + into an absolute path. + """ + mlflow_experiment_name = config.loggers.mlflow.experiment_name + if mlflow_experiment_name.startswith('/'): + return mlflow_experiment_name + else: + from databricks.sdk import WorkspaceClient + return f'/Users/{WorkspaceClient().current_user.me().user_name}/{mlflow_experiment_name}' + + +def get_mlflow_run_name(config: Any) -> str: + """Gets the mlflow run name from the config. + + If the run name is not set in the config, it will return the COMPOSER_RUN_NAME environment variable + as this is set for interactive mode as well. + """ + try: + return config.loggers.mlflow.tags.run + except: + return os.environ['COMPOSER_RUN_NAME'] + +# NOTE: This doesn't work yet for a few reasons: +# 1. Downloading nested mlflow artifacts doesn't work correctly due to the MlflowObjectStore +# having issues. For instance, https://github.com/mosaicml/composer/blob/4ae29b1afec56ce2d54f6fa07a7f9578a0d364b0/composer/utils/object_store/mlflow_object_store.py#L465-L476 +# requires `tmp_path = os.path.join(tmp_dir, os.path.basename(artifact_path))` instead of what it currently +# does. By doing that, the symlink can be loaded correctly. +# 2. If save_folder is an absolute path (e.g. /tmp/checkpoints), the symlink will be created using this +# absolute path. This is not a valid symlink in mlflow so we need to do some os.path gymnastics to +# support absolute paths for save_folder. +# 3. We also need to support save_folder being a dbfs path eventually. +# Proposed Approach +# - Create an MlflowCheckpointActor (allowing us to set WORLD_SIZE=1) +# and create functions within that are based on MlflowObjectStore. +# that safely handle dbfs paths and absolute paths. +def get_file(path: str, destination: str, overwrite: bool = True): + """ + A helper function to get a file from mlflow. The existing mlflow utils code + uses dist operations which isn't supported in the RolloutAgent so this approach + works around that limitation. + """ + from composer.utils.file_helpers import parse_uri, get_file as composer_get_file + from composer.utils.object_store import MLFlowObjectStore + backend, _, path = parse_uri(path) + assert backend == 'dbfs', "Only dbfs paths are supported" + object_store = MLFlowObjectStore(path) + composer_get_file(path, destination, object_store, overwrite) + + +def setup_mlflow(config: Any): + """ + Sets up mlflow for the current process. + + This function should be called before any other mlflow functions are called. + It will set the mlflow experiment and run. It will create both if they don't exist. + It will set all environment variables needed for mlflow. + """ + dist.init_process_group(backend='gloo') + mlflow.set_tracking_uri('databricks') + + mlflow_experiment_name = get_valid_mlflow_experiment_name(config) + setattr(config.loggers.mlflow, 'experiment_name', mlflow_experiment_name) + mlflow_run_name = get_mlflow_run_name(config) + setattr(config.loggers.mlflow.tags, 'run', mlflow_run_name) + + # get mlflow experiment if it exists, otherwise create it and set it to all ranks. + experiment_id = None + if composer_dist.get_global_rank() == 0: + experiment = mlflow.get_experiment_by_name(mlflow_experiment_name) + if experiment is None: + experiment_id = mlflow.create_experiment(mlflow_experiment_name) + else: + experiment_id = experiment.experiment_id + experiment_id_broadcast_list = [experiment_id] + composer_dist.broadcast_object_list(experiment_id_broadcast_list, src=0) + experiment_id = experiment_id_broadcast_list[0] + + mlflow.set_experiment(experiment_id=experiment_id) + + # get mlflow run if it exists and we are autoresuming, otherwise create it and set it to all ranks. + run_id = None + if composer_dist.get_global_rank() == 0: + existing_runs = mlflow.search_runs( + experiment_ids=[experiment_id], + filter_string=f'tags.run_name = "{mlflow_run_name}"', + output_format='list', + ) if config.autoresume else [] + if len(existing_runs) > 0: + run_id = existing_runs[0].info.run_id + print(f'Resuming mlflow run with run id: {run_id}') + else: + run_id = mlflow.start_run(run_name=mlflow_run_name).info.run_id + print(f'Creating new mlflow run with run id: {run_id}') + run_id_broadcast_list = [run_id] + composer_dist.broadcast_object_list(run_id_broadcast_list, src=0) + run_id = run_id_broadcast_list[0] + + # set all the right enviornment variables + assert run_id is not None and experiment_id is not None, "Run ID and experiment ID must be set" + os.environ['MLFLOW_RUN_ID'] = run_id + os.environ['MLFLOW_EXPERIMENT_ID'] = experiment_id + os.environ['MLFLOW_TRACKING_URI'] = 'databricks' + + dist.destroy_process_group() diff --git a/test_single_controller_ppo.py b/test_single_controller_ppo.py index ad756d4f..1aa9492c 100644 --- a/test_single_controller_ppo.py +++ b/test_single_controller_ppo.py @@ -22,13 +22,14 @@ from typing import Any, Optional from composer.loggers import MLFlowLogger +import mlflow import ray import torch import torch.distributed as dist from composer import Trainer from composer.core import get_precision_context from composer.optim import DecoupledAdamW -from composer.utils import dist as composer_dist +from composer.utils import create_symlink_file, dist as composer_dist from llmfoundry.data import build_dataloader from omegaconf import OmegaConf as om from transformers import AutoTokenizer @@ -44,6 +45,15 @@ create_vllm_engines, _vllm_generate, ) +from compose_rl.utils.mlflow_utils import ( + get_mlflow_run_id, + setup_mlflow, + artifact_exists_on_mlflow, + get_mlflow_absolute_path_for_save_folder, + get_mlflow_relative_path_for_save_folder, + validate_save_folder, + get_file, +) from compose_rl.utils.ray_utils import start_ray_server, uninstall_megablocks_if_exists from compose_rl.controllers import BaseDistributedGPUActor, SPMDActorGroup from compose_rl.controllers.buffer import Buffer @@ -121,6 +131,8 @@ def build_train_config(self, config: Any): variables = om.to_container(self.config.variables, resolve=True) algorithm_config = self.config.algorithms + mlflow_save_folder = get_mlflow_absolute_path_for_save_folder(self.config.save_folder) + self.train_config = { 'seed': self.config.seed, 'model': self.model_config, @@ -131,7 +143,7 @@ def build_train_config(self, config: Any): 'global_train_batch_size': self.device_train_batch_size * self.world_size, 'device_train_batch_size': self.device_train_batch_size, 'device_train_microbatch_size': self.device_train_batch_size, - 'save_folder': self.config.save_folder, + 'save_folder': mlflow_save_folder, 'log_config': self.config.log_config, 'max_seq_len': self.max_seq_len, 'python_log_level': self.config.python_log_level, @@ -186,6 +198,8 @@ def build_ppo_trainer(self): dummy_distributed_sampler = torch.utils.data.distributed.DistributedSampler(dummy_dataset) dummy_dataloader = torch.utils.data.DataLoader(dummy_dataset, sampler=dummy_distributed_sampler) + # TODO: We might be able to skip part of the setup here as some mlflow + # environment variables are set in the setup_mlflow function mlflow_logger = MLFlowLogger( experiment_name=self.config.loggers.mlflow.experiment_name, run_name=self.config.loggers.mlflow.tags.run, @@ -206,6 +220,8 @@ def build_ppo_trainer(self): SpeedMonitor(window_size=10), ] + mlflow_save_folder = get_mlflow_absolute_path_for_save_folder(self.config.save_folder) + self.ppo_trainer = Trainer( model=model, optimizers=optimizer, @@ -217,7 +233,7 @@ def build_ppo_trainer(self): loggers=[mlflow_logger], device_train_microbatch_size=self.config.device_train_microbatch_size, load_path=self.ref_path, - save_folder=self.config.save_folder, + save_folder=mlflow_save_folder, save_interval=self.config.save_interval, autoresume=self.config.autoresume, ) @@ -335,6 +351,8 @@ def train_1_iter(self): async def run(self, num_iterations: int, experience_buffer: 'ExperienceBuffer', parameter_buffer: 'ParameterBuffer', inference_server: 'InferenceServer', lock: asyncio.Lock, rollout_semaphore: asyncio.Semaphore, eval_semaphore: asyncio.Semaphore): # the overall design rn is we have a async def run function for each of the subcontroller that is responsible for async primitives but leave the rest of the logic to be sync function and use # asyncio.to_thread to bridge the async and sync world + + # TODO: Load experience buffer from checkpoints, this will make checkpointing work for async for _ in range(num_iterations): # Simple example of adding elements to the experience buffer # Populate the train actor group with the rollouts and then train @@ -455,24 +473,36 @@ def __init__( self.tokenizer_pad_token_id = ray.get(self.streaming_dataset_actor.get_tokenizer_pad_token_id.remote()) self.prompt_handler_config = ray.get(self.streaming_dataset_actor.get_prompt_handler_config.remote()) self.max_gen_len = self.prompt_handler_config['max_gen_len'] - - # Load iter_num from the checkpoint - self.save_folder = os.path.join(config.save_folder, 'RolloutAgent') - self.iter_num = 0 - # Load the latest checkpoint - self.latest_checkpoint = os.path.join(self.save_folder, 'latest.symlink') - - if config.autoresume and os.path.exists(self.latest_checkpoint): + self.local_save_folder = os.path.join(config.save_folder, 'RolloutAgent') + # We need to format the full path correctly for MlflowObjectStore to be created. + self.mlflow_absolute_save_folder = get_mlflow_absolute_path_for_save_folder(self.local_save_folder).format( + mlflow_experiment_id=os.environ['MLFLOW_EXPERIMENT_ID'], + mlflow_run_id=os.environ['MLFLOW_RUN_ID'], + ) + self.mlflow_relative_save_folder = get_mlflow_relative_path_for_save_folder(self.local_save_folder) + + # Load the latest checkpoint if we are autoresuming. + # Note that since we are checking if the checkpoint exists with + # mlflow.client.list_artifacts, we need to use the relative path to + # the checkpoint (i.e. not include dbfs:/.../{mlflow_experiment_id}/{mlflow_run_id} + # in the path). Note: we don't support UC Volumes for storage otherwise + # _artifact_exists would not work. + self.latest_checkpoint_path = os.path.join(self.local_save_folder, 'latest_rollout_agent.symlink') + self.mlflow_latest_checkpoint_absolute_path = os.path.join(self.mlflow_absolute_save_folder, 'latest_rollout_agent.symlink') + self.mlflow_latest_checkpoint_relative_path = os.path.join(self.mlflow_relative_save_folder, 'latest_rollout_agent.symlink') + + if config.autoresume and artifact_exists_on_mlflow(self.mlflow_latest_checkpoint_relative_path): print(f'Autoresuming from checkpoint for RolloutAgent.') - with open(self.latest_checkpoint, 'rb') as f: + get_file(self.mlflow_latest_checkpoint_absolute_path, self.latest_checkpoint_path, overwrite=True) + print(f'Got autoresume checkpoint from mlflow: {self.latest_checkpoint_path}') + with open(self.latest_checkpoint_path, 'rb') as f: checkpoint = pickle.load(f) self.iter_num = checkpoint['iter_num'] print(f'Loading streaming dataloader state dict for RolloutAgent.', checkpoint['streaming_dataloader']) self.streaming_dataset_actor.load_dataloader_state_dict.remote(checkpoint['streaming_dataloader']) - def get_next_iter_rollouts(self): """ Gets the next rollouts from the inference server. @@ -510,15 +540,15 @@ def get_next_iter_rollouts(self): processed_sequences = torch.cat([all_prompts, padded_responses], dim=-1) iter_data['sequences'] = processed_sequences - save_folder_iter = os.path.join(self.save_folder, f'iter_{self.iter_num}') - checkpoint_path = os.path.join(save_folder_iter, 'checkpoint.pt') + save_folder_for_curr_iter = os.path.join(self.local_save_folder, f'iter_{self.iter_num}') + checkpoint_path = os.path.join(save_folder_for_curr_iter, 'checkpoint.pt') self.iter_num += 1 streaming_dataloader_state_dict = ray.get(self.streaming_dataset_actor.get_dataloader_state_dict.remote()) print(f'Streaming dataloader state dict for RolloutAgent.', streaming_dataloader_state_dict) # make sure that the folder path can exist - os.makedirs(save_folder_iter, exist_ok=True) + os.makedirs(save_folder_for_curr_iter, exist_ok=True) with open(checkpoint_path, 'wb') as f: pickle.dump({ 'iter_data': iter_data, @@ -526,9 +556,23 @@ def get_next_iter_rollouts(self): 'streaming_dataloader': streaming_dataloader_state_dict, }, f) - if os.path.exists(self.latest_checkpoint): - os.remove(self.latest_checkpoint) - os.symlink(checkpoint_path, self.latest_checkpoint) + # log the checkpoint to mlflow + mlflow.log_artifact( + checkpoint_path, + get_mlflow_relative_path_for_save_folder(save_folder_for_curr_iter), + run_id=get_mlflow_run_id(), + ) + + if os.path.exists(self.latest_checkpoint_path): + os.remove(self.latest_checkpoint_path) + create_symlink_file(checkpoint_path, self.latest_checkpoint_path) + + # log the latest checkpoint to mlflow + mlflow.log_artifact( + self.latest_checkpoint_path, + self.mlflow_relative_save_folder, + run_id=get_mlflow_run_id(), + ) return iter_data async def run(self, num_iterations: int, experience_buffer: 'ExperienceBuffer', lock: asyncio.Lock, rollout_semaphore: asyncio.Semaphore): @@ -740,6 +784,9 @@ def _run_single_controller_ppo( # Disable setting CUDA_VISIBLE_DEVICES by ray, we will set it manually os.environ['RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES'] = '1' + validate_save_folder(config.save_folder) + setup_mlflow(config) + with start_ray_server() as _address: # only rank 0 is the master controller if dist.get_rank() == 0: diff --git a/yamls/single-controller-grpo-workflow.yaml b/yamls/single-controller-grpo-workflow.yaml index 699ba5a8..2ed66451 100644 --- a/yamls/single-controller-grpo-workflow.yaml +++ b/yamls/single-controller-grpo-workflow.yaml @@ -157,7 +157,7 @@ parameters: verbose: false cpu_offload: false mixed_precision: PURE - state_dict_type: sharded + state_dict_type: full # TODO: Sharded state dicts are having issues. Need to investigate further. use_orig_params: true forward_prefetch: true backward_prefetch: BACKWARD_PRE