Skip to content

Commit 02df60c

Browse files
irenedearithwik-db
authored andcommitted
setupmlflow globally
works comment
1 parent 69538bd commit 02df60c

File tree

1 file changed

+94
-6
lines changed

1 file changed

+94
-6
lines changed

test_single_controller_ppo.py

Lines changed: 94 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@
2222
from typing import Any, Optional
2323

2424
from composer.loggers import MLFlowLogger
25+
import mlflow
2526
import ray
2627
import torch
2728
import torch.distributed as dist
2829
from composer import Trainer
2930
from composer.core import get_precision_context
3031
from composer.optim import DecoupledAdamW
31-
from composer.utils import dist as composer_dist
32+
from composer.utils import create_symlink_file, dist as composer_dist, get_file
3233
from llmfoundry.data import build_dataloader
3334
from omegaconf import OmegaConf as om
3435
from transformers import AutoTokenizer
@@ -48,7 +49,10 @@
4849
from compose_rl.controllers import BaseDistributedGPUActor, SPMDActorGroup
4950
from compose_rl.controllers.buffer import Buffer
5051
from compose_rl.algorithms.online.callback_utils import preprocess_batches
52+
from databricks.sdk import WorkspaceClient
5153

54+
MLFLOW_RUN_NAME=os.environ['COMPOSER_RUN_NAME'] # SHOULD BE SET BY MCLI
55+
MLFLOW_EXPERIMENT_NAME=f'/Users/{WorkspaceClient().current_user.me().user_name}/test_single_controller'
5256

5357
@contextmanager
5458
def time_it(name: str):
@@ -129,7 +133,7 @@ def build_train_config(self, config: Any):
129133
'global_train_batch_size': self.device_train_batch_size * self.world_size,
130134
'device_train_batch_size': self.device_train_batch_size,
131135
'device_train_microbatch_size': self.device_train_batch_size,
132-
'save_folder': self.config.save_folder,
136+
'save_folder': os.path.join('dbfs:/databricks/mlflow-tracking/{mlflow_experiment_id}/{mlflow_run_id}', self.config.save_folder),
133137
'log_config': self.config.log_config,
134138
'max_seq_len': self.max_seq_len,
135139
'python_log_level': self.config.python_log_level,
@@ -220,7 +224,7 @@ def build_ppo_trainer(self):
220224
loggers=[mlflow_logger],
221225
device_train_microbatch_size=self.config.device_train_microbatch_size,
222226
load_path=self.ref_path,
223-
save_folder=self.config.save_folder,
227+
save_folder=os.path.join('dbfs:/databricks/mlflow-tracking/{mlflow_experiment_id}/{mlflow_run_id}', self.config.save_folder),
224228
save_interval='1iter',
225229
autoresume=self.config.autoresume,
226230
)
@@ -338,6 +342,8 @@ def train_1_iter(self):
338342
async def run(self, num_iterations: int, experience_buffer: 'ExperienceBuffer', parameter_buffer: 'ParameterBuffer', inference_server: 'InferenceServer', lock: asyncio.Lock, rollout_semaphore: asyncio.Semaphore, eval_semaphore: asyncio.Semaphore):
339343
# the overall design rn is we have a async def run function for each of the subcontroller that is responsible for async primitives but leave the rest of the logic to be sync function and use
340344
# asyncio.to_thread to bridge the async and sync world
345+
346+
# TODO: Load experience buffer from checkpoints, this will make checkpointing work for async
341347
for _ in range(num_iterations):
342348
# Simple example of adding elements to the experience buffer
343349
# Populate the train actor group with the rollouts and then train
@@ -465,10 +471,12 @@ def __init__(
465471
self.iter_num = 0
466472

467473
# Load the latest checkpoint
468-
self.latest_checkpoint = os.path.join(self.save_folder, 'latest.symlink')
469474

470-
if config.autoresume and os.path.exists(self.latest_checkpoint):
475+
self.latest_checkpoint = os.path.join(self.save_folder, 'latest_rollout_agent.symlink') # TODO: This might need to use the updated path
476+
477+
if config.autoresume and _artifact_exists(self.latest_checkpoint):
471478
print(f'Autoresuming from checkpoint for RolloutAgent.')
479+
get_file(self.latest_checkpoint, self.latest_checkpoint, overwrite=True)
472480
with open(self.latest_checkpoint, 'rb') as f:
473481
checkpoint = pickle.load(f)
474482
self.iter_num = checkpoint['iter_num']
@@ -529,9 +537,13 @@ def get_next_iter_rollouts(self):
529537
'streaming_dataloader': streaming_dataloader_state_dict,
530538
}, f)
531539

540+
mlflow.log_artifact(checkpoint_path, save_folder_iter, run_id=_get_mlflow_run_id())
541+
532542
if os.path.exists(self.latest_checkpoint):
533543
os.remove(self.latest_checkpoint)
534-
os.symlink(checkpoint_path, self.latest_checkpoint)
544+
create_symlink_file(checkpoint_path, self.latest_checkpoint)
545+
546+
mlflow.log_artifact(self.latest_checkpoint, self.config.save_folder, run_id=_get_mlflow_run_id())
535547
return iter_data
536548

537549
async def run(self, num_iterations: int, experience_buffer: 'ExperienceBuffer', lock: asyncio.Lock, rollout_semaphore: asyncio.Semaphore):
@@ -728,6 +740,80 @@ async def train_async(self, max_duration: int | str):
728740
await asyncio.gather(train_task, rollout_task, eval_task)
729741
self.train_actor.collective_methods.close_trainer()
730742

743+
def _get_mlflow_run_id() -> Optional[str]:
744+
return os.environ.get('MLFLOW_RUN_ID', None)
745+
746+
def _setup_mlflow():
747+
print('setting up mlflow')
748+
dist.init_process_group(backend='gloo')
749+
# Create a new MLFlow run to be used for the entire run
750+
mlflow.set_tracking_uri('databricks')
751+
752+
# get mlflow experiment
753+
experiment = mlflow.get_experiment_by_name(MLFLOW_EXPERIMENT_NAME)
754+
if experiment is None:
755+
experiment_id = mlflow.create_experiment(MLFLOW_EXPERIMENT_NAME)
756+
else:
757+
experiment_id = experiment.experiment_id
758+
mlflow.set_experiment(experiment_id=experiment_id)
759+
760+
761+
762+
run_id = None
763+
if composer_dist.get_global_rank() == 0:
764+
# find a preexisting run if it exists
765+
existing_runs = mlflow.search_runs(
766+
experiment_ids=[experiment_id],
767+
filter_string=f'tags.run_name = "{MLFLOW_RUN_NAME}"',
768+
output_format='list',
769+
) if config.autoresume else []
770+
if len(existing_runs) > 0:
771+
run_id = existing_runs[0].info.run_id
772+
print(f'Resuming mlflow run with run id: {run_id}')
773+
else:
774+
run_id = mlflow.start_run(run_name=MLFLOW_RUN_NAME).info.run_id
775+
print(f'Creating new mlflow run with run id: {run_id}')
776+
broadcast_list = [run_id]
777+
778+
composer_dist.broadcast_object_list(broadcast_list, src=0)
779+
780+
# set all the right enviornment variables
781+
run_id = broadcast_list[0]
782+
assert run_id is not None and experiment_id is not None, "Run ID and experiment ID must be set"
783+
os.environ['MLFLOW_RUN_ID'] = run_id
784+
os.environ['MLFLOW_EXPERIMENT_ID'] = experiment_id
785+
os.environ['MLFLOW_TRACKING_URI'] = 'databricks'
786+
787+
dist.destroy_process_group()
788+
789+
790+
def _artifact_exists(artifact_path: str) -> bool:
791+
"""Return True if artifact_path exists (file or directory) for the run."""
792+
client = mlflow.MlflowClient()
793+
artifact_path = artifact_path.strip("/")
794+
795+
run_id = _get_mlflow_run_id()
796+
assert run_id is not None, "Run ID must be set"
797+
798+
# Walk down the path parts level-by-level
799+
parent = ""
800+
if artifact_path:
801+
parts = artifact_path.split("/")
802+
for i, part in enumerate(parts):
803+
entries = {os.path.basename(fi.path): fi for fi in client.list_artifacts(run_id, parent)}
804+
if part not in entries:
805+
return False
806+
fi = entries[part]
807+
is_last = (i == len(parts) - 1)
808+
if not is_last and not fi.is_dir:
809+
# trying to descend into a file
810+
return False
811+
parent = fi.path # descend
812+
813+
# If we got here, the path exists (root or found item).
814+
return True
815+
816+
731817

732818
def _run_single_controller_ppo(
733819
config: Any,
@@ -744,6 +830,8 @@ def _run_single_controller_ppo(
744830
# Disable setting CUDA_VISIBLE_DEVICES by ray, we will set it manually
745831
os.environ['RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES'] = '1'
746832

833+
_setup_mlflow()
834+
747835
with start_ray_server() as _address:
748836
# only rank 0 is the master controller
749837
if dist.get_rank() == 0:

0 commit comments

Comments
 (0)