Skip to content

Commit 1b6defb

Browse files
committed
added save load initial step document
1 parent c297c10 commit 1b6defb

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

steps/20_save_load_DL/step.md

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# Save/Load for DL Estimators
2+
3+
Contributors: ['AurumnPegasus']
4+
5+
## Introduction
6+
7+
Initially proposed in sktime issue [#3128](https://github.com/alan-turing-institute/sktime/pull/3128), we need to introduce a `save` and `load` functionality for estimators so as to easily store and load fitted models.
8+
9+
Later, Franz introduced a design for save/load estimators in a more general way in sktime issue [#3336](https://github.com/alan-turing-institute/sktime/pull/3336), and the solution I plan to propose here is built on the same.
10+
11+
## Contents
12+
13+
[TOC]
14+
15+
## Problem Statement
16+
17+
18+
### Current Implementation
19+
20+
The current implementation for `serialisation` and `deserialiation` is based on `__getstate__` and `__setstate__` functions implemented in `sklearn`'s `BaseEstimator`. It is done using pickle, where the user simply has to:
21+
22+
```python
23+
import pickle
24+
vecm = VECM()
25+
vecm.fit(train, fh=fh)
26+
save_output = pickle.dumps(vecm)
27+
----------------------------------------------
28+
model = pickle.loads(save_output)
29+
model.predict(fh=fh)
30+
```
31+
32+
### Problems
33+
34+
The issue here is that for general DL Estimators, you cannot do that, because of the `optimizer` parameter. The `optimizer` parameter uses lambda function in its inherent implementation, which can not be pickled in a straightforward manner.
35+
36+
Hence, we need to find a better and more general solution which would allow us to save and load the DL estimators as well.
37+
38+
## Solution
39+
40+
In this case, we want to use the base design proposed by Franz in [#3336](https://github.com/alan-turing-institute/sktime/pull/3336).
41+
42+
As proposed by him:
43+
44+
In the BaseObject class, we add three functions:
45+
46+
```python
47+
def save(self, path=None):
48+
import pickle
49+
if path is None:
50+
return (type(self), pickle.dumps(self))
51+
52+
from zipfile import ZipFile
53+
with ZipFile(path) as zipfile:
54+
with zipfile.open("metadata", mode="w") as meta_file:
55+
meta_file.write(type(self))
56+
with zipfile.open("object", mode="w") as object:
57+
object.write(pickle.dumps(self))
58+
return ZipFile(path)
59+
60+
def load_from_serial(cls, serial):
61+
import pickle
62+
return pickle.loads(serial)
63+
64+
def load_from_path(cld, serial):
65+
import pickle
66+
return pickle.loads(serial)
67+
```
68+
69+
For DL Estimator, we will overwrite this in a base class for all DL Estimators (which is in design phase currently [#26](https://github.com/sktime/enhancement-proposals/pull/26))
70+
71+
```python
72+
class BaseDeepClass():
73+
def __getstate__(self):
74+
out = self.__dict__.copy()
75+
del out['optimizer']
76+
del out['optimizer_']
77+
return out
78+
79+
def save(self, path=None):
80+
import pickle
81+
if path is None:
82+
return (type(self), pickle.dumps(self))
83+
84+
from zipfile import ZipFile
85+
with ZipFile(path) as zipfile:
86+
with zipfile.open("metadata", mode="w") as meta_file:
87+
meta_file.write(type(self))
88+
with zipfile.open("object", mode="w") as object:
89+
object.write(pickle.dumps(self))
90+
with zipfile.open("model", mode="w") as model:
91+
model.write(self.model_.save(path))
92+
return ZipFile(path)
93+
94+
def load_from_path(cls, serial):
95+
# supposed to return the keras model directly
96+
return keras.load(serial)
97+
```
98+
99+

0 commit comments

Comments
 (0)