For the remainder of this course, we will be using the PyTorch
library to implement virtually all the use PyTorch for the remainder of
the course. The library is chosen for its automatic differentiation
capabilities, which are essential for implementing neural network
architectures. It also contains the nn.Module
class, a nice
abstraction for building self-contained layers that can easily enumerate
parameters.
Automatic differentiation is a key feature in PyTorch, facilitated by
the use of tensors. Tensors are the fundamental data structures in
PyTorch, capable of representing multidimensional arrays. They can be
created using various functions such as torch.zeros()
,
which creates a tensor filled with zeros, or torch.randn()
,
which creates a tensor with entries drawn from a standard normal
distribution. The dimensions of the tensor are specified by passing a
list of sizes to these functions. For example,
torch.randn(3, 4)
creates a 3x4 matrix with entries from a
normal distribution.
To enable gradient accumulation, a tensor’s
requires_grad
attribute must be set to true
.
This can be done by setting requires_grad=True
when
creating a tensor or by calling the requires_grad_()
method
on an existing tensor. In PyTorch, methods with an underscore at the
end, such as requires_grad_()
, modify the tensor in
place.
When a tensor is marked to require gradients, it is flagged to
accumulate gradients through operations. For example, if tensor
A
requires gradients and tensor W
does not,
and they are multiplied element-wise and then summed to produce a scalar
F
, the gradient of A
with respect to
F
can be computed. This is done by calling the
backward()
method on F
, which populates the
.grad
attribute of all tensors in the computational graph
that have requires_grad
set to true
.
The gradient of A
after calling
F.backward()
will be equal to W
, assuming
A
and W
were multiplied element-wise and then
summed. This is because the gradient of the sum with respect to each
element of A
is simply the corresponding element of
W
.
The following code illustrates this usage of automatic differentiation within PyTorch.
= torch.randn(5,4).requires_grad_()
A = torch.randn(5,4)
W = (A*W).sum()
f
f.backward()print(A.grad - W)
In PyTorch, functions are implemented using a class called
Function
from the torch.autograd
library. This
class allows users to define custom operations with both forward and
backward passes. The forward pass computes the operation, while the
backward pass computes the gradients through the vector-Jacobian
product.
To create a custom function, one must define a class that inherits
from Function
and implement two static methods:
forward
and backward
. The forward
method computes the output of the function given the inputs, and the
backward
method computes the gradient of the function with
respect to its inputs, given the gradient of the output.
This mechanism is essential for building complex neural networks with custom operations, as it seamlessly integrates with PyTorch’s automatic differentiation system. However, most common functions already have their forward and backward passes implemented in PyTorch, so users rarely need to define their own unless they require a specialized or more efficient version of an operation.
The following implementation would implement a basic matrix multiplication function with PyTroch.
from torch.autograd import Function
class MatMul(Function):
@staticmethod
def forward(context, X, Y):
# X: m x n, Y: n x p
context.save_for_backward(X,Y)return X@Y
@staticmethod
def backward(context, grad):
# grad: m x p
= context.saved_tensors
X,Y return grad @ Y.T, X.T @ grad
The forward function takes two arguments, X
and
Y
, along with a context
variable. The context
is used to store additional information required for the computation
graph. Specifically, the context.save_for_backward
method
is used to save the inputs for later use in the backward pass.
The backward function also takes the context and the incoming
backward gradient (grad
) as arguments. It retrieves the
saved tensors X
and Y
from the context and
computes the vector Jacobian product with respect to X
and
Y
, as derived in the previous set of notes.
Gradient checking is a crucial step in verifying the correctness of
the gradients computed by the backward function. PyTorch provides a
utility function called gradcheck
for this purpose. It is
recommended to use double precision inputs when using
gradcheck
to avoid errors due to finite differencing in
single precision.
The gradcheck
function can be used as follows:
from torch.autograd import gradcheck
= torch.randn(5,4, dtype=torch.double).requires_grad_()
A = torch.randn(4,3, dtype=torch.double).requires_grad_()
B apply, (A,B)) gradcheck(MatMul.
If the gradients are correct, gradcheck
will return
True
. Otherwise, it will raise an error indicating that the
gradients are incorrect.
In practice, PyTorch has built-in functions for common operations,
and the need to define custom forward and backward functions arises only
when dealing with new functions not yet implemented in PyTorch. When
implementing new functions, it is essential to ensure that the forward
and backward passes are correctly defined and that the gradients are
accurate. The gradcheck
utility is invaluable in this
process, providing a means to numerically verify the correctness of the
gradients.
PyTorch modules are a key concept for encapsulating computations and
parameters. They are particularly useful for defining neural network
layers with learnable parameters. The Module
class from
torch.nn
is the base class for all neural network modules,
and it is common to subclass it to create custom layers or models.
A linear layer, also known as a fully connected layer, is one of the
simplest and most commonly used types of layers in neural networks. It
applies a linear transformation to the incoming data. In PyTorch, this
can be implemented by subclassing the Module
class.
The linear layer is defined by two main parameters: the input
dimension (in_dim
) and the output dimension
(out_dim
). Additionally, a bias term can be included, which
is typically initialized to True
. The weights of the layer
are represented by a matrix, and the bias is a vector if it is used.
The initialization of the weights is an important step that can
significantly affect the performance of the neural network. A common
practice is to initialize the weights as random normal variables, scaled
by the square root of the input dimension, and multiplied by an
initialization factor. This factor is often set to 2.0
for
networks using ReLU activation functions.
Here is the example of how to define a custom linear layer in PyTorch:
from torch.nn import Module, Parameter
class Linear(Module):
def __init__(self, in_dim, out_dim, bias=True, init_factor=2.0):
super().__init__()
self.weight = Parameter(torch.randn(in_dim, out_dim) * np.sqrt(init_factor / in_dim))
if bias:
self.bias = Parameter(torch.zeros(out_dim))
else:
self.bias = None
def forward(self, X):
= X @ self.weight
out if self.bias is not None:
+= self.bias[..., :]
out return out
In this implementation, the Parameter
class is used to
wrap the tensors that should be considered as parameters of the layer.
This allows PyTorch to track gradients for these tensors during the
training process.
The forward
method applies the linear transformation to
the input X
using the weight matrix and adds the bias if it
is present.
By defining the forward pass and initializing the parameters, the backward pass (gradient computation) is automatically handled by PyTorch’s autograd system, provided that the operations used in the forward pass are differentiable and supported by autograd.
The ReLU (Rectified Linear Unit) layer is implemented as a module
without any parameters, thus not requiring an initialization method. The
forward pass function takes an input X
and returns the
element-wise maximum between X
and a tensor of zeros,
effectively applying the ReLU activation function.
class ReLU(Module):
def forward(self, X):
return torch.maximum(X, torch.tensor(0.))
The cross-entropy loss is implemented as a module with a forward pass
function that takes predictions H
and target labels
Y
. The loss is computed by indexing into the predictions
H
using the target labels Y
and taking the
negative log likelihood. Additionally, the log sum exponent of the
predictions is computed and added to the loss. The mean of the loss is
returned.
class CrossEntropyLoss(Module):
def forward(self, H, Y):
return -H[torch.arange(len(Y)),Y].mean() + torch.logsumexp(H, -1).mean()
Finally, we can now define a two-layer neural network as a
Module with
an initialization method that sets up two
linear layers (self.linear1
and self.linear2
)
and a ReLU activation (self.relu
). The first linear layer
maps the input dimension to the hidden dimension, while the second
linear layer maps the hidden dimension to the output dimension. The ReLU
activation does not require any parameters.
The forward pass of the two-layer neural network takes an input
X
and applies the first linear layer, followed by the ReLU
activation, and then the second linear layer. The output of the second
linear layer is returned as the final output of the network.
class TwoLayerNN(Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.linear1 = Linear(in_dim, hidden_dim)
self.linear2 = Linear(hidden_dim, out_dim, init_factor=1.0)
self.relu = ReLU()
def forward(self, X):
return self.linear2(self.relu(self.linear1(X)))
There are a new items to highlight in this implementation
Parameter
utility class is used to wrap tensors
that should be considered as parameters of a layer, automatically
setting require_grad
to true
.Module
class from PyTorch’s torch.nn
is used as the base class for defining custom layers and networks.parameters()
method of a module returns a generator
that can be converted to a list containing all the parameters of the
module, which are tensors with require_grad
set to
true
.nn
library
functions, except for Module
and a few others, to
demonstrate the underlying mechanics. ==> audio_chunk_6.md
<==While torch has a number of built-in optimizers, to make this element more explicit, we can define our own class to perform SGD as follows, which roughly mirrors the structure of the optimizers within PyTorch
class SGD:
def __init__(self, params, lr=1.0):
self.params = list(params)
self.lr = lr
def step(self):
with torch.no_grad():
for param in self.params:
-= self.lr * param.grad
param
def zero_grad(self):
with torch.no_grad():
for param in self.params:
if param.grad is not None:
param.grad.zero_()
The optimizer takes a list of parameters to be updated (then converts
it to a list, to handle generator objects). The step()
method is wrapped in a with torch.no_grad()
block to ensure
that the parameter updates do not track gradients, which would
unnecessarily increase memory usage and computation. This is crucial for
efficient training, as it prevents the entire history of parameter
updates from being stored in the computation graph. The
zero_grad()
sets all the gradients for parameters updated
by optimizer to zero, which is important to call because PyTorch
accumulates gradients each time you call .backward()
.
We now have all the elements we need to implement an MNIST classifier in PyTroch. We will use the same method for iterating over data as before, though later will use PyTorch wrappers for this. An function that performs a single epoch of training can be written as follows:
def epoch(model, X_full, Y_full, opt=None, batch_size=100):
= 0., 0., 0
mean_err, mean_loss, batches
for X,Y in zip(X_full.split(batch_size), Y_full.split(batch_size)):
= model(X)
H = CrossEntropyLoss()(H, Y)
loss += loss_01(H, Y).item()
mean_err += loss.item()
mean_loss += 1
batches
if opt:
opt.zero_grad()
loss.backward()
opt.step()return mean_err / batches, mean_loss / batches
The epoch function is defined to take several parameters: a model,
full datasets X_full
and Y_full
, and an
optimizer. The function is designed to be flexible, allowing for the
possibility of not updating parameters if no optimizer is passed by
setting the optimizer to None
. This feature is particularly
useful for computing test error without training the network.
The training process begins by initializing error and loss accumulators to zero. The function then iterates over the entire dataset in batches, applying the model to each batch to generate hypotheses. For each batch, the cross-entropy loss is computed between the model’s output and the true labels. Additionally, the 0-1 loss, which represents the mean error, is also computed and accumulated.
If an optimizer is provided, the gradients are zeroed out before
backpropagation with loss.backward()
. After
backpropagation, the optimizer’s step
function is called to
update the model’s parameters. This process is a common usage pattern in
PyTorch for training models.
After processing all batches, the function returns the mean error and mean loss, both divided by the number of batches. This accounts for the average performance over the entire dataset. However, it is noted that the last batch might be smaller and thus could be counted more, but this detail is not addressed in the current implementation.
We can use this code to train a two-layer neural network in PyTorch.
The network is trained on a training dataset (X_train
and
Y_train
) for a specified number of epochs. After training,
the network is evaluated on a test dataset (X_test
and
Y_test
) without an optimizer to compute the test error and
loss.
= MNIST('.', train=True, download=True)
dataset = dataset.data.reshape(60000,784)/255.
X_train = dataset.targets
Y_train
= MNIST('.', train=False, download=True)
dataset_test = dataset_test.data.reshape(10000,784)/255.
X_test = dataset_test.targets
Y_test
= TwoLayerNN(784, 100, 10)
model = SGD(model.parameters(), lr=0.5)
opt for i in range(20):
epoch(model, X_train, Y_train, opt)
If we don’t want to define a new subclass for every new network
architecture, a convenient tool is the Seuential
module.
This module simplifies the process of applying one layer’s output as the
input to the next layer, making it a convenient tool for building neural
networks. We can implement it as follows:
from torch.nn import ModuleList
class Sequential(Module):
def __init__(self, *layers):
super().__init__()
self.layers = ModuleList(layers)
def forward(self, X):
= X
out for layer in self.layers:
= layer(out)
out return out
In this definition, *layers
is an argument in PyTorch
that allows passing a list of items to a function argument. This is a
common Python feature known as argument unpacking, which is used here to
pass multiple layer instances to the sequential module. The forward pass
of a sequential module is implemented by iterating over each layer and
applying it to the input. The output of one layer becomes the input to
the next. In this code snippet, x
is the input to the first
layer, and out
is the output of the last layer, which is
returned as the final output of the sequential module.
= Sequential(
model 784, 100),
Linear(
ReLU(),100, 10)
Linear( )
There is a subtlety when using the sequential module as defined
above. To ensure that the layers are properly registered as submodules
so that their parameters are recognized by PyTorch’s infrastructure, we
have to use the ModuleList
class, so that elements in
self.layers
are recognized as modules (and will thus have
their parameters also propagated to the parameters of the
Sequential
object).