Skip to content

Commit 4df9208

Browse files
committed
sets model to eval mode (required due to batchnorm layer)
1 parent 75dd61c commit 4df9208

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

train_classification.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@
7474
points = points.transpose(2,1)
7575
points, target = points.cuda(), target.cuda()
7676
optimizer.zero_grad()
77+
classifier = classifier.train()
7778
pred, _ = classifier(points)
7879
loss = F.nll_loss(pred, target)
7980
loss.backward()
@@ -88,6 +89,7 @@
8889
points, target = Variable(points), Variable(target[:,0])
8990
points = points.transpose(2,1)
9091
points, target = points.cuda(), target.cuda()
92+
classifier = classifier.eval()
9193
pred, _ = classifier(points)
9294
loss = F.nll_loss(pred, target)
9395
pred_choice = pred.data.max(1)[1]

train_segmentation.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
points = points.transpose(2,1)
7272
points, target = points.cuda(), target.cuda()
7373
optimizer.zero_grad()
74+
classifier = classifier.train()
7475
pred, _ = classifier(points)
7576
pred = pred.view(-1, num_classes)
7677
target = target.view(-1,1)[:,0] - 1
@@ -87,7 +88,8 @@
8788
points, target = data
8889
points, target = Variable(points), Variable(target)
8990
points = points.transpose(2,1)
90-
points, target = points.cuda(), target.cuda()
91+
points, target = points.cuda(), target.cuda()
92+
classifier = classifier.eval()
9193
pred, _ = classifier(points)
9294
pred = pred.view(-1, num_classes)
9395
target = target.view(-1,1)[:,0] - 1

0 commit comments

Comments
 (0)