Skip to content

Commit 5967ce9

Browse files
author
Omar Samir Mohammed
committed
Enabling the batch_size to propogate to the get_batch function
1 parent c88343d commit 5967ce9

File tree

1 file changed

+11
-5
lines changed

1 file changed

+11
-5
lines changed

experiment.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@
1111
import data_generation
1212
import models
1313

14-
batch_size = 32
14+
batch_size = 16
1515
seq_len = 100 # This is equivalent to time steps of the sequence in keras
1616
input_size = 1
1717
hidden_size = 51
18-
nb_layers = 1
1918
target_size = 1
2019
nb_samples = 1000
2120
nb_epochs = 20
@@ -36,17 +35,22 @@
3635
val_loss = 0
3736
rnn.train(True)
3837
for batch, i in enumerate(range(0, X_train.size(0) - 1, batch_size)):
39-
data, targets = data_generation.get_batch(X_train, y_train, i)
38+
# print "batch = ", batch, " -- i = ", i
39+
data, targets = data_generation.get_batch(X_train, y_train, i, batch_size=batch_size)
40+
# print ("X_train.size() = {}, y_train.size() = {} \n data.size() = {}, targets.size() = {}"
41+
# .format(X_train.size(), y_train.size(), data.size(), targets.size()))
4042
output = rnn(data)
4143
optimizer.zero_grad()
4244
loss = loss_fn(output, targets)
4345
loss.backward()
4446
optimizer.step()
4547
training_loss += loss.data[0]
48+
# print "-"*30
49+
# exit()
4650
training_loss /= batch
4751
rnn.train(False)
4852
for batch, i in enumerate(range(0, X_val.size(0) - 1, batch_size)):
49-
data, targets = data_generation.get_batch(X_val, y_val, i)
53+
data, targets = data_generation.get_batch(X_val, y_val, i, batch_size=batch_size)
5054
output = rnn(data)
5155
loss = loss_fn(output, targets)
5256
val_loss += loss.data[0]
@@ -63,7 +67,7 @@
6367
list2 = []
6468
for batch, i in enumerate(range(0, X_test.size(0) - 1, batch_size)):
6569
print i
66-
data, targets = data_generation.get_batch(X_test, y_test, i)
70+
data, targets = data_generation.get_batch(X_test, y_test, i, batch_size=batch_size)
6771
output = rnn(data)
6872
loss = loss_fn(output, targets)
6973
test_loss += loss.data[0]
@@ -86,6 +90,8 @@
8690
output = torch.squeeze(output).data.cpu().numpy()
8791
plt.figure()
8892
plt.plot(output)
93+
plt.xlabel("Time step")
94+
plt.ylabel("Signal amplitude")
8995
plt.show()
9096
"""
9197
Generating sequences - attempt 2 --> Concatenating the output with the input, and feed the new data point to the model.

0 commit comments

Comments
 (0)