Skip to content

Commit 62c49f6

Browse files
authored
Added TF collection of dict parameters and parameters (#9)
1 parent ba5c605 commit 62c49f6

File tree

1 file changed

+42
-8
lines changed

1 file changed

+42
-8
lines changed

tensorlayerx/nn/core/core_tensorflow.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def __init__(self, name=None, act=None, *args, **kwargs):
5656
self._params = OrderedDict()
5757
self._layers = OrderedDict()
5858
self._params_list = OrderedDict()
59+
self._params_dict = OrderedDict()
5960
self._params_status = OrderedDict()
6061
self._parameter_layout_dict = {}
6162
self._create_time = int(time.time() * 1e9)
@@ -105,6 +106,7 @@ def __init__(self, name=None, act=None, *args, **kwargs):
105106

106107
# weights check state
107108
self._check = False
109+
self.trainable = True
108110

109111
def extend_repr(self):
110112
"""
@@ -149,6 +151,9 @@ def __setattr__(self, name, value):
149151
elif isinstance(value, ParameterList):
150152
self.set_attr_for_parameter_tuple(name, value)
151153

154+
elif isinstance(value, ParameterDict):
155+
self.set_attr_for_parameter_dict(name, value)
156+
152157
elif isinstance(value, Module):
153158
if layers is None:
154159
raise AttributeError("Can not assign layers before Module.__init__() call.")
@@ -250,6 +255,26 @@ def _set_mode_for_layers(self, is_train):
250255
if isinstance(layer, Module):
251256
layer.is_train = is_train
252257

258+
def set_attr_for_parameter_dict(self, name, value):
259+
"""Set attr for parameter in ParameterDict."""
260+
params = self.__dict__.get('_params')
261+
params_dict = self.__dict__.get('_params_dict')
262+
if params is None:
263+
raise AttributeError("For 'Module', can not assign params before Module.__init__() is called.")
264+
exist_names = set("")
265+
for item in value:
266+
self.insert_param_to_layer(item, value[item], check_name=False)
267+
if item in exist_names:
268+
raise ValueError("The value {} , its name '{}' already exists.".
269+
format(value[item], item))
270+
exist_names.add(item)
271+
272+
if name in self.__dict__:
273+
del self.__dict__[name]
274+
if name in params:
275+
del params[name]
276+
params_dict[name] = value
277+
253278
def set_attr_for_parameter_tuple(self, name, value):
254279
"""Set attr for parameter in ParameterTuple."""
255280
params = self.__dict__.get('_params')
@@ -368,6 +393,10 @@ def __getattr__(self, name):
368393
if name in params_list:
369394
para_list = params_list[name]
370395
return para_list
396+
if '_params_dict' in self.__dict__:
397+
params_dict = self.__dict__['_params_dict']
398+
if name in params_dict:
399+
return params_dict[name]
371400
raise AttributeError("'{}' object has no attribute '{}'.".format(type(self).__name__, name))
372401

373402
def __delattr__(self, name):
@@ -1027,9 +1056,10 @@ class Parameter(Module):
10271056
10281057
"""
10291058

1030-
def __new__(self, data=None, requires_grad=True, name=None):
1059+
def __new__(self, data=None, name=None):
1060+
instance = super().__new__(self)
10311061
if name is None:
1032-
prefix = self.__class__.__name__.lower()
1062+
prefix = 'parameter'
10331063

10341064
if _global_layer_name_dict.get(prefix) is not None:
10351065
_global_layer_name_dict[prefix] += 1
@@ -1047,9 +1077,13 @@ def __new__(self, data=None, requires_grad=True, name=None):
10471077
pass
10481078
else:
10491079
_global_layer_name_dict[name] = 0
1080+
if data is None:
1081+
return instance
1082+
else:
1083+
return instance(data, name)
10501084

1051-
self.name = name
1052-
return tf.Variable(initial_value=data, trainable=requires_grad, name=name)
1085+
def __call__(self, data=None, name=None, **kwargs):
1086+
return tf.Variable(initial_value=data, name=name)
10531087

10541088

10551089
class ParameterList(Module):
@@ -1219,10 +1253,10 @@ def __setitem__(self, key, parameter):
12191253
def __delitem__(self, key):
12201254
del self._params[key]
12211255

1222-
def __setattr__(self, key, value):
1223-
if not hasattr(self, key) and not isinstance(value, tf.Variable):
1224-
warnings.warn("Setting attributes on ParameterDict is not supported.")
1225-
super(ParameterDict, self).__setattr__(key, value)
1256+
# def __setattr__(self, key, value):
1257+
# if not hasattr(self, key) and not isinstance(value, tf.Variable):
1258+
# warnings.warn("Setting attributes on ParameterDict is not supported.")
1259+
# super(ParameterDict, self).__setattr__(key, value)
12261260

12271261
def __len__(self) -> int:
12281262
return len(self._params)

0 commit comments

Comments
 (0)