Skip to content

Commit ad8b2b9

Browse files
committed
hyperpose compatible:
(1)maxpool and batchnorm dataformat debuged,support "channels_first" (2)vgg forward fixed
1 parent faf18cb commit ad8b2b9

3 files changed

Lines changed: 24 additions & 68 deletions

File tree

tensorlayer/layers/normalization.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def _bias_scale(x, b, data_format):
9292
if data_format == 'NHWC':
9393
return x * b
9494
elif data_format == 'NCHW':
95-
return x * _to_channel_first_bias(b)
95+
return x * b
9696
else:
9797
raise ValueError('invalid data_format: %s' % data_format)
9898

@@ -102,7 +102,7 @@ def _bias_add(x, b, data_format):
102102
if data_format == 'NHWC':
103103
return tf.add(x, b)
104104
elif data_format == 'NCHW':
105-
return tf.add(x, _to_channel_first_bias(b))
105+
return tf.add(x, b)
106106
else:
107107
raise ValueError('invalid data_format: %s' % data_format)
108108

@@ -291,9 +291,9 @@ def forward(self, inputs):
291291
if self.axes is None:
292292
self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis]
293293

294+
mean, var = tf.nn.moments(inputs, self.axes, keepdims=False)
294295
if self.is_train:
295296
# update moving_mean and moving_var
296-
mean, var = tf.nn.moments(inputs, self.axes, keepdims=False)
297297
self.moving_mean = moving_averages.assign_moving_average(
298298
self.moving_mean, mean, self.decay, zero_debias=False
299299
)

0 commit comments

Comments
 (0)