Skip to content

Commit 2fd1e79

Browse files
committed
moved to mlflow utils file
1 parent 22cacab commit 2fd1e79

File tree

3 files changed

+181
-141
lines changed

3 files changed

+181
-141
lines changed

compose_rl/utils/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,15 @@
5353
switch_left_to_right_padding,
5454
print_batch_shapes,
5555
)
56+
from compose_rl.utils.mlflow_utils import (
57+
get_mlflow_run_id,
58+
get_valid_mlflow_experiment_name,
59+
setup_mlflow,
60+
artifact_exists_on_mlflow,
61+
get_mlflow_absolute_path_for_save_folder,
62+
get_mlflow_relative_path_for_save_folder,
63+
validate_save_folder,
64+
)
5665

5766
__all__ = [
5867
'get_mb_load_balancing_loss',
@@ -103,4 +112,11 @@
103112
'remove_boxed',
104113
'ray_utils',
105114
'print_batch_shapes',
115+
'get_mlflow_run_id',
116+
'get_valid_mlflow_experiment_name',
117+
'setup_mlflow',
118+
'artifact_exists_on_mlflow',
119+
'get_mlflow_absolute_path_for_save_folder',
120+
'get_mlflow_relative_path_for_save_folder',
121+
'validate_save_folder',
106122
]

compose_rl/utils/mlflow_utils.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import os
2+
from typing import Any, Optional
3+
4+
import mlflow
5+
import torch.distributed as dist
6+
7+
from composer.utils import dist as composer_dist
8+
9+
10+
def get_mlflow_run_id() -> Optional[str]:
11+
return os.environ.get('MLFLOW_RUN_ID', None)
12+
13+
14+
def get_valid_mlflow_experiment_name(config: Any) -> str:
15+
"""Fixes the experiment name to be an absolute path for mlflow.
16+
17+
MLflow requires the experiment name to be an absolute path.
18+
If the experiment name is not an absolute path, we turn it
19+
into an absolute path.
20+
"""
21+
mlflow_experiment_name = config.loggers.mlflow.experiment_name
22+
if mlflow_experiment_name.startswith('/'):
23+
return mlflow_experiment_name
24+
else:
25+
from databricks.sdk import WorkspaceClient
26+
return f'/Users/{WorkspaceClient().current_user.me().user_name}/{mlflow_experiment_name}'
27+
28+
29+
def get_mlflow_relative_path_for_save_folder(save_folder: str) -> str:
30+
"""Returns the relative path for the given save folder
31+
32+
Relative in mlflow need to be of the format: `artifacts/{relative_path}`
33+
"""
34+
return os.path.join('artifacts', save_folder.lstrip('/'))
35+
36+
37+
def get_mlflow_absolute_path_for_save_folder(save_folder: str) -> str:
38+
"""Returns the mlflow artifact path for the given save folder"""
39+
mlflow_prefix = 'dbfs:/databricks/mlflow-tracking/{mlflow_experiment_id}/{mlflow_run_id}'
40+
mlflow_artifact_path = os.path.join(mlflow_prefix, get_mlflow_relative_path_for_save_folder(save_folder))
41+
return mlflow_artifact_path
42+
43+
44+
def validate_save_folder(save_folder: str) -> None:
45+
"""Validates the save folder"""
46+
if save_folder.startswith("dbfs:/"):
47+
raise ValueError(f"Using dbfs save_folder ({save_folder}) to store checkpoints is not supported. Please use a local save_folder.")
48+
49+
50+
def artifact_exists_on_mlflow(artifact_path: str) -> bool:
51+
"""Return True if artifact_path exists (file or directory) for the run.
52+
53+
Artifact path needs to be a relative path to the save folder.
54+
"""
55+
client = mlflow.MlflowClient()
56+
run_id = get_mlflow_run_id()
57+
assert run_id is not None, "Run ID must be set"
58+
59+
# Walk down the path parts level-by-level
60+
parent = ""
61+
if artifact_path:
62+
parts = artifact_path.split("/")
63+
for i, part in enumerate(parts):
64+
entries = {os.path.basename(fi.path): fi for fi in client.list_artifacts(run_id, parent)}
65+
if part not in entries:
66+
return False
67+
fi = entries[part]
68+
is_last = (i == len(parts) - 1)
69+
if not is_last and not fi.is_dir:
70+
# trying to descend into a file
71+
return False
72+
parent = fi.path # descend
73+
74+
# If we got here, the path exists (root or found item).
75+
return True
76+
77+
78+
def setup_mlflow(config: Any):
79+
"""
80+
Sets up mlflow for the current process.
81+
82+
This function should be called before any other mlflow functions are called.
83+
It will set the mlflow experiment and run. It will create both if they don't exist.
84+
It will set all environment variables needed for mlflow.
85+
"""
86+
dist.init_process_group(backend='gloo')
87+
mlflow.set_tracking_uri('databricks')
88+
89+
# mlflow experiment name needs to be an absolute path for databricks mlflow.
90+
mlflow_experiment_name = get_valid_mlflow_experiment_name(config)
91+
setattr(config.loggers.mlflow, 'experiment_name', mlflow_experiment_name)
92+
# COMPOSER_RUN_NAME is set for interactive mode as well.
93+
mlflow_run_name = os.environ['COMPOSER_RUN_NAME']
94+
setattr(config.loggers.mlflow, 'run_name', mlflow_run_name)
95+
96+
# get mlflow experiment if it exists, otherwise create it and set it to all ranks.
97+
experiment_id = None
98+
if composer_dist.get_global_rank() == 0:
99+
experiment = mlflow.get_experiment_by_name(mlflow_experiment_name)
100+
if experiment is None:
101+
experiment_id = mlflow.create_experiment(mlflow_experiment_name)
102+
else:
103+
experiment_id = experiment.experiment_id
104+
experiment_id_broadcast_list = [experiment_id]
105+
composer_dist.broadcast_object_list(experiment_id_broadcast_list, src=0)
106+
experiment_id = experiment_id_broadcast_list[0]
107+
108+
mlflow.set_experiment(experiment_id=experiment_id)
109+
110+
# get mlflow run if it exists and we are autoresuming, otherwise create it and set it to all ranks.
111+
run_id = None
112+
if composer_dist.get_global_rank() == 0:
113+
existing_runs = mlflow.search_runs(
114+
experiment_ids=[experiment_id],
115+
filter_string=f'tags.run_name = "{mlflow_run_name}"',
116+
output_format='list',
117+
) if config.autoresume else []
118+
if len(existing_runs) > 0:
119+
run_id = existing_runs[0].info.run_id
120+
print(f'Resuming mlflow run with run id: {run_id}')
121+
else:
122+
run_id = mlflow.start_run(run_name=mlflow_run_name).info.run_id
123+
print(f'Creating new mlflow run with run id: {run_id}')
124+
run_id_broadcast_list = [run_id]
125+
composer_dist.broadcast_object_list(run_id_broadcast_list, src=0)
126+
run_id = run_id_broadcast_list[0]
127+
128+
# set all the right enviornment variables
129+
assert run_id is not None and experiment_id is not None, "Run ID and experiment ID must be set"
130+
os.environ['MLFLOW_RUN_ID'] = run_id
131+
os.environ['MLFLOW_EXPERIMENT_ID'] = experiment_id
132+
os.environ['MLFLOW_TRACKING_URI'] = 'databricks'
133+
134+
dist.destroy_process_group()

0 commit comments

Comments
 (0)