22
22
from typing import Any , Optional
23
23
24
24
from composer .loggers import MLFlowLogger
25
+ import mlflow
25
26
import ray
26
27
import torch
27
28
import torch .distributed as dist
28
29
from composer import Trainer
29
30
from composer .core import get_precision_context
30
31
from composer .optim import DecoupledAdamW
31
- from composer .utils import dist as composer_dist
32
+ from composer .utils import create_symlink_file , dist as composer_dist , get_file
32
33
from llmfoundry .data import build_dataloader
33
34
from omegaconf import OmegaConf as om
34
35
from transformers import AutoTokenizer
48
49
from compose_rl .controllers import BaseDistributedGPUActor , SPMDActorGroup
49
50
from compose_rl .controllers .buffer import Buffer
50
51
from compose_rl .algorithms .online .callback_utils import preprocess_batches
52
+ from databricks .sdk import WorkspaceClient
51
53
54
+ MLFLOW_RUN_NAME = os .environ ['COMPOSER_RUN_NAME' ] # SHOULD BE SET BY MCLI
55
+ MLFLOW_EXPERIMENT_NAME = f'/Users/{ WorkspaceClient ().current_user .me ().user_name } /test_single_controller'
52
56
53
57
@contextmanager
54
58
def time_it (name : str ):
@@ -129,7 +133,7 @@ def build_train_config(self, config: Any):
129
133
'global_train_batch_size' : self .device_train_batch_size * self .world_size ,
130
134
'device_train_batch_size' : self .device_train_batch_size ,
131
135
'device_train_microbatch_size' : self .device_train_batch_size ,
132
- 'save_folder' : self .config .save_folder ,
136
+ 'save_folder' : os . path . join ( 'dbfs:/databricks/mlflow-tracking/{mlflow_experiment_id}/{mlflow_run_id}' , self .config .save_folder ) ,
133
137
'log_config' : self .config .log_config ,
134
138
'max_seq_len' : self .max_seq_len ,
135
139
'python_log_level' : self .config .python_log_level ,
@@ -220,7 +224,7 @@ def build_ppo_trainer(self):
220
224
loggers = [mlflow_logger ],
221
225
device_train_microbatch_size = self .config .device_train_microbatch_size ,
222
226
load_path = self .ref_path ,
223
- save_folder = self .config .save_folder ,
227
+ save_folder = os . path . join ( 'dbfs:/databricks/mlflow-tracking/{mlflow_experiment_id}/{mlflow_run_id}' , self .config .save_folder ) ,
224
228
save_interval = '1iter' ,
225
229
autoresume = self .config .autoresume ,
226
230
)
@@ -338,6 +342,8 @@ def train_1_iter(self):
338
342
async def run (self , num_iterations : int , experience_buffer : 'ExperienceBuffer' , parameter_buffer : 'ParameterBuffer' , inference_server : 'InferenceServer' , lock : asyncio .Lock , rollout_semaphore : asyncio .Semaphore , eval_semaphore : asyncio .Semaphore ):
339
343
# the overall design rn is we have a async def run function for each of the subcontroller that is responsible for async primitives but leave the rest of the logic to be sync function and use
340
344
# asyncio.to_thread to bridge the async and sync world
345
+
346
+ # TODO: Load experience buffer from checkpoints, this will make checkpointing work for async
341
347
for _ in range (num_iterations ):
342
348
# Simple example of adding elements to the experience buffer
343
349
# Populate the train actor group with the rollouts and then train
@@ -465,10 +471,12 @@ def __init__(
465
471
self .iter_num = 0
466
472
467
473
# Load the latest checkpoint
468
- self .latest_checkpoint = os .path .join (self .save_folder , 'latest.symlink' )
469
474
470
- if config .autoresume and os .path .exists (self .latest_checkpoint ):
475
+ self .latest_checkpoint = os .path .join (self .save_folder , 'latest_rollout_agent.symlink' ) # TODO: This might need to use the updated path
476
+
477
+ if config .autoresume and _artifact_exists (self .latest_checkpoint ):
471
478
print (f'Autoresuming from checkpoint for RolloutAgent.' )
479
+ get_file (self .latest_checkpoint , self .latest_checkpoint , overwrite = True )
472
480
with open (self .latest_checkpoint , 'rb' ) as f :
473
481
checkpoint = pickle .load (f )
474
482
self .iter_num = checkpoint ['iter_num' ]
@@ -529,9 +537,13 @@ def get_next_iter_rollouts(self):
529
537
'streaming_dataloader' : streaming_dataloader_state_dict ,
530
538
}, f )
531
539
540
+ mlflow .log_artifact (checkpoint_path , save_folder_iter , run_id = _get_mlflow_run_id ())
541
+
532
542
if os .path .exists (self .latest_checkpoint ):
533
543
os .remove (self .latest_checkpoint )
534
- os .symlink (checkpoint_path , self .latest_checkpoint )
544
+ create_symlink_file (checkpoint_path , self .latest_checkpoint )
545
+
546
+ mlflow .log_artifact (self .latest_checkpoint , self .config .save_folder , run_id = _get_mlflow_run_id ())
535
547
return iter_data
536
548
537
549
async def run (self , num_iterations : int , experience_buffer : 'ExperienceBuffer' , lock : asyncio .Lock , rollout_semaphore : asyncio .Semaphore ):
@@ -728,6 +740,80 @@ async def train_async(self, max_duration: int | str):
728
740
await asyncio .gather (train_task , rollout_task , eval_task )
729
741
self .train_actor .collective_methods .close_trainer ()
730
742
743
+ def _get_mlflow_run_id () -> Optional [str ]:
744
+ return os .environ .get ('MLFLOW_RUN_ID' , None )
745
+
746
+ def _setup_mlflow ():
747
+ print ('setting up mlflow' )
748
+ dist .init_process_group (backend = 'gloo' )
749
+ # Create a new MLFlow run to be used for the entire run
750
+ mlflow .set_tracking_uri ('databricks' )
751
+
752
+ # get mlflow experiment
753
+ experiment = mlflow .get_experiment_by_name (MLFLOW_EXPERIMENT_NAME )
754
+ if experiment is None :
755
+ experiment_id = mlflow .create_experiment (MLFLOW_EXPERIMENT_NAME )
756
+ else :
757
+ experiment_id = experiment .experiment_id
758
+ mlflow .set_experiment (experiment_id = experiment_id )
759
+
760
+
761
+
762
+ run_id = None
763
+ if composer_dist .get_global_rank () == 0 :
764
+ # find a preexisting run if it exists
765
+ existing_runs = mlflow .search_runs (
766
+ experiment_ids = [experiment_id ],
767
+ filter_string = f'tags.run_name = "{ MLFLOW_RUN_NAME } "' ,
768
+ output_format = 'list' ,
769
+ ) if config .autoresume else []
770
+ if len (existing_runs ) > 0 :
771
+ run_id = existing_runs [0 ].info .run_id
772
+ print (f'Resuming mlflow run with run id: { run_id } ' )
773
+ else :
774
+ run_id = mlflow .start_run (run_name = MLFLOW_RUN_NAME ).info .run_id
775
+ print (f'Creating new mlflow run with run id: { run_id } ' )
776
+ broadcast_list = [run_id ]
777
+
778
+ composer_dist .broadcast_object_list (broadcast_list , src = 0 )
779
+
780
+ # set all the right enviornment variables
781
+ run_id = broadcast_list [0 ]
782
+ assert run_id is not None and experiment_id is not None , "Run ID and experiment ID must be set"
783
+ os .environ ['MLFLOW_RUN_ID' ] = run_id
784
+ os .environ ['MLFLOW_EXPERIMENT_ID' ] = experiment_id
785
+ os .environ ['MLFLOW_TRACKING_URI' ] = 'databricks'
786
+
787
+ dist .destroy_process_group ()
788
+
789
+
790
+ def _artifact_exists (artifact_path : str ) -> bool :
791
+ """Return True if artifact_path exists (file or directory) for the run."""
792
+ client = mlflow .MlflowClient ()
793
+ artifact_path = artifact_path .strip ("/" )
794
+
795
+ run_id = _get_mlflow_run_id ()
796
+ assert run_id is not None , "Run ID must be set"
797
+
798
+ # Walk down the path parts level-by-level
799
+ parent = ""
800
+ if artifact_path :
801
+ parts = artifact_path .split ("/" )
802
+ for i , part in enumerate (parts ):
803
+ entries = {os .path .basename (fi .path ): fi for fi in client .list_artifacts (run_id , parent )}
804
+ if part not in entries :
805
+ return False
806
+ fi = entries [part ]
807
+ is_last = (i == len (parts ) - 1 )
808
+ if not is_last and not fi .is_dir :
809
+ # trying to descend into a file
810
+ return False
811
+ parent = fi .path # descend
812
+
813
+ # If we got here, the path exists (root or found item).
814
+ return True
815
+
816
+
731
817
732
818
def _run_single_controller_ppo (
733
819
config : Any ,
@@ -744,6 +830,8 @@ def _run_single_controller_ppo(
744
830
# Disable setting CUDA_VISIBLE_DEVICES by ray, we will set it manually
745
831
os .environ ['RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES' ] = '1'
746
832
833
+ _setup_mlflow ()
834
+
747
835
with start_ray_server () as _address :
748
836
# only rank 0 is the master controller
749
837
if dist .get_rank () == 0 :
0 commit comments