Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions compose_rl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@
switch_left_to_right_padding,
print_batch_shapes,
)
from compose_rl.utils.mlflow_utils import (
get_mlflow_run_id,
get_valid_mlflow_experiment_name,
setup_mlflow,
artifact_exists_on_mlflow,
get_mlflow_absolute_path_for_save_folder,
get_mlflow_relative_path_for_save_folder,
validate_save_folder,
)

__all__ = [
'get_mb_load_balancing_loss',
Expand Down Expand Up @@ -103,4 +112,11 @@
'remove_boxed',
'ray_utils',
'print_batch_shapes',
'get_mlflow_run_id',
'get_valid_mlflow_experiment_name',
'setup_mlflow',
'artifact_exists_on_mlflow',
'get_mlflow_absolute_path_for_save_folder',
'get_mlflow_relative_path_for_save_folder',
'validate_save_folder',
]
167 changes: 167 additions & 0 deletions compose_rl/utils/mlflow_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import os
from typing import Any, Optional

import mlflow
import torch.distributed as dist

from composer.utils import dist as composer_dist


def get_mlflow_run_id() -> Optional[str]:
return os.environ.get('MLFLOW_RUN_ID', None)


def get_mlflow_relative_path_for_save_folder(save_folder: str) -> str:
"""Returns the relative path for the given save folder"""
return save_folder.lstrip('/')


def get_mlflow_absolute_path_for_save_folder(save_folder: str) -> str:
"""Returns the mlflow artifact path for the given save folder"""
mlflow_prefix = 'dbfs:/databricks/mlflow-tracking/{mlflow_experiment_id}/{mlflow_run_id}'
mlflow_artifact_path = os.path.join(mlflow_prefix, 'artifacts', get_mlflow_relative_path_for_save_folder(save_folder))
return mlflow_artifact_path


def validate_save_folder(save_folder: str) -> None:
"""Validates the save folder"""
if save_folder.startswith("dbfs:/"):
raise ValueError(f"Using dbfs save_folder ({save_folder}) to store checkpoints is not supported. Please use a local save_folder.")


def artifact_exists_on_mlflow(artifact_path: str) -> bool:
"""Return True if artifact_path exists (file or directory) for the run.

Artifact path needs to be a relative path to the save folder.
"""
client = mlflow.MlflowClient()
run_id = get_mlflow_run_id()
assert run_id is not None, "Run ID must be set"

# Walk down the path parts level-by-level
parent = ""
if artifact_path:
parts = artifact_path.split("/")
for i, part in enumerate(parts):
entries = {os.path.basename(fi.path): fi for fi in client.list_artifacts(run_id, parent)}
if part not in entries:
return False
fi = entries[part]
is_last = (i == len(parts) - 1)
if not is_last and not fi.is_dir:
# trying to descend into a file
return False
parent = fi.path # descend

# If we got here, the path exists (root or found item).
return True


def get_valid_mlflow_experiment_name(config: Any) -> str:
"""Fixes the experiment name to be an absolute path for mlflow.

MLflow requires the experiment name to be an absolute path.
If the experiment name is not an absolute path, we turn it
into an absolute path.
"""
mlflow_experiment_name = config.loggers.mlflow.experiment_name
if mlflow_experiment_name.startswith('/'):
return mlflow_experiment_name
else:
from databricks.sdk import WorkspaceClient
return f'/Users/{WorkspaceClient().current_user.me().user_name}/{mlflow_experiment_name}'


def get_mlflow_run_name(config: Any) -> str:
"""Gets the mlflow run name from the config.

If the run name is not set in the config, it will return the COMPOSER_RUN_NAME environment variable
as this is set for interactive mode as well.
"""
try:
return config.loggers.mlflow.tags.run
except:
return os.environ['COMPOSER_RUN_NAME']

# NOTE: This doesn't work yet for a few reasons:
# 1. Downloading nested mlflow artifacts doesn't work correctly due to the MlflowObjectStore
# having issues. For instance, https://github.com/mosaicml/composer/blob/4ae29b1afec56ce2d54f6fa07a7f9578a0d364b0/composer/utils/object_store/mlflow_object_store.py#L465-L476
# requires `tmp_path = os.path.join(tmp_dir, os.path.basename(artifact_path))` instead of what it currently
# does. By doing that, the symlink can be loaded correctly.
# 2. If save_folder is an absolute path (e.g. /tmp/checkpoints), the symlink will be created using this
# absolute path. This is not a valid symlink in mlflow so we need to do some os.path gymnastics to
# support absolute paths for save_folder.
# 3. We also need to support save_folder being a dbfs path eventually.
# Proposed Approach
# - Create an MlflowCheckpointActor (allowing us to set WORLD_SIZE=1)
# and create functions within that are based on MlflowObjectStore.
# that safely handle dbfs paths and absolute paths.
Comment on lines +86 to +98
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR is NOT ready for review since there's a lot of os.path gymnastics that we are doing for supporting saving things to mlflow artifacts. I am going to keep this PR on hold for now until we have time to think of a more resilient solution that addresses the problems here. (cc: @irenedea @bowenyang008)

def get_file(path: str, destination: str, overwrite: bool = True):
"""
A helper function to get a file from mlflow. The existing mlflow utils code
uses dist operations which isn't supported in the RolloutAgent so this approach
works around that limitation.
"""
from composer.utils.file_helpers import parse_uri, get_file as composer_get_file
from composer.utils.object_store import MLFlowObjectStore
backend, _, path = parse_uri(path)
assert backend == 'dbfs', "Only dbfs paths are supported"
object_store = MLFlowObjectStore(path)
composer_get_file(path, destination, object_store, overwrite)


def setup_mlflow(config: Any):
"""
Sets up mlflow for the current process.

This function should be called before any other mlflow functions are called.
It will set the mlflow experiment and run. It will create both if they don't exist.
It will set all environment variables needed for mlflow.
"""
dist.init_process_group(backend='gloo')
mlflow.set_tracking_uri('databricks')

mlflow_experiment_name = get_valid_mlflow_experiment_name(config)
setattr(config.loggers.mlflow, 'experiment_name', mlflow_experiment_name)
mlflow_run_name = get_mlflow_run_name(config)
setattr(config.loggers.mlflow.tags, 'run', mlflow_run_name)

# get mlflow experiment if it exists, otherwise create it and set it to all ranks.
experiment_id = None
if composer_dist.get_global_rank() == 0:
experiment = mlflow.get_experiment_by_name(mlflow_experiment_name)
if experiment is None:
experiment_id = mlflow.create_experiment(mlflow_experiment_name)
else:
experiment_id = experiment.experiment_id
experiment_id_broadcast_list = [experiment_id]
composer_dist.broadcast_object_list(experiment_id_broadcast_list, src=0)
experiment_id = experiment_id_broadcast_list[0]

mlflow.set_experiment(experiment_id=experiment_id)

# get mlflow run if it exists and we are autoresuming, otherwise create it and set it to all ranks.
run_id = None
if composer_dist.get_global_rank() == 0:
existing_runs = mlflow.search_runs(
experiment_ids=[experiment_id],
filter_string=f'tags.run_name = "{mlflow_run_name}"',
output_format='list',
) if config.autoresume else []
if len(existing_runs) > 0:
run_id = existing_runs[0].info.run_id
print(f'Resuming mlflow run with run id: {run_id}')
else:
run_id = mlflow.start_run(run_name=mlflow_run_name).info.run_id
print(f'Creating new mlflow run with run id: {run_id}')
run_id_broadcast_list = [run_id]
composer_dist.broadcast_object_list(run_id_broadcast_list, src=0)
run_id = run_id_broadcast_list[0]

# set all the right enviornment variables
assert run_id is not None and experiment_id is not None, "Run ID and experiment ID must be set"
os.environ['MLFLOW_RUN_ID'] = run_id
os.environ['MLFLOW_EXPERIMENT_ID'] = experiment_id
os.environ['MLFLOW_TRACKING_URI'] = 'databricks'

dist.destroy_process_group()
85 changes: 66 additions & 19 deletions test_single_controller_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
from typing import Any, Optional

from composer.loggers import MLFlowLogger
import mlflow
import ray
import torch
import torch.distributed as dist
from composer import Trainer
from composer.core import get_precision_context
from composer.optim import DecoupledAdamW
from composer.utils import dist as composer_dist
from composer.utils import create_symlink_file, dist as composer_dist
from llmfoundry.data import build_dataloader
from omegaconf import OmegaConf as om
from transformers import AutoTokenizer
Expand All @@ -44,6 +45,15 @@
create_vllm_engines,
_vllm_generate,
)
from compose_rl.utils.mlflow_utils import (
get_mlflow_run_id,
setup_mlflow,
artifact_exists_on_mlflow,
get_mlflow_absolute_path_for_save_folder,
get_mlflow_relative_path_for_save_folder,
validate_save_folder,
get_file,
)
from compose_rl.utils.ray_utils import start_ray_server, uninstall_megablocks_if_exists
from compose_rl.controllers import BaseDistributedGPUActor, SPMDActorGroup
from compose_rl.controllers.buffer import Buffer
Expand Down Expand Up @@ -121,6 +131,8 @@ def build_train_config(self, config: Any):
variables = om.to_container(self.config.variables, resolve=True)
algorithm_config = self.config.algorithms

mlflow_save_folder = get_mlflow_absolute_path_for_save_folder(self.config.save_folder)

self.train_config = {
'seed': self.config.seed,
'model': self.model_config,
Expand All @@ -131,7 +143,7 @@ def build_train_config(self, config: Any):
'global_train_batch_size': self.device_train_batch_size * self.world_size,
'device_train_batch_size': self.device_train_batch_size,
'device_train_microbatch_size': self.device_train_batch_size,
'save_folder': self.config.save_folder,
'save_folder': mlflow_save_folder,
'log_config': self.config.log_config,
'max_seq_len': self.max_seq_len,
'python_log_level': self.config.python_log_level,
Expand Down Expand Up @@ -186,6 +198,8 @@ def build_ppo_trainer(self):
dummy_distributed_sampler = torch.utils.data.distributed.DistributedSampler(dummy_dataset)
dummy_dataloader = torch.utils.data.DataLoader(dummy_dataset, sampler=dummy_distributed_sampler)

# TODO: We might be able to skip part of the setup here as some mlflow
# environment variables are set in the setup_mlflow function
mlflow_logger = MLFlowLogger(
experiment_name=self.config.loggers.mlflow.experiment_name,
run_name=self.config.loggers.mlflow.tags.run,
Expand All @@ -206,6 +220,8 @@ def build_ppo_trainer(self):
SpeedMonitor(window_size=10),
]

mlflow_save_folder = get_mlflow_absolute_path_for_save_folder(self.config.save_folder)

self.ppo_trainer = Trainer(
model=model,
optimizers=optimizer,
Expand All @@ -217,7 +233,7 @@ def build_ppo_trainer(self):
loggers=[mlflow_logger],
device_train_microbatch_size=self.config.device_train_microbatch_size,
load_path=self.ref_path,
save_folder=self.config.save_folder,
save_folder=mlflow_save_folder,
save_interval=self.config.save_interval,
autoresume=self.config.autoresume,
)
Expand Down Expand Up @@ -335,6 +351,8 @@ def train_1_iter(self):
async def run(self, num_iterations: int, experience_buffer: 'ExperienceBuffer', parameter_buffer: 'ParameterBuffer', inference_server: 'InferenceServer', lock: asyncio.Lock, rollout_semaphore: asyncio.Semaphore, eval_semaphore: asyncio.Semaphore):
# the overall design rn is we have a async def run function for each of the subcontroller that is responsible for async primitives but leave the rest of the logic to be sync function and use
# asyncio.to_thread to bridge the async and sync world

# TODO: Load experience buffer from checkpoints, this will make checkpointing work for async
for _ in range(num_iterations):
# Simple example of adding elements to the experience buffer
# Populate the train actor group with the rollouts and then train
Expand Down Expand Up @@ -455,24 +473,36 @@ def __init__(
self.tokenizer_pad_token_id = ray.get(self.streaming_dataset_actor.get_tokenizer_pad_token_id.remote())
self.prompt_handler_config = ray.get(self.streaming_dataset_actor.get_prompt_handler_config.remote())
self.max_gen_len = self.prompt_handler_config['max_gen_len']

# Load iter_num from the checkpoint
self.save_folder = os.path.join(config.save_folder, 'RolloutAgent')

self.iter_num = 0

# Load the latest checkpoint
self.latest_checkpoint = os.path.join(self.save_folder, 'latest.symlink')

if config.autoresume and os.path.exists(self.latest_checkpoint):
self.local_save_folder = os.path.join(config.save_folder, 'RolloutAgent')
# We need to format the full path correctly for MlflowObjectStore to be created.
self.mlflow_absolute_save_folder = get_mlflow_absolute_path_for_save_folder(self.local_save_folder).format(
mlflow_experiment_id=os.environ['MLFLOW_EXPERIMENT_ID'],
mlflow_run_id=os.environ['MLFLOW_RUN_ID'],
)
self.mlflow_relative_save_folder = get_mlflow_relative_path_for_save_folder(self.local_save_folder)

# Load the latest checkpoint if we are autoresuming.
# Note that since we are checking if the checkpoint exists with
# mlflow.client.list_artifacts, we need to use the relative path to
# the checkpoint (i.e. not include dbfs:/.../{mlflow_experiment_id}/{mlflow_run_id}
# in the path). Note: we don't support UC Volumes for storage otherwise
# _artifact_exists would not work.
self.latest_checkpoint_path = os.path.join(self.local_save_folder, 'latest_rollout_agent.symlink')
self.mlflow_latest_checkpoint_absolute_path = os.path.join(self.mlflow_absolute_save_folder, 'latest_rollout_agent.symlink')
self.mlflow_latest_checkpoint_relative_path = os.path.join(self.mlflow_relative_save_folder, 'latest_rollout_agent.symlink')

if config.autoresume and artifact_exists_on_mlflow(self.mlflow_latest_checkpoint_relative_path):
print(f'Autoresuming from checkpoint for RolloutAgent.')
with open(self.latest_checkpoint, 'rb') as f:
get_file(self.mlflow_latest_checkpoint_absolute_path, self.latest_checkpoint_path, overwrite=True)
print(f'Got autoresume checkpoint from mlflow: {self.latest_checkpoint_path}')
with open(self.latest_checkpoint_path, 'rb') as f:
checkpoint = pickle.load(f)
self.iter_num = checkpoint['iter_num']
print(f'Loading streaming dataloader state dict for RolloutAgent.', checkpoint['streaming_dataloader'])
self.streaming_dataset_actor.load_dataloader_state_dict.remote(checkpoint['streaming_dataloader'])


def get_next_iter_rollouts(self):
"""
Gets the next rollouts from the inference server.
Expand Down Expand Up @@ -510,25 +540,39 @@ def get_next_iter_rollouts(self):
processed_sequences = torch.cat([all_prompts, padded_responses], dim=-1)
iter_data['sequences'] = processed_sequences

save_folder_iter = os.path.join(self.save_folder, f'iter_{self.iter_num}')
checkpoint_path = os.path.join(save_folder_iter, 'checkpoint.pt')
save_folder_for_curr_iter = os.path.join(self.local_save_folder, f'iter_{self.iter_num}')
checkpoint_path = os.path.join(save_folder_for_curr_iter, 'checkpoint.pt')
self.iter_num += 1

streaming_dataloader_state_dict = ray.get(self.streaming_dataset_actor.get_dataloader_state_dict.remote())
print(f'Streaming dataloader state dict for RolloutAgent.', streaming_dataloader_state_dict)

# make sure that the folder path can exist
os.makedirs(save_folder_iter, exist_ok=True)
os.makedirs(save_folder_for_curr_iter, exist_ok=True)
with open(checkpoint_path, 'wb') as f:
pickle.dump({
'iter_data': iter_data,
'iter_num': self.iter_num,
'streaming_dataloader': streaming_dataloader_state_dict,
}, f)

if os.path.exists(self.latest_checkpoint):
os.remove(self.latest_checkpoint)
os.symlink(checkpoint_path, self.latest_checkpoint)
# log the checkpoint to mlflow
mlflow.log_artifact(
checkpoint_path,
get_mlflow_relative_path_for_save_folder(save_folder_for_curr_iter),
run_id=get_mlflow_run_id(),
)

if os.path.exists(self.latest_checkpoint_path):
os.remove(self.latest_checkpoint_path)
create_symlink_file(checkpoint_path, self.latest_checkpoint_path)

# log the latest checkpoint to mlflow
mlflow.log_artifact(
self.latest_checkpoint_path,
self.mlflow_relative_save_folder,
run_id=get_mlflow_run_id(),
)
return iter_data

async def run(self, num_iterations: int, experience_buffer: 'ExperienceBuffer', lock: asyncio.Lock, rollout_semaphore: asyncio.Semaphore):
Expand Down Expand Up @@ -740,6 +784,9 @@ def _run_single_controller_ppo(
# Disable setting CUDA_VISIBLE_DEVICES by ray, we will set it manually
os.environ['RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES'] = '1'

validate_save_folder(config.save_folder)
setup_mlflow(config)

with start_ray_server() as _address:
# only rank 0 is the master controller
if dist.get_rank() == 0:
Expand Down
2 changes: 1 addition & 1 deletion yamls/single-controller-grpo-workflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ parameters:
verbose: false
cpu_offload: false
mixed_precision: PURE
state_dict_type: sharded
state_dict_type: full # TODO: Sharded state dicts are having issues. Need to investigate further.
use_orig_params: true
forward_prefetch: true
backward_prefetch: BACKWARD_PRE
Expand Down