Skip to content

Commit 1862d0f

Browse files
committed
make RNN definition independent of global variables
1 parent 96acc31 commit 1862d0f

File tree

4 files changed

+19
-11
lines changed

4 files changed

+19
-11
lines changed

tutorials/06 - Recurrent Neural Network/main-gpu.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@
3838
class RNN(nn.Module):
3939
def __init__(self, input_size, hidden_size, num_layers, num_classes):
4040
super(RNN, self).__init__()
41+
self.hidden_size = hidden_size
42+
self.num_layers = num_layers
4143
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
4244
self.fc = nn.Linear(hidden_size, num_classes)
4345

4446
def forward(self, x):
4547
# Set initial states
46-
h0 = Variable(torch.zeros(num_layers, x.size(0), hidden_size).cuda())
47-
c0 = Variable(torch.zeros(num_layers, x.size(0), hidden_size).cuda())
48+
h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size).cuda())
49+
c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size).cuda())
4850

4951
# Forward propagate RNN
5052
out, _ = self.lstm(x, (h0, c0))
@@ -90,4 +92,4 @@ def forward(self, x):
9092
print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
9193

9294
# Save the Model
93-
torch.save(rnn, 'rnn.pkl')
95+
torch.save(rnn, 'rnn.pkl')

tutorials/06 - Recurrent Neural Network/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,15 @@
3838
class RNN(nn.Module):
3939
def __init__(self, input_size, hidden_size, num_layers, num_classes):
4040
super(RNN, self).__init__()
41+
self.hidden_size = hidden_size
42+
self.num_layers = num_layers
4143
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
4244
self.fc = nn.Linear(hidden_size, num_classes)
4345

4446
def forward(self, x):
4547
# Set initial states
46-
h0 = Variable(torch.zeros(num_layers, x.size(0), hidden_size))
47-
c0 = Variable(torch.zeros(num_layers, x.size(0), hidden_size))
48+
h0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))
49+
c0 = Variable(torch.zeros(self.num_layers, x.size(0), self.hidden_size))
4850

4951
# Forward propagate RNN
5052
out, _ = self.lstm(x, (h0, c0))

tutorials/07 - Bidirectional Recurrent Neural Network/main-gpu.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,16 @@
3838
class BiRNN(nn.Module):
3939
def __init__(self, input_size, hidden_size, num_layers, num_classes):
4040
super(BiRNN, self).__init__()
41+
self.hidden_size = hidden_size
42+
self.num_layers = num_layers
4143
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
4244
batch_first=True, bidirectional=True)
4345
self.fc = nn.Linear(hidden_size*2, num_classes) # 2 for bidirection
4446

4547
def forward(self, x):
4648
# Set initial states
47-
h0 = Variable(torch.zeros(num_layers*2, x.size(0), hidden_size)).cuda() # 2 for bidirection
48-
c0 = Variable(torch.zeros(num_layers*2, x.size(0), hidden_size)).cuda()
49+
h0 = Variable(torch.zeros(self.num_layers*2, x.size(0), self.hidden_size)).cuda() # 2 for bidirection
50+
c0 = Variable(torch.zeros(self.num_layers*2, x.size(0), self.hidden_size)).cuda()
4951

5052
# Forward propagate RNN
5153
out, _ = self.lstm(x, (h0, c0))
@@ -91,4 +93,4 @@ def forward(self, x):
9193
print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
9294

9395
# Save the Model
94-
torch.save(rnn, 'rnn.pkl')
96+
torch.save(rnn, 'rnn.pkl')

tutorials/07 - Bidirectional Recurrent Neural Network/main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,16 @@
3838
class BiRNN(nn.Module):
3939
def __init__(self, input_size, hidden_size, num_layers, num_classes):
4040
super(BiRNN, self).__init__()
41+
self.hidden_size = hidden_size
42+
self.num_layers = num_layers
4143
self.lstm = nn.LSTM(input_size, hidden_size, num_layers,
4244
batch_first=True, bidirectional=True)
4345
self.fc = nn.Linear(hidden_size*2, num_classes) # 2 for bidirection
4446

4547
def forward(self, x):
4648
# Set initial states
47-
h0 = Variable(torch.zeros(num_layers*2, x.size(0), hidden_size)) # 2 for bidirection
48-
c0 = Variable(torch.zeros(num_layers*2, x.size(0), hidden_size))
49+
h0 = Variable(torch.zeros(self.num_layers*2, x.size(0), self.hidden_size)) # 2 for bidirection
50+
c0 = Variable(torch.zeros(self.num_layers*2, x.size(0), self.hidden_size))
4951

5052
# Forward propagate RNN
5153
out, _ = self.lstm(x, (h0, c0))
@@ -91,4 +93,4 @@ def forward(self, x):
9193
print('Test Accuracy of the model on the 10000 test images: %d %%' % (100 * correct / total))
9294

9395
# Save the Model
94-
torch.save(rnn, 'rnn.pkl')
96+
torch.save(rnn, 'rnn.pkl')

0 commit comments

Comments
 (0)