@@ -56,6 +56,7 @@ def __init__(self, name=None, act=None, *args, **kwargs):
56
56
self ._params = OrderedDict ()
57
57
self ._layers = OrderedDict ()
58
58
self ._params_list = OrderedDict ()
59
+ self ._params_dict = OrderedDict ()
59
60
self ._params_status = OrderedDict ()
60
61
self ._parameter_layout_dict = {}
61
62
self ._create_time = int (time .time () * 1e9 )
@@ -105,6 +106,7 @@ def __init__(self, name=None, act=None, *args, **kwargs):
105
106
106
107
# weights check state
107
108
self ._check = False
109
+ self .trainable = True
108
110
109
111
def extend_repr (self ):
110
112
"""
@@ -149,6 +151,9 @@ def __setattr__(self, name, value):
149
151
elif isinstance (value , ParameterList ):
150
152
self .set_attr_for_parameter_tuple (name , value )
151
153
154
+ elif isinstance (value , ParameterDict ):
155
+ self .set_attr_for_parameter_dict (name , value )
156
+
152
157
elif isinstance (value , Module ):
153
158
if layers is None :
154
159
raise AttributeError ("Can not assign layers before Module.__init__() call." )
@@ -250,6 +255,26 @@ def _set_mode_for_layers(self, is_train):
250
255
if isinstance (layer , Module ):
251
256
layer .is_train = is_train
252
257
258
+ def set_attr_for_parameter_dict (self , name , value ):
259
+ """Set attr for parameter in ParameterDict."""
260
+ params = self .__dict__ .get ('_params' )
261
+ params_dict = self .__dict__ .get ('_params_dict' )
262
+ if params is None :
263
+ raise AttributeError ("For 'Module', can not assign params before Module.__init__() is called." )
264
+ exist_names = set ("" )
265
+ for item in value :
266
+ self .insert_param_to_layer (item , value [item ], check_name = False )
267
+ if item in exist_names :
268
+ raise ValueError ("The value {} , its name '{}' already exists." .
269
+ format (value [item ], item ))
270
+ exist_names .add (item )
271
+
272
+ if name in self .__dict__ :
273
+ del self .__dict__ [name ]
274
+ if name in params :
275
+ del params [name ]
276
+ params_dict [name ] = value
277
+
253
278
def set_attr_for_parameter_tuple (self , name , value ):
254
279
"""Set attr for parameter in ParameterTuple."""
255
280
params = self .__dict__ .get ('_params' )
@@ -368,6 +393,10 @@ def __getattr__(self, name):
368
393
if name in params_list :
369
394
para_list = params_list [name ]
370
395
return para_list
396
+ if '_params_dict' in self .__dict__ :
397
+ params_dict = self .__dict__ ['_params_dict' ]
398
+ if name in params_dict :
399
+ return params_dict [name ]
371
400
raise AttributeError ("'{}' object has no attribute '{}'." .format (type (self ).__name__ , name ))
372
401
373
402
def __delattr__ (self , name ):
@@ -1027,9 +1056,10 @@ class Parameter(Module):
1027
1056
1028
1057
"""
1029
1058
1030
- def __new__ (self , data = None , requires_grad = True , name = None ):
1059
+ def __new__ (self , data = None , name = None ):
1060
+ instance = super ().__new__ (self )
1031
1061
if name is None :
1032
- prefix = self . __class__ . __name__ . lower ()
1062
+ prefix = 'parameter'
1033
1063
1034
1064
if _global_layer_name_dict .get (prefix ) is not None :
1035
1065
_global_layer_name_dict [prefix ] += 1
@@ -1047,9 +1077,13 @@ def __new__(self, data=None, requires_grad=True, name=None):
1047
1077
pass
1048
1078
else :
1049
1079
_global_layer_name_dict [name ] = 0
1080
+ if data is None :
1081
+ return instance
1082
+ else :
1083
+ return instance (data , name )
1050
1084
1051
- self . name = name
1052
- return tf .Variable (initial_value = data , trainable = requires_grad , name = name )
1085
+ def __call__ ( self , data = None , name = None , ** kwargs ):
1086
+ return tf .Variable (initial_value = data , name = name )
1053
1087
1054
1088
1055
1089
class ParameterList (Module ):
@@ -1219,10 +1253,10 @@ def __setitem__(self, key, parameter):
1219
1253
def __delitem__ (self , key ):
1220
1254
del self ._params [key ]
1221
1255
1222
- def __setattr__ (self , key , value ):
1223
- if not hasattr (self , key ) and not isinstance (value , tf .Variable ):
1224
- warnings .warn ("Setting attributes on ParameterDict is not supported." )
1225
- super (ParameterDict , self ).__setattr__ (key , value )
1256
+ # def __setattr__(self, key, value):
1257
+ # if not hasattr(self, key) and not isinstance(value, tf.Variable):
1258
+ # warnings.warn("Setting attributes on ParameterDict is not supported.")
1259
+ # super(ParameterDict, self).__setattr__(key, value)
1226
1260
1227
1261
def __len__ (self ) -> int :
1228
1262
return len (self ._params )
0 commit comments