Skip to content

Commit b897e28

Browse files
author
Eduardo de Leon
committed
Fix conda dependencies and serialization issues
1 parent 5ac9505 commit b897e28

File tree

1 file changed

+29
-3
lines changed
  • python/interpret-core/interpret/glassbox/mlflow

1 file changed

+29
-3
lines changed

python/interpret-core/interpret/glassbox/mlflow/__init__.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,25 @@ def load_model(*args, **kwargs):
1212
return mlflow.pyfunc.load_model(*args, **kwargs)
1313

1414

15+
def _sanitize_explanation_data(data): # TODO Explanations should have a to_json()
16+
if isinstance(data, dict):
17+
for key, val in data.items():
18+
data[key] = _sanitize_explanation_data(data[key])
19+
return data
20+
21+
elif isinstance(data, list):
22+
return [_sanitize_explanation_data[x] for x in data]
23+
else:
24+
# numpy type conversion to python https://stackoverflow.com/questions/9452775 primitive
25+
return data.item() if hasattr(data, "item") else data
26+
27+
1528
def _load_pyfunc(path):
1629
import cloudpickle as pickle
1730
with open(os.path.join(path, "model.pkl"), "rb") as f:
1831
return pickle.load(f)
1932

33+
2034
def _save_model(model, output_path):
2135
import cloudpickle as pickle
2236
if not os.path.exists(output_path):
@@ -25,7 +39,15 @@ def _save_model(model, output_path):
2539
pickle.dump(model, stream)
2640
try:
2741
with open(os.path.join(output_path, "global_explanation.json"), "w") as stream:
28-
json.dump(model.explain_global().data(-1)["mli"], stream)
42+
data = model.explain_global().data(-1)["mli"]
43+
if isinstance(data, list):
44+
data = data[0]
45+
if "global" not in data["explanation_type"]:
46+
raise Exception("Invalid explanation, not global")
47+
for key in data:
48+
if isinstance(data[key], list):
49+
data[key] = [float(x) for x in data[key]]
50+
json.dump(data, stream)
2951
except ValueError as e:
3052
raise Exception("Unsupported glassbox model type {}. Failed with error {}.".format(type(model), e))
3153

@@ -34,14 +56,18 @@ def log_model(path, model):
3456
import mlflow.pyfunc
3557
except ImportError as e:
3658
raise Exception("Could not log_model to mlflow. Missing mlflow dependency, pip install mlflow to resolve the error: {}.".format(e))
59+
import cloudpickle as pickle
3760

3861
with TemporaryDirectory() as tempdir:
3962
_save_model(model, tempdir)
4063

4164
conda_env = {"name": "mlflow-env",
4265
"channels": ["defaults"],
43-
"dependencies": ["interpret=".format(interpret.version.__version__),
44-
"cloudpickle==0.5.8"
66+
"dependencies": ["pip",
67+
{"pip": [
68+
"interpret=={}".format(interpret.version.__version__),
69+
"cloudpickle=={}".format(pickle.__version__)]
70+
}
4571
]
4672
}
4773
conda_path = os.path.join(tempdir, "conda.yaml") # TODO Open issue and bug fix for dict support

0 commit comments

Comments
 (0)