diff --git a/README.md b/README.md index f58c7c3..65ffc4e 100644 --- a/README.md +++ b/README.md @@ -11,10 +11,15 @@ will have a single hidden layer, and will be trained with gradient descent to fit random data by minimizing the Euclidean distance between the network output and the true output. +**NOTE:** These examples have been update for PyTorch 0.4, which made several +major changes to the core PyTorch API. Most notably, prior to 0.4 Tensors had +to be wrapped in Variable objects to use autograd; this functionality has now +been added directly to Tensors, and Variables are now deprecated. + ### Table of Contents - Warm-up: numpy - PyTorch: Tensors -- PyTorch: Variables and autograd +- PyTorch: Autograd - PyTorch: Defining new autograd functions - TensorFlow: Static Graphs - PyTorch: nn @@ -82,37 +87,37 @@ unfortunately numpy won't be enough for modern deep learning. Here we introduce the most fundamental PyTorch concept: the **Tensor**. A PyTorch Tensor is conceptually identical to a numpy array: a Tensor is an n-dimensional -array, and PyTorch provides many functions for operating on these Tensors. Like -numpy arrays, PyTorch Tensors do not know anything about deep learning or -computational graphs or gradients; they are a generic tool for scientific +array, and PyTorch provides many functions for operating on these Tensors. +Any computation you might want to perform with numpy can also be accomplished +with PyTorch Tensors; you should think of them as a generic tool for scientific computing. However unlike numpy, PyTorch Tensors can utilize GPUs to accelerate their -numeric computations. To run a PyTorch Tensor on GPU, you simply need to cast it -to a new datatype. +numeric computations. To run a PyTorch Tensor on GPU, you use the `device` +argument when constructing a Tensor to place the Tensor on a GPU. Here we use PyTorch Tensors to fit a two-layer network to random data. Like the -numpy example above we need to manually implement the forward and backward -passes through the network: +numpy example above we manually implement the forward and backward +passes through the network, using operations on PyTorch Tensors: ```python # Code in file tensor/two_layer_net_tensor.py import torch -dtype = torch.FloatTensor -# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU +device = torch.device('cpu') +# device = torch.device('cuda') # Uncomment this to run on GPU # N is batch size; D_in is input dimension; # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 # Create random input and output data -x = torch.randn(N, D_in).type(dtype) -y = torch.randn(N, D_out).type(dtype) +x = torch.randn(N, D_in, device=device) +y = torch.randn(N, D_out, device=device) # Randomly initialize weights -w1 = torch.randn(D_in, H).type(dtype) -w2 = torch.randn(H, D_out).type(dtype) +w1 = torch.randn(D_in, H, device=device) +w2 = torch.randn(H, D_out, device=device) learning_rate = 1e-6 for t in range(500): @@ -121,9 +126,10 @@ for t in range(500): h_relu = h.clamp(min=0) y_pred = h_relu.mm(w2) - # Compute and print loss + # Compute and print loss; loss is a scalar, and is stored in a PyTorch Tensor + # of shape (); we can get its value as a Python number with loss.item(). loss = (y_pred - y).pow(2).sum() - print(t, loss) + print(t, loss.item()) # Backprop to compute gradients of w1 and w2 with respect to loss grad_y_pred = 2.0 * (y_pred - y) @@ -138,7 +144,7 @@ for t in range(500): w2 -= learning_rate * grad_w2 ``` -## PyTorch: Variables and autograd +## PyTorch: Autograd In the above examples, we had to manually implement both the forward and backward passes of our neural network. Manually implementing the backward pass @@ -154,74 +160,75 @@ When using autograd, the forward pass of your network will define a functions that produce output Tensors from input Tensors. Backpropagating through this graph then allows you to easily compute gradients. -This sounds complicated, it's pretty simple to use in practice. We wrap our -PyTorch Tensors in **Variable** objects; a Variable represents a node in a -computational graph. If `x` is a Variable then `x.data` is a Tensor, and -`x.grad` is another Variable holding the gradient of `x` with respect to some -scalar value. - -PyTorch Variables have the same API as PyTorch Tensors: (almost) any operation -that you can perform on a Tensor also works on Variables; the difference is that -using Variables defines a computational graph, allowing you to automatically -compute gradients. - -Here we use PyTorch Variables and autograd to implement our two-layer network; +This sounds complicated, it's pretty simple to use in practice. If we want to +compute gradients with respect to some Tensor, then we set `requires_grad=True` +when constructing that Tensor. Any PyTorch operations on that Tensor will cause +a computational graph to be constructed, allowing us to later perform backpropagation +through the graph. If `x` is a Tensor with `requires_grad=True`, then after +backpropagation `x.grad` will be another Tensor holding the gradient of `x` with +respect to some scalar value. + +Sometimes you may wish to prevent PyTorch from building computational graphs when +performing certain operations on Tensors with `requires_grad=True`; for example +we usually don't want to backpropagate through the weight update steps when +training a neural network. In such scenarios we can use the `torch.no_grad()` +context manager to prevent the construction of a computational graph. + +Here we use PyTorch Tensors and autograd to implement our two-layer network; now we no longer need to manually implement the backward pass through the network: ```python # Code in file autograd/two_layer_net_autograd.py import torch -from torch.autograd import Variable -dtype = torch.FloatTensor -# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU +device = torch.device('cpu') +# device = torch.device('cuda') # Uncomment this to run on GPU # N is batch size; D_in is input dimension; # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold input and outputs, and wrap them in Variables. -# Setting requires_grad=False indicates that we do not need to compute gradients -# with respect to these Variables during the backward pass. -x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False) -y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False) +# Create random Tensors to hold input and outputs +x = torch.randn(N, D_in, device=device) +y = torch.randn(N, D_out, device=device) -# Create random Tensors for weights, and wrap them in Variables. -# Setting requires_grad=True indicates that we want to compute gradients with -# respect to these Variables during the backward pass. -w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True) -w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True) +# Create random Tensors for weights; setting requires_grad=True means that we +# want to compute gradients for these Tensors during the backward pass. +w1 = torch.randn(D_in, H, device=device, requires_grad=True) +w2 = torch.randn(H, D_out, device=device, requires_grad=True) learning_rate = 1e-6 for t in range(500): - # Forward pass: compute predicted y using operations on Variables; these - # are exactly the same operations we used to compute the forward pass using - # Tensors, but we do not need to keep references to intermediate values since - # we are not implementing the backward pass by hand. + # Forward pass: compute predicted y using operations on Tensors. Since w1 and + # w2 have requires_grad=True, operations involving these Tensors will cause + # PyTorch to build a computational graph, allowing automatic computation of + # gradients. Since we are no longer implementing the backward pass by hand we + # don't need to keep references to intermediate values. y_pred = x.mm(w1).clamp(min=0).mm(w2) - # Compute and print loss using operations on Variables. - # Now loss is a Variable of shape (1,) and loss.data is a Tensor of shape - # (1,); loss.data[0] is a scalar value holding the loss. + # Compute and print loss. Loss is a Tensor of shape (), and loss.item() + # is a Python number giving its value. loss = (y_pred - y).pow(2).sum() - print(t, loss.data[0]) - - # Manually zero the gradients before running the backward pass - w1.grad.data.zero_() - w2.grad.data.zero_() + print(t, loss.item()) # Use autograd to compute the backward pass. This call will compute the - # gradient of loss with respect to all Variables with requires_grad=True. - # After this call w1.grad and w2.grad will be Variables holding the gradient + # gradient of loss with respect to all Tensors with requires_grad=True. + # After this call w1.grad and w2.grad will be Tensors holding the gradient # of the loss with respect to w1 and w2 respectively. loss.backward() - # Update weights using gradient descent; w1.data and w2.data are Tensors, - # w1.grad and w2.grad are Variables and w1.grad.data and w2.grad.data are - # Tensors. - w1.data -= learning_rate * w1.grad.data - w2.data -= learning_rate * w2.grad.data + # Update weights using gradient descent. For this step we just want to mutate + # the values of w1 and w2 in-place; we don't want to build up a computational + # graph for the update steps, so we use the torch.no_grad() context manager + # to prevent PyTorch from building a computational graph for the updates + with torch.no_grad(): + w1 -= learning_rate * w1.grad + w2 -= learning_rate * w2.grad + + # Manually zero the gradients after running the backward pass + w1.grad.zero_() + w2.grad.zero_() ``` ## PyTorch: Defining new autograd functions @@ -234,7 +241,7 @@ with respect to that same scalar value. In PyTorch we can easily define our own autograd operator by defining a subclass of `torch.autograd.Function` and implementing the `forward` and `backward` functions. We can then use our new autograd operator by constructing an instance and calling it -like a function, passing Variables containing input data. +like a function, passing Tensors containing input data. In this example we define our own custom autograd function for performing the ReLU nonlinearity, and use it to implement our two-layer network: @@ -242,7 +249,6 @@ nonlinearity, and use it to implement our two-layer network: ```python # Code in file autograd/two_layer_net_custom_function.py import torch -from torch.autograd import Variable class MyReLU(torch.autograd.Function): """ @@ -250,65 +256,68 @@ class MyReLU(torch.autograd.Function): torch.autograd.Function and implementing the forward and backward passes which operate on Tensors. """ - def forward(self, input): + @staticmethod + def forward(ctx, x): """ - In the forward pass we receive a Tensor containing the input and return a - Tensor containing the output. You can cache arbitrary Tensors for use in the - backward pass using the save_for_backward method. + In the forward pass we receive a context object and a Tensor containing the + input; we must return a Tensor containing the output, and we can use the + context object to cache objects for use in the backward pass. """ - self.save_for_backward(input) - return input.clamp(min=0) + ctx.save_for_backward(x) + return x.clamp(min=0) - def backward(self, grad_output): + @staticmethod + def backward(ctx, grad_output): """ - In the backward pass we receive a Tensor containing the gradient of the loss - with respect to the output, and we need to compute the gradient of the loss - with respect to the input. + In the backward pass we receive the context object and a Tensor containing + the gradient of the loss with respect to the output produced during the + forward pass. We can retrieve cached data from the context object, and must + compute and return the gradient of the loss with respect to the input to the + forward function. """ - input, = self.saved_tensors - grad_input = grad_output.clone() - grad_input[input < 0] = 0 - return grad_input + x, = ctx.saved_tensors + grad_x = grad_output.clone() + grad_x[x < 0] = 0 + return grad_x -dtype = torch.FloatTensor -# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU +device = torch.device('cpu') +# device = torch.device('cuda') # Uncomment this to run on GPU # N is batch size; D_in is input dimension; # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold input and outputs, and wrap them in Variables. -x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False) -y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False) +# Create random Tensors to hold input and output +x = torch.randn(N, D_in, device=device) +y = torch.randn(N, D_out, device=device) -# Create random Tensors for weights, and wrap them in Variables. -w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True) -w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True) +# Create random Tensors for weights. +w1 = torch.randn(D_in, H, device=device, requires_grad=True) +w2 = torch.randn(H, D_out, device=device, requires_grad=True) learning_rate = 1e-6 for t in range(500): - # Construct an instance of our MyReLU class to use in our network - relu = MyReLU() - - # Forward pass: compute predicted y using operations on Variables; we compute - # ReLU using our custom autograd operation. - y_pred = relu(x.mm(w1)).mm(w2) - + # Forward pass: compute predicted y using operations on Tensors; we call our + # custom ReLU implementation using the MyReLU.apply function + y_pred = MyReLU.apply(x.mm(w1)).mm(w2) + # Compute and print loss loss = (y_pred - y).pow(2).sum() - print(t, loss.data[0]) - - # Manually zero the gradients before running the backward pass - w1.grad.data.zero_() - w2.grad.data.zero_() + print(t, loss.item()) # Use autograd to compute the backward pass. loss.backward() - # Update weights using gradient descent - w1.data -= learning_rate * w1.grad.data - w2.data -= learning_rate * w2.grad.data + with torch.no_grad(): + # Update weights using gradient descent + w1 -= learning_rate * w1.grad + w2 -= learning_rate * w2.grad + + # Manually zero the gradients after running the backward pass + w1.grad.zero_() + w2.grad.zero_() + ``` ## TensorFlow: Static Graphs @@ -421,8 +430,8 @@ raw computational graphs that are useful for building neural networks. In PyTorch, the `nn` package serves this same purpose. The `nn` package defines a set of **Modules**, which are roughly equivalent to neural network layers. A Module receives -input Variables and computes output Variables, but may also hold internal state such as -Variables containing learnable parameters. The `nn` package also defines a set of useful +input Tensors and computes output Tensors, but may also hold internal state such as +Tensors containing learnable parameters. The `nn` package also defines a set of useful loss functions that are commonly used when training neural networks. In this example we use the `nn` package to implement our two-layer network: @@ -430,62 +439,71 @@ In this example we use the `nn` package to implement our two-layer network: ```python # Code in file nn/two_layer_net_nn.py import torch -from torch.autograd import Variable + +device = torch.device('cpu') +# device = torch.device('cuda') # Uncomment this to run on GPU # N is batch size; D_in is input dimension; # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold inputs and outputs, and wrap them in Variables. -x = Variable(torch.randn(N, D_in)) -y = Variable(torch.randn(N, D_out), requires_grad=False) +# Create random Tensors to hold inputs and outputs +x = torch.randn(N, D_in, device=device) +y = torch.randn(N, D_out, device=device) # Use the nn package to define our model as a sequence of layers. nn.Sequential # is a Module which contains other Modules, and applies them in sequence to # produce its output. Each Linear Module computes output from input using a -# linear function, and holds internal Variables for its weight and bias. +# linear function, and holds internal Tensors for its weight and bias. +# After constructing the model we use the .to() method to move it to the +# desired device. model = torch.nn.Sequential( torch.nn.Linear(D_in, H), torch.nn.ReLU(), torch.nn.Linear(H, D_out), - ) + ).to(device) # The nn package also contains definitions of popular loss functions; in this -# case we will use Mean Squared Error (MSE) as our loss function. -loss_fn = torch.nn.MSELoss(size_average=False) +# case we will use Mean Squared Error (MSE) as our loss function. Setting +# reduction='sum' means that we are computing the *sum* of squared errors rather +# than the mean; this is for consistency with the examples above where we +# manually compute the loss, but in practice it is more common to use mean +# squared error as a loss by setting reduction='elementwise_mean'. +loss_fn = torch.nn.MSELoss(reduction='sum') learning_rate = 1e-4 for t in range(500): # Forward pass: compute predicted y by passing x to the model. Module objects # override the __call__ operator so you can call them like functions. When - # doing so you pass a Variable of input data to the Module and it produces - # a Variable of output data. + # doing so you pass a Tensor of input data to the Module and it produces + # a Tensor of output data. y_pred = model(x) - # Compute and print loss. We pass Variables containing the predicted and true - # values of y, and the loss function returns a Variable containing the loss. + # Compute and print loss. We pass Tensors containing the predicted and true + # values of y, and the loss function returns a Tensor containing the loss. loss = loss_fn(y_pred, y) - print(t, loss.data[0]) + print(t, loss.item()) # Zero the gradients before running the backward pass. model.zero_grad() # Backward pass: compute gradient of the loss with respect to all the learnable # parameters of the model. Internally, the parameters of each Module are stored - # in Variables with requires_grad=True, so this call will compute gradients for + # in Tensors with requires_grad=True, so this call will compute gradients for # all learnable parameters in the model. loss.backward() - # Update the weights using gradient descent. Each parameter is a Variable, so + # Update the weights using gradient descent. Each parameter is a Tensor, so # we can access its data and gradients like we did before. - for param in model.parameters(): - param.data -= learning_rate * param.grad.data + with torch.no_grad(): + for param in model.parameters(): + param.data -= learning_rate * param.grad ``` ## PyTorch: optim -Up to this point we have updated the weights of our models by manually mutating the -`.data` member for Variables holding learnable parameters. This is not a huge burden +Up to this point we have updated the weights of our models by manually mutating +Tensors holding learnable parameters. This is not a huge burden for simple optimization algorithms like stochastic gradient descent, but in practice we often train neural networks using more sophisiticated optimizers like AdaGrad, RMSProp, Adam, etc. @@ -499,15 +517,14 @@ will optimize the model using the Adam algorithm provided by the `optim` package ```python # Code in file nn/two_layer_net_optim.py import torch -from torch.autograd import Variable # N is batch size; D_in is input dimension; # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold inputs and outputs, and wrap them in Variables. -x = Variable(torch.randn(N, D_in)) -y = Variable(torch.randn(N, D_out), requires_grad=False) +# Create random Tensors to hold inputs and outputs. +x = torch.randn(N, D_in) +y = torch.randn(N, D_out) # Use the nn package to define our model and loss function. model = torch.nn.Sequential( @@ -515,12 +532,12 @@ model = torch.nn.Sequential( torch.nn.ReLU(), torch.nn.Linear(H, D_out), ) -loss_fn = torch.nn.MSELoss(size_average=False) +loss_fn = torch.nn.MSELoss(reduction='sum') # Use the optim package to define an Optimizer that will update the weights of # the model for us. Here we will use Adam; the optim package contains many other -# optimization algoriths. The first argument to the Adam constructor tells the -# optimizer which Variables it should update. +# optimization algorithms. The first argument to the Adam constructor tells the +# optimizer which Tensors it should update. learning_rate = 1e-4 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) for t in range(500): @@ -529,10 +546,10 @@ for t in range(500): # Compute and print loss. loss = loss_fn(y_pred, y) - print(t, loss.data[0]) + print(t, loss.item()) # Before the backward pass, use the optimizer object to zero all of the - # gradients for the variables it will update (which are the learnable weights + # gradients for the Tensors it will update (which are the learnable weights # of the model) optimizer.zero_grad() @@ -547,15 +564,14 @@ for t in range(500): ## PyTorch: Custom nn Modules Sometimes you will want to specify models that are more complex than a sequence of existing Modules; for these cases you can define your own Modules by subclassing -`nn.Module` and defining a `forward` which receives input Variables and produces -output Variables using other modules or other autograd operations on Variables. +`nn.Module` and defining a `forward` which receives input Tensors and produces +output Tensors using other modules or other autograd operations on Tensors. In this example we implement our two-layer network as a custom Module subclass: ```python # Code in file nn/two_layer_net_module.py import torch -from torch.autograd import Variable class TwoLayerNet(torch.nn.Module): def __init__(self, D_in, H, D_out): @@ -569,38 +585,37 @@ class TwoLayerNet(torch.nn.Module): def forward(self, x): """ - In the forward function we accept a Variable of input data and we must return - a Variable of output data. We can use Modules defined in the constructor as - well as arbitrary operators on Variables. + In the forward function we accept a Tensor of input data and we must return + a Tensor of output data. We can use Modules defined in the constructor as + well as arbitrary (differentiable) operations on Tensors. """ h_relu = self.linear1(x).clamp(min=0) y_pred = self.linear2(h_relu) return y_pred - # N is batch size; D_in is input dimension; # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold inputs and outputs, and wrap them in Variables -x = Variable(torch.randn(N, D_in)) -y = Variable(torch.randn(N, D_out), requires_grad=False) +# Create random Tensors to hold inputs and outputs +x = torch.randn(N, D_in) +y = torch.randn(N, D_out) -# Construct our model by instantiating the class defined above +# Construct our model by instantiating the class defined above. model = TwoLayerNet(D_in, H, D_out) # Construct our loss function and an Optimizer. The call to model.parameters() # in the SGD constructor will contain the learnable parameters of the two # nn.Linear modules which are members of the model. -criterion = torch.nn.MSELoss(size_average=False) +loss_fn = torch.nn.MSELoss(reduction='sum') optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) for t in range(500): # Forward pass: Compute predicted y by passing x to the model y_pred = model(x) # Compute and print loss - loss = criterion(y_pred, y) - print(t, loss.data[0]) + loss = loss_fn(y_pred, y) + print(t, loss.item()) # Zero gradients, perform a backward pass, and update the weights. optimizer.zero_grad() @@ -626,7 +641,6 @@ We can easily implement this model as a Module subclass: # Code in file nn/dynamic_net.py import random import torch -from torch.autograd import Variable class DynamicNet(torch.nn.Module): def __init__(self, D_in, H, D_out): @@ -664,16 +678,16 @@ class DynamicNet(torch.nn.Module): # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold inputs and outputs, and wrap them in Variables -x = Variable(torch.randn(N, D_in)) -y = Variable(torch.randn(N, D_out), requires_grad=False) +# Create random Tensors to hold inputs and outputs. +x = torch.randn(N, D_in) +y = torch.randn(N, D_out) # Construct our model by instantiating the class defined above model = DynamicNet(D_in, H, D_out) # Construct our loss function and an Optimizer. Training this strange model with # vanilla stochastic gradient descent is tough, so we use momentum -criterion = torch.nn.MSELoss(size_average=False) +criterion = torch.nn.MSELoss(reduction='sum') optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) for t in range(500): # Forward pass: Compute predicted y by passing x to the model @@ -681,7 +695,7 @@ for t in range(500): # Compute and print loss loss = criterion(y_pred, y) - print(t, loss.data[0]) + print(t, loss.item()) # Zero gradients, perform a backward pass, and update the weights. optimizer.zero_grad() diff --git a/README_raw.md b/README_raw.md index b81f5c4..8c1fbed 100644 --- a/README_raw.md +++ b/README_raw.md @@ -11,10 +11,15 @@ will have a single hidden layer, and will be trained with gradient descent to fit random data by minimizing the Euclidean distance between the network output and the true output. +**NOTE:** These examples have been update for PyTorch 0.4, which made several +major changes to the core PyTorch API. Most notably, prior to 0.4 Tensors had +to be wrapped in Variable objects to use autograd; this functionality has now +been added directly to Tensors, and Variables are now deprecated. + ### Table of Contents - Warm-up: numpy - PyTorch: Tensors -- PyTorch: Variables and autograd +- PyTorch: Autograd - PyTorch: Defining new autograd functions - TensorFlow: Static Graphs - PyTorch: nn @@ -46,24 +51,24 @@ unfortunately numpy won't be enough for modern deep learning. Here we introduce the most fundamental PyTorch concept: the **Tensor**. A PyTorch Tensor is conceptually identical to a numpy array: a Tensor is an n-dimensional -array, and PyTorch provides many functions for operating on these Tensors. Like -numpy arrays, PyTorch Tensors do not know anything about deep learning or -computational graphs or gradients; they are a generic tool for scientific +array, and PyTorch provides many functions for operating on these Tensors. +Any computation you might want to perform with numpy can also be accomplished +with PyTorch Tensors; you should think of them as a generic tool for scientific computing. However unlike numpy, PyTorch Tensors can utilize GPUs to accelerate their -numeric computations. To run a PyTorch Tensor on GPU, you simply need to cast it -to a new datatype. +numeric computations. To run a PyTorch Tensor on GPU, you use the `device` +argument when constructing a Tensor to place the Tensor on a GPU. Here we use PyTorch Tensors to fit a two-layer network to random data. Like the -numpy example above we need to manually implement the forward and backward -passes through the network: +numpy example above we manually implement the forward and backward +passes through the network, using operations on PyTorch Tensors: ```python :INCLUDE tensor/two_layer_net_tensor.py ``` -## PyTorch: Variables and autograd +## PyTorch: Autograd In the above examples, we had to manually implement both the forward and backward passes of our neural network. Manually implementing the backward pass @@ -79,18 +84,21 @@ When using autograd, the forward pass of your network will define a functions that produce output Tensors from input Tensors. Backpropagating through this graph then allows you to easily compute gradients. -This sounds complicated, it's pretty simple to use in practice. We wrap our -PyTorch Tensors in **Variable** objects; a Variable represents a node in a -computational graph. If `x` is a Variable then `x.data` is a Tensor, and -`x.grad` is another Variable holding the gradient of `x` with respect to some -scalar value. - -PyTorch Variables have the same API as PyTorch Tensors: (almost) any operation -that you can perform on a Tensor also works on Variables; the difference is that -using Variables defines a computational graph, allowing you to automatically -compute gradients. - -Here we use PyTorch Variables and autograd to implement our two-layer network; +This sounds complicated, it's pretty simple to use in practice. If we want to +compute gradients with respect to some Tensor, then we set `requires_grad=True` +when constructing that Tensor. Any PyTorch operations on that Tensor will cause +a computational graph to be constructed, allowing us to later perform backpropagation +through the graph. If `x` is a Tensor with `requires_grad=True`, then after +backpropagation `x.grad` will be another Tensor holding the gradient of `x` with +respect to some scalar value. + +Sometimes you may wish to prevent PyTorch from building computational graphs when +performing certain operations on Tensors with `requires_grad=True`; for example +we usually don't want to backpropagate through the weight update steps when +training a neural network. In such scenarios we can use the `torch.no_grad()` +context manager to prevent the construction of a computational graph. + +Here we use PyTorch Tensors and autograd to implement our two-layer network; now we no longer need to manually implement the backward pass through the network: @@ -108,7 +116,7 @@ with respect to that same scalar value. In PyTorch we can easily define our own autograd operator by defining a subclass of `torch.autograd.Function` and implementing the `forward` and `backward` functions. We can then use our new autograd operator by constructing an instance and calling it -like a function, passing Variables containing input data. +like a function, passing Tensors containing input data. In this example we define our own custom autograd function for performing the ReLU nonlinearity, and use it to implement our two-layer network: @@ -168,8 +176,8 @@ raw computational graphs that are useful for building neural networks. In PyTorch, the `nn` package serves this same purpose. The `nn` package defines a set of **Modules**, which are roughly equivalent to neural network layers. A Module receives -input Variables and computes output Variables, but may also hold internal state such as -Variables containing learnable parameters. The `nn` package also defines a set of useful +input Tensors and computes output Tensors, but may also hold internal state such as +Tensors containing learnable parameters. The `nn` package also defines a set of useful loss functions that are commonly used when training neural networks. In this example we use the `nn` package to implement our two-layer network: @@ -180,8 +188,8 @@ In this example we use the `nn` package to implement our two-layer network: ## PyTorch: optim -Up to this point we have updated the weights of our models by manually mutating the -`.data` member for Variables holding learnable parameters. This is not a huge burden +Up to this point we have updated the weights of our models by manually mutating +Tensors holding learnable parameters. This is not a huge burden for simple optimization algorithms like stochastic gradient descent, but in practice we often train neural networks using more sophisiticated optimizers like AdaGrad, RMSProp, Adam, etc. @@ -200,8 +208,8 @@ will optimize the model using the Adam algorithm provided by the `optim` package ## PyTorch: Custom nn Modules Sometimes you will want to specify models that are more complex than a sequence of existing Modules; for these cases you can define your own Modules by subclassing -`nn.Module` and defining a `forward` which receives input Variables and produces -output Variables using other modules or other autograd operations on Variables. +`nn.Module` and defining a `forward` which receives input Tensors and produces +output Tensors using other modules or other autograd operations on Tensors. In this example we implement our two-layer network as a custom Module subclass: diff --git a/autograd/two_layer_net_autograd.py b/autograd/two_layer_net_autograd.py index ad9f4fa..2a5bb7f 100644 --- a/autograd/two_layer_net_autograd.py +++ b/autograd/two_layer_net_autograd.py @@ -1,68 +1,65 @@ import torch -from torch.autograd import Variable """ A fully-connected ReLU network with one hidden layer and no biases, trained to predict y from x by minimizing squared Euclidean distance. This implementation computes the forward pass using operations on PyTorch -Variables, and uses PyTorch autograd to compute gradients. +Tensors, and uses PyTorch autograd to compute gradients. -A PyTorch Variable is a wrapper around a PyTorch Tensor, and represents a node -in a computational graph. If x is a Variable then x.data is a Tensor giving its -value, and x.grad is another Variable holding the gradient of x with respect to -some scalar value. - -PyTorch Variables have the same API as PyTorch tensors: (almost) any operation -you can do on a Tensor you can also do on a Variable; the difference is that -autograd allows you to automatically compute gradients. +When we create a PyTorch Tensor with requires_grad=True, then operations +involving that Tensor will not just compute values; they will also build up +a computational graph in the background, allowing us to easily backpropagate +through the graph to compute gradients of some downstream (scalar) loss with +respect to a Tensor. Concretely if x is a Tensor with x.requires_grad == True +then after backpropagation x.grad will be another Tensor holding the gradient +of x with respect to some scalar value. """ -dtype = torch.FloatTensor -# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU +device = torch.device('cpu') +# device = torch.device('cuda') # Uncomment this to run on GPU # N is batch size; D_in is input dimension; # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold input and outputs, and wrap them in Variables. -# Setting requires_grad=False indicates that we do not need to compute gradients -# with respect to these Variables during the backward pass. -x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False) -y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False) +# Create random Tensors to hold input and outputs +x = torch.randn(N, D_in, device=device) +y = torch.randn(N, D_out, device=device) -# Create random Tensors for weights, and wrap them in Variables. -# Setting requires_grad=True indicates that we want to compute gradients with -# respect to these Variables during the backward pass. -w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True) -w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True) +# Create random Tensors for weights; setting requires_grad=True means that we +# want to compute gradients for these Tensors during the backward pass. +w1 = torch.randn(D_in, H, device=device, requires_grad=True) +w2 = torch.randn(H, D_out, device=device, requires_grad=True) learning_rate = 1e-6 for t in range(500): - # Forward pass: compute predicted y using operations on Variables; these - # are exactly the same operations we used to compute the forward pass using - # Tensors, but we do not need to keep references to intermediate values since - # we are not implementing the backward pass by hand. + # Forward pass: compute predicted y using operations on Tensors. Since w1 and + # w2 have requires_grad=True, operations involving these Tensors will cause + # PyTorch to build a computational graph, allowing automatic computation of + # gradients. Since we are no longer implementing the backward pass by hand we + # don't need to keep references to intermediate values. y_pred = x.mm(w1).clamp(min=0).mm(w2) - # Compute and print loss using operations on Variables. - # Now loss is a Variable of shape (1,) and loss.data is a Tensor of shape - # (1,); loss.data[0] is a scalar value holding the loss. + # Compute and print loss. Loss is a Tensor of shape (), and loss.item() + # is a Python number giving its value. loss = (y_pred - y).pow(2).sum() - print(t, loss.data[0]) - - # Manually zero the gradients before running the backward pass - w1.grad.data.zero_() - w2.grad.data.zero_() + print(t, loss.item()) # Use autograd to compute the backward pass. This call will compute the - # gradient of loss with respect to all Variables with requires_grad=True. - # After this call w1.grad and w2.grad will be Variables holding the gradient + # gradient of loss with respect to all Tensors with requires_grad=True. + # After this call w1.grad and w2.grad will be Tensors holding the gradient # of the loss with respect to w1 and w2 respectively. loss.backward() - # Update weights using gradient descent; w1.data and w2.data are Tensors, - # w1.grad and w2.grad are Variables and w1.grad.data and w2.grad.data are - # Tensors. - w1.data -= learning_rate * w1.grad.data - w2.data -= learning_rate * w2.grad.data + # Update weights using gradient descent. For this step we just want to mutate + # the values of w1 and w2 in-place; we don't want to build up a computational + # graph for the update steps, so we use the torch.no_grad() context manager + # to prevent PyTorch from building a computational graph for the updates + with torch.no_grad(): + w1 -= learning_rate * w1.grad + w2 -= learning_rate * w2.grad + + # Manually zero the gradients after running the backward pass + w1.grad.zero_() + w2.grad.zero_() diff --git a/autograd/two_layer_net_custom_function.py b/autograd/two_layer_net_custom_function.py index ef75c09..6c768d1 100644 --- a/autograd/two_layer_net_custom_function.py +++ b/autograd/two_layer_net_custom_function.py @@ -1,12 +1,11 @@ import torch -from torch.autograd import Variable """ A fully-connected ReLU network with one hidden layer and no biases, trained to predict y from x by minimizing squared Euclidean distance. This implementation computes the forward pass using operations on PyTorch -Variables, and uses PyTorch autograd to compute gradients. +Tensors, and uses PyTorch autograd to compute gradients. In this implementation we implement our own custom autograd function to perform the ReLU function. @@ -18,62 +17,65 @@ class MyReLU(torch.autograd.Function): torch.autograd.Function and implementing the forward and backward passes which operate on Tensors. """ - def forward(self, input): + @staticmethod + def forward(ctx, x): """ - In the forward pass we receive a Tensor containing the input and return a - Tensor containing the output. You can cache arbitrary Tensors for use in the - backward pass using the save_for_backward method. + In the forward pass we receive a context object and a Tensor containing the + input; we must return a Tensor containing the output, and we can use the + context object to cache objects for use in the backward pass. """ - self.save_for_backward(input) - return input.clamp(min=0) + ctx.save_for_backward(x) + return x.clamp(min=0) - def backward(self, grad_output): + @staticmethod + def backward(ctx, grad_output): """ - In the backward pass we receive a Tensor containing the gradient of the loss - with respect to the output, and we need to compute the gradient of the loss - with respect to the input. + In the backward pass we receive the context object and a Tensor containing + the gradient of the loss with respect to the output produced during the + forward pass. We can retrieve cached data from the context object, and must + compute and return the gradient of the loss with respect to the input to the + forward function. """ - input, = self.saved_tensors - grad_input = grad_output.clone() - grad_input[input < 0] = 0 - return grad_input + x, = ctx.saved_tensors + grad_x = grad_output.clone() + grad_x[x < 0] = 0 + return grad_x -dtype = torch.FloatTensor -# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU +device = torch.device('cpu') +# device = torch.device('cuda') # Uncomment this to run on GPU # N is batch size; D_in is input dimension; # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold input and outputs, and wrap them in Variables. -x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False) -y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False) +# Create random Tensors to hold input and output +x = torch.randn(N, D_in, device=device) +y = torch.randn(N, D_out, device=device) -# Create random Tensors for weights, and wrap them in Variables. -w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True) -w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True) +# Create random Tensors for weights. +w1 = torch.randn(D_in, H, device=device, requires_grad=True) +w2 = torch.randn(H, D_out, device=device, requires_grad=True) learning_rate = 1e-6 for t in range(500): - # Construct an instance of our MyReLU class to use in our network - relu = MyReLU() - - # Forward pass: compute predicted y using operations on Variables; we compute - # ReLU using our custom autograd operation. - y_pred = relu(x.mm(w1)).mm(w2) - + # Forward pass: compute predicted y using operations on Tensors; we call our + # custom ReLU implementation using the MyReLU.apply function + y_pred = MyReLU.apply(x.mm(w1)).mm(w2) + # Compute and print loss loss = (y_pred - y).pow(2).sum() - print(t, loss.data[0]) - - # Manually zero the gradients before running the backward pass - w1.grad.data.zero_() - w2.grad.data.zero_() + print(t, loss.item()) # Use autograd to compute the backward pass. loss.backward() - # Update weights using gradient descent - w1.data -= learning_rate * w1.grad.data - w2.data -= learning_rate * w2.grad.data + with torch.no_grad(): + # Update weights using gradient descent + w1 -= learning_rate * w1.grad + w2 -= learning_rate * w2.grad + + # Manually zero the gradients after running the backward pass + w1.grad.zero_() + w2.grad.zero_() + diff --git a/nn/dynamic_net.py b/nn/dynamic_net.py index a2e1714..ce4b4a0 100644 --- a/nn/dynamic_net.py +++ b/nn/dynamic_net.py @@ -1,6 +1,5 @@ import random import torch -from torch.autograd import Variable """ To showcase the power of PyTorch dynamic graphs, we will implement a very strange @@ -45,16 +44,16 @@ def forward(self, x): # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold inputs and outputs, and wrap them in Variables -x = Variable(torch.randn(N, D_in)) -y = Variable(torch.randn(N, D_out), requires_grad=False) +# Create random Tensors to hold inputs and outputs. +x = torch.randn(N, D_in) +y = torch.randn(N, D_out) # Construct our model by instantiating the class defined above model = DynamicNet(D_in, H, D_out) # Construct our loss function and an Optimizer. Training this strange model with # vanilla stochastic gradient descent is tough, so we use momentum -criterion = torch.nn.MSELoss(size_average=False) +criterion = torch.nn.MSELoss(reduction='sum') optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) for t in range(500): # Forward pass: Compute predicted y by passing x to the model @@ -62,7 +61,7 @@ def forward(self, x): # Compute and print loss loss = criterion(y_pred, y) - print(t, loss.data[0]) + print(t, loss.item()) # Zero gradients, perform a backward pass, and update the weights. optimizer.zero_grad() diff --git a/nn/two_layer_net_module.py b/nn/two_layer_net_module.py index feb6075..e86127e 100644 --- a/nn/two_layer_net_module.py +++ b/nn/two_layer_net_module.py @@ -1,5 +1,4 @@ import torch -from torch.autograd import Variable """ A fully-connected ReLU network with one hidden layer, trained to predict y from x @@ -22,38 +21,37 @@ def __init__(self, D_in, H, D_out): def forward(self, x): """ - In the forward function we accept a Variable of input data and we must return - a Variable of output data. We can use Modules defined in the constructor as - well as arbitrary operators on Variables. + In the forward function we accept a Tensor of input data and we must return + a Tensor of output data. We can use Modules defined in the constructor as + well as arbitrary (differentiable) operations on Tensors. """ h_relu = self.linear1(x).clamp(min=0) y_pred = self.linear2(h_relu) return y_pred - # N is batch size; D_in is input dimension; # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold inputs and outputs, and wrap them in Variables -x = Variable(torch.randn(N, D_in)) -y = Variable(torch.randn(N, D_out), requires_grad=False) +# Create random Tensors to hold inputs and outputs +x = torch.randn(N, D_in) +y = torch.randn(N, D_out) -# Construct our model by instantiating the class defined above +# Construct our model by instantiating the class defined above. model = TwoLayerNet(D_in, H, D_out) # Construct our loss function and an Optimizer. The call to model.parameters() # in the SGD constructor will contain the learnable parameters of the two # nn.Linear modules which are members of the model. -criterion = torch.nn.MSELoss(size_average=False) +loss_fn = torch.nn.MSELoss(reduction='sum') optimizer = torch.optim.SGD(model.parameters(), lr=1e-4) for t in range(500): # Forward pass: Compute predicted y by passing x to the model y_pred = model(x) # Compute and print loss - loss = criterion(y_pred, y) - print(t, loss.data[0]) + loss = loss_fn(y_pred, y) + print(t, loss.item()) # Zero gradients, perform a backward pass, and update the weights. optimizer.zero_grad() diff --git a/nn/two_layer_net_nn.py b/nn/two_layer_net_nn.py index f75fa40..ec4f897 100644 --- a/nn/two_layer_net_nn.py +++ b/nn/two_layer_net_nn.py @@ -1,5 +1,4 @@ import torch -from torch.autograd import Variable """ A fully-connected ReLU network with one hidden layer, trained to predict y from x @@ -10,54 +9,64 @@ but raw autograd can be a bit too low-level for defining complex neural networks; this is where the nn package can help. The nn package defines a set of Modules, which you can think of as a neural network layer that has produces output from -input and may have some trainable weights. +input and may have some trainable weights or other state. """ +device = torch.device('cpu') +# device = torch.device('cuda') # Uncomment this to run on GPU + # N is batch size; D_in is input dimension; # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold inputs and outputs, and wrap them in Variables. -x = Variable(torch.randn(N, D_in)) -y = Variable(torch.randn(N, D_out), requires_grad=False) +# Create random Tensors to hold inputs and outputs +x = torch.randn(N, D_in, device=device) +y = torch.randn(N, D_out, device=device) # Use the nn package to define our model as a sequence of layers. nn.Sequential # is a Module which contains other Modules, and applies them in sequence to # produce its output. Each Linear Module computes output from input using a -# linear function, and holds internal Variables for its weight and bias. +# linear function, and holds internal Tensors for its weight and bias. +# After constructing the model we use the .to() method to move it to the +# desired device. model = torch.nn.Sequential( torch.nn.Linear(D_in, H), torch.nn.ReLU(), torch.nn.Linear(H, D_out), - ) + ).to(device) # The nn package also contains definitions of popular loss functions; in this -# case we will use Mean Squared Error (MSE) as our loss function. -loss_fn = torch.nn.MSELoss(size_average=False) +# case we will use Mean Squared Error (MSE) as our loss function. Setting +# reduction='sum' means that we are computing the *sum* of squared errors rather +# than the mean; this is for consistency with the examples above where we +# manually compute the loss, but in practice it is more common to use mean +# squared error as a loss by setting reduction='elementwise_mean'. +loss_fn = torch.nn.MSELoss(reduction='sum') learning_rate = 1e-4 for t in range(500): # Forward pass: compute predicted y by passing x to the model. Module objects # override the __call__ operator so you can call them like functions. When - # doing so you pass a Variable of input data to the Module and it produces - # a Variable of output data. + # doing so you pass a Tensor of input data to the Module and it produces + # a Tensor of output data. y_pred = model(x) - # Compute and print loss. We pass Variables containing the predicted and true - # values of y, and the loss function returns a Variable containing the loss. + # Compute and print loss. We pass Tensors containing the predicted and true + # values of y, and the loss function returns a Tensor containing the loss. loss = loss_fn(y_pred, y) - print(t, loss.data[0]) + print(t, loss.item()) # Zero the gradients before running the backward pass. model.zero_grad() # Backward pass: compute gradient of the loss with respect to all the learnable # parameters of the model. Internally, the parameters of each Module are stored - # in Variables with requires_grad=True, so this call will compute gradients for + # in Tensors with requires_grad=True, so this call will compute gradients for # all learnable parameters in the model. loss.backward() - # Update the weights using gradient descent. Each parameter is a Variable, so + # Update the weights using gradient descent. Each parameter is a Tensor, so # we can access its data and gradients like we did before. - for param in model.parameters(): - param.data -= learning_rate * param.grad.data + with torch.no_grad(): + for param in model.parameters(): + param.data -= learning_rate * param.grad diff --git a/nn/two_layer_net_optim.py b/nn/two_layer_net_optim.py index fce735b..84a7f2e 100644 --- a/nn/two_layer_net_optim.py +++ b/nn/two_layer_net_optim.py @@ -1,5 +1,4 @@ import torch -from torch.autograd import Variable """ A fully-connected ReLU network with one hidden layer, trained to predict y from x @@ -17,9 +16,9 @@ # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 -# Create random Tensors to hold inputs and outputs, and wrap them in Variables. -x = Variable(torch.randn(N, D_in)) -y = Variable(torch.randn(N, D_out), requires_grad=False) +# Create random Tensors to hold inputs and outputs. +x = torch.randn(N, D_in) +y = torch.randn(N, D_out) # Use the nn package to define our model and loss function. model = torch.nn.Sequential( @@ -27,12 +26,12 @@ torch.nn.ReLU(), torch.nn.Linear(H, D_out), ) -loss_fn = torch.nn.MSELoss(size_average=False) +loss_fn = torch.nn.MSELoss(reduction='sum') # Use the optim package to define an Optimizer that will update the weights of # the model for us. Here we will use Adam; the optim package contains many other # optimization algoriths. The first argument to the Adam constructor tells the -# optimizer which Variables it should update. +# optimizer which Tensors it should update. learning_rate = 1e-4 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) for t in range(500): @@ -41,10 +40,10 @@ # Compute and print loss. loss = loss_fn(y_pred, y) - print(t, loss.data[0]) + print(t, loss.item()) # Before the backward pass, use the optimizer object to zero all of the - # gradients for the variables it will update (which are the learnable weights + # gradients for the Tensors it will update (which are the learnable weights # of the model) optimizer.zero_grad() diff --git a/tensor/two_layer_net_tensor.py b/tensor/two_layer_net_tensor.py index 0624ebd..72d27b2 100644 --- a/tensor/two_layer_net_tensor.py +++ b/tensor/two_layer_net_tensor.py @@ -13,23 +13,24 @@ The biggest difference between a numpy array and a PyTorch Tensor is that a PyTorch Tensor can run on either CPU or GPU. To run operations on the GPU, -just cast the Tensor to a cuda datatype. +just pass a different value to the `device` argument when constructing the +Tensor. """ -dtype = torch.FloatTensor -# dtype = torch.cuda.FloatTensor # Uncomment this to run on GPU +device = torch.device('cpu') +# device = torch.device('cuda') # Uncomment this to run on GPU # N is batch size; D_in is input dimension; # H is hidden dimension; D_out is output dimension. N, D_in, H, D_out = 64, 1000, 100, 10 # Create random input and output data -x = torch.randn(N, D_in).type(dtype) -y = torch.randn(N, D_out).type(dtype) +x = torch.randn(N, D_in, device=device) +y = torch.randn(N, D_out, device=device) # Randomly initialize weights -w1 = torch.randn(D_in, H).type(dtype) -w2 = torch.randn(H, D_out).type(dtype) +w1 = torch.randn(D_in, H, device=device) +w2 = torch.randn(H, D_out, device=device) learning_rate = 1e-6 for t in range(500): @@ -38,9 +39,10 @@ h_relu = h.clamp(min=0) y_pred = h_relu.mm(w2) - # Compute and print loss + # Compute and print loss; loss is a scalar, and is stored in a PyTorch Tensor + # of shape (); we can get its value as a Python number with loss.item(). loss = (y_pred - y).pow(2).sum() - print(t, loss) + print(t, loss.item()) # Backprop to compute gradients of w1 and w2 with respect to loss grad_y_pred = 2.0 * (y_pred - y)