1
+ import os
2
+ from typing import Any , Optional
3
+
4
+ import mlflow
5
+ import torch .distributed as dist
6
+
7
+ from composer .utils import dist as composer_dist
8
+
9
+
10
+ def get_mlflow_run_id () -> Optional [str ]:
11
+ return os .environ .get ('MLFLOW_RUN_ID' , None )
12
+
13
+
14
+ def get_valid_mlflow_experiment_name (config : Any ) -> str :
15
+ """Fixes the experiment name to be an absolute path for mlflow.
16
+
17
+ MLflow requires the experiment name to be an absolute path.
18
+ If the experiment name is not an absolute path, we turn it
19
+ into an absolute path.
20
+ """
21
+ mlflow_experiment_name = config .loggers .mlflow .experiment_name
22
+ if mlflow_experiment_name .startswith ('/' ):
23
+ return mlflow_experiment_name
24
+ else :
25
+ from databricks .sdk import WorkspaceClient
26
+ return f'/Users/{ WorkspaceClient ().current_user .me ().user_name } /{ mlflow_experiment_name } '
27
+
28
+
29
+ def get_mlflow_relative_path_for_save_folder (save_folder : str ) -> str :
30
+ """Returns the relative path for the given save folder
31
+
32
+ Relative in mlflow need to be of the format: `artifacts/{relative_path}`
33
+ """
34
+ return os .path .join ('artifacts' , save_folder .lstrip ('/' ))
35
+
36
+
37
+ def get_mlflow_absolute_path_for_save_folder (save_folder : str ) -> str :
38
+ """Returns the mlflow artifact path for the given save folder"""
39
+ mlflow_prefix = 'dbfs:/databricks/mlflow-tracking/{mlflow_experiment_id}/{mlflow_run_id}'
40
+ mlflow_artifact_path = os .path .join (mlflow_prefix , get_mlflow_relative_path_for_save_folder (save_folder ))
41
+ return mlflow_artifact_path
42
+
43
+
44
+ def validate_save_folder (save_folder : str ) -> None :
45
+ """Validates the save folder"""
46
+ if save_folder .startswith ("dbfs:/" ):
47
+ raise ValueError (f"Using dbfs save_folder ({ save_folder } ) to store checkpoints is not supported. Please use a local save_folder." )
48
+
49
+
50
+ def artifact_exists_on_mlflow (artifact_path : str ) -> bool :
51
+ """Return True if artifact_path exists (file or directory) for the run.
52
+
53
+ Artifact path needs to be a relative path to the save folder.
54
+ """
55
+ client = mlflow .MlflowClient ()
56
+ run_id = get_mlflow_run_id ()
57
+ assert run_id is not None , "Run ID must be set"
58
+
59
+ # Walk down the path parts level-by-level
60
+ parent = ""
61
+ if artifact_path :
62
+ parts = artifact_path .split ("/" )
63
+ for i , part in enumerate (parts ):
64
+ entries = {os .path .basename (fi .path ): fi for fi in client .list_artifacts (run_id , parent )}
65
+ if part not in entries :
66
+ return False
67
+ fi = entries [part ]
68
+ is_last = (i == len (parts ) - 1 )
69
+ if not is_last and not fi .is_dir :
70
+ # trying to descend into a file
71
+ return False
72
+ parent = fi .path # descend
73
+
74
+ # If we got here, the path exists (root or found item).
75
+ return True
76
+
77
+
78
+ def setup_mlflow (config : Any ):
79
+ """
80
+ Sets up mlflow for the current process.
81
+
82
+ This function should be called before any other mlflow functions are called.
83
+ It will set the mlflow experiment and run. It will create both if they don't exist.
84
+ It will set all environment variables needed for mlflow.
85
+ """
86
+ dist .init_process_group (backend = 'gloo' )
87
+ mlflow .set_tracking_uri ('databricks' )
88
+
89
+ # mlflow experiment name needs to be an absolute path for databricks mlflow.
90
+ mlflow_experiment_name = get_valid_mlflow_experiment_name (config )
91
+ setattr (config .loggers .mlflow , 'experiment_name' , mlflow_experiment_name )
92
+ # COMPOSER_RUN_NAME is set for interactive mode as well.
93
+ mlflow_run_name = os .environ ['COMPOSER_RUN_NAME' ]
94
+ setattr (config .loggers .mlflow , 'run_name' , mlflow_run_name )
95
+
96
+ # get mlflow experiment if it exists, otherwise create it and set it to all ranks.
97
+ experiment_id = None
98
+ if composer_dist .get_global_rank () == 0 :
99
+ experiment = mlflow .get_experiment_by_name (mlflow_experiment_name )
100
+ if experiment is None :
101
+ experiment_id = mlflow .create_experiment (mlflow_experiment_name )
102
+ else :
103
+ experiment_id = experiment .experiment_id
104
+ experiment_id_broadcast_list = [experiment_id ]
105
+ composer_dist .broadcast_object_list (experiment_id_broadcast_list , src = 0 )
106
+ experiment_id = experiment_id_broadcast_list [0 ]
107
+
108
+ mlflow .set_experiment (experiment_id = experiment_id )
109
+
110
+ # get mlflow run if it exists and we are autoresuming, otherwise create it and set it to all ranks.
111
+ run_id = None
112
+ if composer_dist .get_global_rank () == 0 :
113
+ existing_runs = mlflow .search_runs (
114
+ experiment_ids = [experiment_id ],
115
+ filter_string = f'tags.run_name = "{ mlflow_run_name } "' ,
116
+ output_format = 'list' ,
117
+ ) if config .autoresume else []
118
+ if len (existing_runs ) > 0 :
119
+ run_id = existing_runs [0 ].info .run_id
120
+ print (f'Resuming mlflow run with run id: { run_id } ' )
121
+ else :
122
+ run_id = mlflow .start_run (run_name = mlflow_run_name ).info .run_id
123
+ print (f'Creating new mlflow run with run id: { run_id } ' )
124
+ run_id_broadcast_list = [run_id ]
125
+ composer_dist .broadcast_object_list (run_id_broadcast_list , src = 0 )
126
+ run_id = run_id_broadcast_list [0 ]
127
+
128
+ # set all the right enviornment variables
129
+ assert run_id is not None and experiment_id is not None , "Run ID and experiment ID must be set"
130
+ os .environ ['MLFLOW_RUN_ID' ] = run_id
131
+ os .environ ['MLFLOW_EXPERIMENT_ID' ] = experiment_id
132
+ os .environ ['MLFLOW_TRACKING_URI' ] = 'databricks'
133
+
134
+ dist .destroy_process_group ()
0 commit comments