@@ -12,11 +12,25 @@ def load_model(*args, **kwargs):
12
12
return mlflow .pyfunc .load_model (* args , ** kwargs )
13
13
14
14
15
+ def _sanitize_explanation_data (data ): # TODO Explanations should have a to_json()
16
+ if isinstance (data , dict ):
17
+ for key , val in data .items ():
18
+ data [key ] = _sanitize_explanation_data (data [key ])
19
+ return data
20
+
21
+ elif isinstance (data , list ):
22
+ return [_sanitize_explanation_data [x ] for x in data ]
23
+ else :
24
+ # numpy type conversion to python https://stackoverflow.com/questions/9452775 primitive
25
+ return data .item () if hasattr (data , "item" ) else data
26
+
27
+
15
28
def _load_pyfunc (path ):
16
29
import cloudpickle as pickle
17
30
with open (os .path .join (path , "model.pkl" ), "rb" ) as f :
18
31
return pickle .load (f )
19
32
33
+
20
34
def _save_model (model , output_path ):
21
35
import cloudpickle as pickle
22
36
if not os .path .exists (output_path ):
@@ -25,7 +39,15 @@ def _save_model(model, output_path):
25
39
pickle .dump (model , stream )
26
40
try :
27
41
with open (os .path .join (output_path , "global_explanation.json" ), "w" ) as stream :
28
- json .dump (model .explain_global ().data (- 1 )["mli" ], stream )
42
+ data = model .explain_global ().data (- 1 )["mli" ]
43
+ if isinstance (data , list ):
44
+ data = data [0 ]
45
+ if "global" not in data ["explanation_type" ]:
46
+ raise Exception ("Invalid explanation, not global" )
47
+ for key in data :
48
+ if isinstance (data [key ], list ):
49
+ data [key ] = [float (x ) for x in data [key ]]
50
+ json .dump (data , stream )
29
51
except ValueError as e :
30
52
raise Exception ("Unsupported glassbox model type {}. Failed with error {}." .format (type (model ), e ))
31
53
@@ -34,14 +56,18 @@ def log_model(path, model):
34
56
import mlflow .pyfunc
35
57
except ImportError as e :
36
58
raise Exception ("Could not log_model to mlflow. Missing mlflow dependency, pip install mlflow to resolve the error: {}." .format (e ))
59
+ import cloudpickle as pickle
37
60
38
61
with TemporaryDirectory () as tempdir :
39
62
_save_model (model , tempdir )
40
63
41
64
conda_env = {"name" : "mlflow-env" ,
42
65
"channels" : ["defaults" ],
43
- "dependencies" : ["interpret=" .format (interpret .version .__version__ ),
44
- "cloudpickle==0.5.8"
66
+ "dependencies" : ["pip" ,
67
+ {"pip" : [
68
+ "interpret=={}" .format (interpret .version .__version__ ),
69
+ "cloudpickle=={}" .format (pickle .__version__ )]
70
+ }
45
71
]
46
72
}
47
73
conda_path = os .path .join (tempdir , "conda.yaml" ) # TODO Open issue and bug fix for dict support
0 commit comments