We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 0b3ca76 + 1469a3d commit 6a852ffCopy full SHA for 6a852ff
tensorlayer/layers/normalization.py
@@ -226,7 +226,6 @@ def __init__(
226
self.moving_var_init = moving_var_init
227
self.num_features = num_features
228
229
- self.channel_axis = -1 if data_format == 'channels_last' else 1
230
self.axes = None
231
232
if num_features is not None:
@@ -288,6 +287,7 @@ def build(self, inputs_shape):
288
287
def forward(self, inputs):
289
self._check_input_shape(inputs)
290
+ self.channel_axis = len(inputs.shape) - 1 if self.data_format == 'channels_last' else 1
291
if self.axes is None:
292
self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis]
293
0 commit comments