Skip to content

Commit e1bfa13

Browse files
author
Justin Johnson
committed
update code for PyTorch 0.4
1 parent 0f1b88a commit e1bfa13

File tree

8 files changed

+240
-238
lines changed

8 files changed

+240
-238
lines changed

README.md

Lines changed: 116 additions & 115 deletions
Large diffs are not rendered by default.

autograd/two_layer_net_autograd.py

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,68 +1,65 @@
11
import torch
2-
from torch.autograd import Variable
32

43
"""
54
A fully-connected ReLU network with one hidden layer and no biases, trained to
65
predict y from x by minimizing squared Euclidean distance.
76
87
This implementation computes the forward pass using operations on PyTorch
9-
Variables, and uses PyTorch autograd to compute gradients.
8+
Tensors, and uses PyTorch autograd to compute gradients.
109
11-
A PyTorch Variable is a wrapper around a PyTorch Tensor, and represents a node
12-
in a computational graph. If x is a Variable then x.data is a Tensor giving its
13-
value, and x.grad is another Variable holding the gradient of x with respect to
14-
some scalar value.
15-
16-
PyTorch Variables have the same API as PyTorch tensors: (almost) any operation
17-
you can do on a Tensor you can also do on a Variable; the difference is that
18-
autograd allows you to automatically compute gradients.
10+
When we create a PyTorch Tensor with requires_grad=True, then operations
11+
involving that Tensor will not just compute values; they will also build up
12+
a computational graph in the background, allowing us to easily backpropagate
13+
through the graph to compute gradients of some Tensors with respect to a
14+
downstream loss. Concretely if x is a Tensor with x.requires_grad == True then
15+
after backpropagation x.grad will be another Tensor holding the gradient of x
16+
with respect to some scalar value.
1917
"""
2018

21-
dtype = torch.FloatTensor
22-
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU
19+
device = torch.device('cpu')
20+
# device = torch.device('cuda') # Uncomment this to run on GPU
2321

2422
# N is batch size; D_in is input dimension;
2523
# H is hidden dimension; D_out is output dimension.
2624
N, D_in, H, D_out = 64, 1000, 100, 10
2725

28-
# Create random Tensors to hold input and outputs, and wrap them in Variables.
29-
# Setting requires_grad=False indicates that we do not need to compute gradients
30-
# with respect to these Variables during the backward pass.
31-
x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
32-
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)
26+
# Create random Tensors to hold input and outputs
27+
x = torch.randn(N, D_in, device=device)
28+
y = torch.randn(N, D_out, device=device)
3329

34-
# Create random Tensors for weights, and wrap them in Variables.
35-
# Setting requires_grad=True indicates that we want to compute gradients with
36-
# respect to these Variables during the backward pass.
37-
w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
38-
w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)
30+
# Create random Tensors for weights; setting requires_grad=True means that we
31+
# want to compute gradients for these Tensors during the backward pass.
32+
w1 = torch.randn(D_in, H, device=device, requires_grad=True)
33+
w2 = torch.randn(H, D_out, device=device, requires_grad=True)
3934

4035
learning_rate = 1e-6
4136
for t in range(500):
42-
# Forward pass: compute predicted y using operations on Variables; these
43-
# are exactly the same operations we used to compute the forward pass using
44-
# Tensors, but we do not need to keep references to intermediate values since
45-
# we are not implementing the backward pass by hand.
37+
# Forward pass: compute predicted y using operations on Tensors. Since w1 and
38+
# w2 have requires_grad=True, operations involving these Tensors will cause
39+
# PyTorch to build a computational graph, allowing automatic computation of
40+
# gradients. Since we are no longer implementing the backward pass by hand we
41+
# don't need to keep references to intermediate values.
4642
y_pred = x.mm(w1).clamp(min=0).mm(w2)
4743

48-
# Compute and print loss using operations on Variables.
49-
# Now loss is a Variable of shape (1,) and loss.data is a Tensor of shape
50-
# (1,); loss.data[0] is a scalar value holding the loss.
44+
# Compute and print loss. Loss is a Tensor of shape (), and loss.item()
45+
# is a Python number giving its value.
5146
loss = (y_pred - y).pow(2).sum()
52-
print(t, loss.data[0])
47+
print(t, loss.item())
5348

5449
# Use autograd to compute the backward pass. This call will compute the
55-
# gradient of loss with respect to all Variables with requires_grad=True.
56-
# After this call w1.grad and w2.grad will be Variables holding the gradient
50+
# gradient of loss with respect to all Tensors with requires_grad=True.
51+
# After this call w1.grad and w2.grad will be Tensors holding the gradient
5752
# of the loss with respect to w1 and w2 respectively.
5853
loss.backward()
5954

60-
# Update weights using gradient descent; w1.data and w2.data are Tensors,
61-
# w1.grad and w2.grad are Variables and w1.grad.data and w2.grad.data are
62-
# Tensors.
63-
w1.data -= learning_rate * w1.grad.data
64-
w2.data -= learning_rate * w2.grad.data
55+
# Update weights using gradient descent. For this step we just want to mutate
56+
# the values of w1 and w2 in-place; we don't want to build up a computational
57+
# graph for the update steps, so we use the torch.no_grad() context manager
58+
# to prevent PyTorch from building a computational graph for the updates
59+
with torch.no_grad():
60+
w1 -= learning_rate * w1.grad
61+
w2 -= learning_rate * w2.grad
6562

66-
# Manually zero the gradients after running the backward pass
67-
w1.grad.data.zero_()
68-
w2.grad.data.zero_()
63+
# Manually zero the gradients after running the backward pass
64+
w1.grad.zero_()
65+
w2.grad.zero_()

autograd/two_layer_net_custom_function.py

Lines changed: 38 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import torch
2-
from torch.autograd import Variable
32

43
"""
54
A fully-connected ReLU network with one hidden layer and no biases, trained to
65
predict y from x by minimizing squared Euclidean distance.
76
87
This implementation computes the forward pass using operations on PyTorch
9-
Variables, and uses PyTorch autograd to compute gradients.
8+
Tensors, and uses PyTorch autograd to compute gradients.
109
1110
In this implementation we implement our own custom autograd function to perform
1211
the ReLU function.
@@ -18,62 +17,64 @@ class MyReLU(torch.autograd.Function):
1817
torch.autograd.Function and implementing the forward and backward passes
1918
which operate on Tensors.
2019
"""
21-
def forward(self, input):
20+
@staticmethod
21+
def forward(ctx, x):
2222
"""
23-
In the forward pass we receive a Tensor containing the input and return a
24-
Tensor containing the output. You can cache arbitrary Tensors for use in the
25-
backward pass using the save_for_backward method.
23+
In the forward pass we receive a context object and a Tensor containing the
24+
input; we must return a Tensor containing the output, and we can use the
25+
context object to cache objects for use in the backward pass.
2626
"""
27-
self.save_for_backward(input)
28-
return input.clamp(min=0)
27+
ctx.save_for_backward(x)
28+
return x.clamp(min=0)
2929

30-
def backward(self, grad_output):
30+
def backward(ctx, grad_output):
3131
"""
32-
In the backward pass we receive a Tensor containing the gradient of the loss
33-
with respect to the output, and we need to compute the gradient of the loss
34-
with respect to the input.
32+
In the backward pass we receive the context object and a Tensor containing
33+
the gradient of the loss with respect to the output produced during the
34+
forward pass. We can retrieve cached data from the context object, and must
35+
compute and return the gradient of the loss with respect to the input to the
36+
forward function.
3537
"""
36-
input, = self.saved_tensors
37-
grad_input = grad_output.clone()
38-
grad_input[input < 0] = 0
39-
return grad_input
38+
x, = ctx.saved_tensors
39+
grad_x = grad_output.clone()
40+
grad_x[x < 0] = 0
41+
return grad_x
4042

4143

42-
dtype = torch.FloatTensor
43-
# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU
44+
device = torch.device('cpu')
45+
# device = torch.device('cuda') # Uncomment this to run on GPU
4446

4547
# N is batch size; D_in is input dimension;
4648
# H is hidden dimension; D_out is output dimension.
4749
N, D_in, H, D_out = 64, 1000, 100, 10
4850

49-
# Create random Tensors to hold input and outputs, and wrap them in Variables.
50-
x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
51-
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)
51+
# Create random Tensors to hold input and output
52+
x = torch.randn(N, D_in, device=device)
53+
y = torch.randn(N, D_out, device=device)
5254

5355
# Create random Tensors for weights, and wrap them in Variables.
54-
w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
55-
w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)
56+
w1 = torch.randn(D_in, H, device=device, requires_grad=True)
57+
w2 = torch.randn(H, D_out, device=device, requires_grad=True)
5658

5759
learning_rate = 1e-6
5860
for t in range(500):
59-
# Construct an instance of our MyReLU class to use in our network
60-
relu = MyReLU()
61-
62-
# Forward pass: compute predicted y using operations on Variables; we compute
63-
# ReLU using our custom autograd operation.
64-
y_pred = relu(x.mm(w1)).mm(w2)
65-
61+
# Forward pass: compute predicted y using operations on Tensors; we call our
62+
# custom ReLU implementation using the MyReLU.apply function
63+
y_pred = MyReLU.apply(x.mm(w1)).mm(w2)
64+
6665
# Compute and print loss
6766
loss = (y_pred - y).pow(2).sum()
68-
print(t, loss.data[0])
67+
print(t, loss.item())
6968

7069
# Use autograd to compute the backward pass.
7170
loss.backward()
7271

73-
# Update weights using gradient descent
74-
w1.data -= learning_rate * w1.grad.data
75-
w2.data -= learning_rate * w2.grad.data
72+
with torch.no_grad():
73+
# Update weights using gradient descent
74+
w1 -= learning_rate * w1.grad
75+
w2 -= learning_rate * w2.grad
76+
77+
# Manually zero the gradients after running the backward pass
78+
w1.grad.zero_()
79+
w2.grad.zero_()
7680

77-
# Manually zero the gradients after running the backward pass
78-
w1.grad.data.zero_()
79-
w2.grad.data.zero_()

nn/dynamic_net.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import random
22
import torch
3-
from torch.autograd import Variable
43

54
"""
65
To showcase the power of PyTorch dynamic graphs, we will implement a very strange
@@ -46,8 +45,8 @@ def forward(self, x):
4645
N, D_in, H, D_out = 64, 1000, 100, 10
4746

4847
# Create random Tensors to hold inputs and outputs, and wrap them in Variables
49-
x = Variable(torch.randn(N, D_in))
50-
y = Variable(torch.randn(N, D_out), requires_grad=False)
48+
x = torch.randn(N, D_in)
49+
y = torch.randn(N, D_out)
5150

5251
# Construct our model by instantiating the class defined above
5352
model = DynamicNet(D_in, H, D_out)
@@ -62,7 +61,7 @@ def forward(self, x):
6261

6362
# Compute and print loss
6463
loss = criterion(y_pred, y)
65-
print(t, loss.data[0])
64+
print(t, loss.item())
6665

6766
# Zero gradients, perform a backward pass, and update the weights.
6867
optimizer.zero_grad()

nn/two_layer_net_module.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
from torch.autograd import Variable
32

43
"""
54
A fully-connected ReLU network with one hidden layer, trained to predict y from x
@@ -22,38 +21,37 @@ def __init__(self, D_in, H, D_out):
2221

2322
def forward(self, x):
2423
"""
25-
In the forward function we accept a Variable of input data and we must return
26-
a Variable of output data. We can use Modules defined in the constructor as
27-
well as arbitrary operators on Variables.
24+
In the forward function we accept a Tensor of input data and we must return
25+
a Tensor of output data. We can use Modules defined in the constructor as
26+
well as arbitrary (differentiable) operations on Tensors.
2827
"""
2928
h_relu = self.linear1(x).clamp(min=0)
3029
y_pred = self.linear2(h_relu)
3130
return y_pred
3231

33-
3432
# N is batch size; D_in is input dimension;
3533
# H is hidden dimension; D_out is output dimension.
3634
N, D_in, H, D_out = 64, 1000, 100, 10
3735

38-
# Create random Tensors to hold inputs and outputs, and wrap them in Variables
39-
x = Variable(torch.randn(N, D_in))
40-
y = Variable(torch.randn(N, D_out), requires_grad=False)
36+
# Create random Tensors to hold inputs and outputs
37+
x = torch.randn(N, D_in)
38+
y = torch.randn(N, D_out)
4139

42-
# Construct our model by instantiating the class defined above
40+
# Construct our model by instantiating the class defined above.
4341
model = TwoLayerNet(D_in, H, D_out)
4442

4543
# Construct our loss function and an Optimizer. The call to model.parameters()
4644
# in the SGD constructor will contain the learnable parameters of the two
4745
# nn.Linear modules which are members of the model.
48-
criterion = torch.nn.MSELoss(size_average=False)
46+
loss_fn = torch.nn.MSELoss(size_average=False)
4947
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
5048
for t in range(500):
5149
# Forward pass: Compute predicted y by passing x to the model
5250
y_pred = model(x)
5351

5452
# Compute and print loss
55-
loss = criterion(y_pred, y)
56-
print(t, loss.data[0])
53+
loss = loss_fn(y_pred, y)
54+
print(t, loss.item())
5755

5856
# Zero gradients, perform a backward pass, and update the weights.
5957
optimizer.zero_grad()

nn/two_layer_net_nn.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import torch
2-
from torch.autograd import Variable
32

43
"""
54
A fully-connected ReLU network with one hidden layer, trained to predict y from x
@@ -10,26 +9,31 @@
109
but raw autograd can be a bit too low-level for defining complex neural networks;
1110
this is where the nn package can help. The nn package defines a set of Modules,
1211
which you can think of as a neural network layer that has produces output from
13-
input and may have some trainable weights.
12+
input and may have some trainable weights or other state.
1413
"""
1514

15+
device = torch.device('cpu')
16+
device = torch.device('cuda') # Uncomment this to run on GPU
17+
1618
# N is batch size; D_in is input dimension;
1719
# H is hidden dimension; D_out is output dimension.
1820
N, D_in, H, D_out = 64, 1000, 100, 10
1921

20-
# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
21-
x = Variable(torch.randn(N, D_in))
22-
y = Variable(torch.randn(N, D_out), requires_grad=False)
22+
# Create random Tensors to hold inputs and outputs
23+
x = torch.randn(N, D_in, device=device)
24+
y = torch.randn(N, D_out, device=device)
2325

2426
# Use the nn package to define our model as a sequence of layers. nn.Sequential
2527
# is a Module which contains other Modules, and applies them in sequence to
2628
# produce its output. Each Linear Module computes output from input using a
27-
# linear function, and holds internal Variables for its weight and bias.
29+
# linear function, and holds internal Tensors for its weight and bias.
30+
# After constructing the model we use the .to() method to move it to the
31+
# desired device.
2832
model = torch.nn.Sequential(
2933
torch.nn.Linear(D_in, H),
3034
torch.nn.ReLU(),
3135
torch.nn.Linear(H, D_out),
32-
)
36+
).to(device)
3337

3438
# The nn package also contains definitions of popular loss functions; in this
3539
# case we will use Mean Squared Error (MSE) as our loss function.
@@ -39,25 +43,26 @@
3943
for t in range(500):
4044
# Forward pass: compute predicted y by passing x to the model. Module objects
4145
# override the __call__ operator so you can call them like functions. When
42-
# doing so you pass a Variable of input data to the Module and it produces
43-
# a Variable of output data.
46+
# doing so you pass a Tensor of input data to the Module and it produces
47+
# a Tensor of output data.
4448
y_pred = model(x)
4549

46-
# Compute and print loss. We pass Variables containing the predicted and true
50+
# Compute and print loss. We pass Tensors containing the predicted and true
4751
# values of y, and the loss function returns a Variable containing the loss.
4852
loss = loss_fn(y_pred, y)
49-
print(t, loss.data[0])
53+
print(t, loss.item())
5054

5155
# Zero the gradients before running the backward pass.
5256
model.zero_grad()
5357

5458
# Backward pass: compute gradient of the loss with respect to all the learnable
5559
# parameters of the model. Internally, the parameters of each Module are stored
56-
# in Variables with requires_grad=True, so this call will compute gradients for
60+
# in Tensors with requires_grad=True, so this call will compute gradients for
5761
# all learnable parameters in the model.
5862
loss.backward()
5963

60-
# Update the weights using gradient descent. Each parameter is a Variable, so
64+
# Update the weights using gradient descent. Each parameter is a Tensor, so
6165
# we can access its data and gradients like we did before.
62-
for param in model.parameters():
63-
param.data -= learning_rate * param.grad.data
66+
with torch.no_grad():
67+
for param in model.parameters():
68+
param.data -= learning_rate * param.grad

0 commit comments

Comments
 (0)