Skip to content

Commit fc8b8af

Browse files
committed
Fix a silly bug in Sample version of EvoNorm missing x* part of swish, update EvoNormBatch to accumulated unbiased variance.
1 parent fa26f6c commit fc8b8af

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

timm/models/layers/evo_norm.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@ def forward(self, x):
3838
x_type = x.dtype
3939
if self.training:
4040
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
41-
self.running_var.copy_(self.momentum * var.detach() + (1 - self.momentum) * self.running_var)
41+
n = x.numel() / x.shape[1]
42+
self.running_var.copy_(
43+
var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum))
4244
else:
4345
var = self.running_var
4446

4547
if self.apply_act:
4648
v = self.v.to(dtype=x_type)
47-
d = (x * v) + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
49+
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
4850
d = d.max((var + self.eps).sqrt().to(dtype=x_type))
4951
x = x / d
5052
return x * self.weight + self.bias
@@ -74,8 +76,8 @@ def forward(self, x):
7476
B, C, H, W = x.shape
7577
assert C % self.groups == 0
7678
if self.apply_act:
77-
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
79+
n = x * (x * self.v).sigmoid()
7880
x = x.reshape(B, self.groups, -1)
79-
x = n / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
81+
x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
8082
x = x.reshape(B, C, H, W)
8183
return x * self.weight + self.bias

0 commit comments

Comments
 (0)