Skip to content

Commit 6a852ff

Browse files
authored
Merge pull request #1104 from Laicheng0830/fix_bn
fix BatchNorm
2 parents 0b3ca76 + 1469a3d commit 6a852ff

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tensorlayer/layers/normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,6 @@ def __init__(
226226
self.moving_var_init = moving_var_init
227227
self.num_features = num_features
228228

229-
self.channel_axis = -1 if data_format == 'channels_last' else 1
230229
self.axes = None
231230

232231
if num_features is not None:
@@ -288,6 +287,7 @@ def build(self, inputs_shape):
288287
def forward(self, inputs):
289288
self._check_input_shape(inputs)
290289

290+
self.channel_axis = len(inputs.shape) - 1 if self.data_format == 'channels_last' else 1
291291
if self.axes is None:
292292
self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis]
293293

0 commit comments

Comments
 (0)