Skip to content

Commit 84b2da2

Browse files
authored
delete cuda()
1 parent 0108b9f commit 84b2da2

File tree

1 file changed

+2
-2
lines changed
  • tutorials/06 - Recurrent Neural Network

1 file changed

+2
-2
lines changed

tutorials/06 - Recurrent Neural Network/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def forward(self, x):
6464
for epoch in range(num_epochs):
6565
for i, (images, labels) in enumerate(train_loader):
6666
images = Variable(images.view(-1, sequence_length, input_size))
67-
labels = Variable(labels).cuda()
67+
labels = Variable(labels)
6868

6969
# Forward + Backward + Optimize
7070
optimizer.zero_grad()
@@ -90,4 +90,4 @@ def forward(self, x):
9090
print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
9191

9292
# Save the Model
93-
torch.save(rnn, 'rnn.pkl')
93+
torch.save(rnn, 'rnn.pkl')

0 commit comments

Comments
 (0)