Skip to content

Commit 5ac9505

Browse files
author
Eduardo de Leon
committed
Add mlflow loader, with log_model and load_model support
1 parent e6266f8 commit 5ac9505

File tree

3 files changed

+96
-0
lines changed

3 files changed

+96
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
mlruns
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import json
2+
import os
3+
import yaml
4+
5+
from tempfile import TemporaryDirectory
6+
7+
import interpret
8+
9+
10+
def load_model(*args, **kwargs):
11+
import mlflow.pyfunc
12+
return mlflow.pyfunc.load_model(*args, **kwargs)
13+
14+
15+
def _load_pyfunc(path):
16+
import cloudpickle as pickle
17+
with open(os.path.join(path, "model.pkl"), "rb") as f:
18+
return pickle.load(f)
19+
20+
def _save_model(model, output_path):
21+
import cloudpickle as pickle
22+
if not os.path.exists(output_path):
23+
os.mkdir(output_path)
24+
with open(os.path.join(output_path, "model.pkl"), "wb") as stream:
25+
pickle.dump(model, stream)
26+
try:
27+
with open(os.path.join(output_path, "global_explanation.json"), "w") as stream:
28+
json.dump(model.explain_global().data(-1)["mli"], stream)
29+
except ValueError as e:
30+
raise Exception("Unsupported glassbox model type {}. Failed with error {}.".format(type(model), e))
31+
32+
def log_model(path, model):
33+
try:
34+
import mlflow.pyfunc
35+
except ImportError as e:
36+
raise Exception("Could not log_model to mlflow. Missing mlflow dependency, pip install mlflow to resolve the error: {}.".format(e))
37+
38+
with TemporaryDirectory() as tempdir:
39+
_save_model(model, tempdir)
40+
41+
conda_env = {"name": "mlflow-env",
42+
"channels": ["defaults"],
43+
"dependencies": ["interpret=".format(interpret.version.__version__),
44+
"cloudpickle==0.5.8"
45+
]
46+
}
47+
conda_path = os.path.join(tempdir, "conda.yaml") # TODO Open issue and bug fix for dict support
48+
with open(conda_path, "w") as stream:
49+
yaml.dump(conda_env, stream)
50+
mlflow.pyfunc.log_model(path, loader_module="interpret.glassbox.mlflow", data_path=tempdir, conda_env=conda_path)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) 2019 Microsoft Corporation
2+
# Distributed under the MIT software license
3+
4+
import json
5+
import os
6+
7+
import pytest
8+
9+
from sklearn.datasets import load_breast_cancer, load_boston
10+
from sklearn.linear_model import LogisticRegression as SKLogistic
11+
from sklearn.linear_model import Lasso as SKLinear
12+
13+
from interpret.glassbox.linear import LogisticRegression, LinearRegression
14+
from interpret.glassbox.mlflow import load_model, log_model
15+
16+
17+
@pytest.fixture()
18+
def glassbox_model():
19+
boston = load_boston()
20+
return LinearRegression(feature_names=boston.feature_names, random_state=1)
21+
22+
23+
@pytest.fixture()
24+
def model():
25+
return SKLinear(random_state=1)
26+
27+
28+
def test_linear_regression_save_load(glassbox_model, model):
29+
boston = load_boston()
30+
X, y = boston.data, boston.target
31+
32+
model.fit(X, y)
33+
glassbox_model.fit(X, y)
34+
35+
save_location = "save_location"
36+
log_model(save_location, glassbox_model)
37+
38+
39+
import mlflow
40+
glassbox_model_loaded = load_model("runs:/{}/{}".format(mlflow.active_run().info.run_id, save_location))
41+
42+
name = "name"
43+
explanation_glassbox_data = glassbox_model.explain_global(name).data(-1)["mli"]
44+
explanation_glassbox_data_loaded = glassbox_model_loaded.explain_global(name).data(-1)["mli"]
45+
assert explanation_glassbox_data == explanation_glassbox_data_loaded

0 commit comments

Comments
 (0)