@@ -38,13 +38,15 @@ def forward(self, x):
38
38
x_type = x .dtype
39
39
if self .training :
40
40
var = x .var (dim = (0 , 2 , 3 ), unbiased = False , keepdim = True )
41
- self .running_var .copy_ (self .momentum * var .detach () + (1 - self .momentum ) * self .running_var )
41
+ n = x .numel () / x .shape [1 ]
42
+ self .running_var .copy_ (
43
+ var .detach () * self .momentum * (n / (n - 1 )) + self .running_var * (1 - self .momentum ))
42
44
else :
43
45
var = self .running_var
44
46
45
47
if self .apply_act :
46
48
v = self .v .to (dtype = x_type )
47
- d = ( x * v ) + (x .var (dim = (2 , 3 ), unbiased = False , keepdim = True ) + self .eps ).sqrt ().to (dtype = x_type )
49
+ d = x * v + (x .var (dim = (2 , 3 ), unbiased = False , keepdim = True ) + self .eps ).sqrt ().to (dtype = x_type )
48
50
d = d .max ((var + self .eps ).sqrt ().to (dtype = x_type ))
49
51
x = x / d
50
52
return x * self .weight + self .bias
@@ -74,8 +76,8 @@ def forward(self, x):
74
76
B , C , H , W = x .shape
75
77
assert C % self .groups == 0
76
78
if self .apply_act :
77
- n = (x * self .v ).sigmoid (). reshape ( B , self . groups , - 1 )
79
+ n = x * (x * self .v ).sigmoid ()
78
80
x = x .reshape (B , self .groups , - 1 )
79
- x = n / (x .var (dim = - 1 , unbiased = False , keepdim = True ) + self .eps ).sqrt ()
81
+ x = n . reshape ( B , self . groups , - 1 ) / (x .var (dim = - 1 , unbiased = False , keepdim = True ) + self .eps ).sqrt ()
80
82
x = x .reshape (B , C , H , W )
81
83
return x * self .weight + self .bias
0 commit comments