In this tutorial, we will learn how to build Residual Neural Networks (ResNets) from scratch using PyTorch. We will create a generalized pipeline that can build all the ResNets from the paper. These include ResNet18, ResNet34, ResNet50, ResNet101, and ResNet152.
In the last two tutorials, we covered building ResNet18 from scratch and training it on the CIFAR10 dataset.
You can find the two posts here.
While we were trying to simplify creating ResNet18 from scratch using PyTorch, we did not include the code for building the other architectures from the ResNet family. In this tutorial, we will also extend the same code and build the other ResNets.
We will cover the following points in this tutorial:
- Discuss how the architectures of ResNets change from ResNet18 to ResNet152
- Discuss what kind of code changes we need.
- Write the code for building ResNets from scratch using PyTorch.
- Verify the architectures by forward passing a dummy tensor and checking the number of parameters.
If you are new to ResNets, then going through the post which covers building ResNet18 from scratch will be helpful. Also, this post covers the ResNet paper explanation which will help you understand the architectures of ResNets.
Changes in ResNets from ResNet18 to ResNet152
When we move from ResNet18 to ResNet152, each model has its own attributes. The following are a few points to keep in mind:
- The number of Basic Blocks that affect the number of convolutional layers.
- The expansion factor affects the output channels in the convolutional layers.
- And the total number of layers in the entire network.
The easiest way to check out the changes is to take a look at the following table from the paper.
The above image clarifies a lot of things.
For ResNet18 and ResNet34, the Basic Block consists of two stackings of 3×3 convolutional layers. Nothing special happens for these two ResNets. The only difference is the number of the Basic Blocks here. For ResNet18, it is 8, and for ResNet34, it is 16.
On the other hand, for ResNet50, ResNet101, and ResNet152, the structure of the Basic Block changes. Instead of simple 3×3 convolutional stackings, now, they have a Bottleneck structure. Just like the following.
For ResNets50/101/152, the Basic Block now consists of 1×1=>3×3=>1×1 convolutions. This Bottleneck structure is responsible for reducing the number of parameters to a great extent (reduction in 100s of millions for ResNet101 and ResNet152).
The 1×1 convolutions in the Bottleneck layers help in reducing and then restoring the dimensions. The 3×3 convolutional layer acts as the Bottleneck.
It is important to keep in mind that 1×1 and 3×3 refer to the kernel size in the convolutional layers.
Code Changes that We Need to Make
When writing ResNet18 from scratch, we already wrote the code for the Basic Block module.
If you go through the official PyTorch repository, you will observe that they have a BasicBlock class for ResNets 18 and 34 and a BottleNeck class for ResNets 50/101/152.
We will not create two different classes/modules for extending our code to the other ResNets. We will simply modify BasicBlock class that we wrote previously. As we are not trying to include other architectures like Wide ResNets and ResNeXt models, we do not need a very complex code base.
So, all the convolutional layer stackings needed for the ResNet architectures will be contained in one class only. This is also easier for getting started and easy to understand as well.
Obviously, we will need to take care of the shape and output changes. This means that there will be a few if-else
blocks. Still, it will be pretty simple.
One other important point is that the rest of the code will remain almost exactly the same. The major changes will take place in the BasicBlock
class only.
With that, let’s get into the coding section of the article.
Directory Structure
The following is the directory structure for this post.
. └── resnet.py
We have just one file in this post’s project directory.
ResNets from Scratch using PyTorch
Let’s start writing the code to build ResNets in a generalized manner. This code will build ResNet18/34/50/101/152 depending on the number of layers that we pass from the command line.
All the code here will go into the resnet.py
file.
Download Code
Starting with the imports and argument parser.
import torch.nn as nn import torch import argparse from torch import Tensor from typing import Type parser = argparse.ArgumentParser() parser.add_argument( '-n', '--num-layers', dest='num_layers', default=18, type=int, help='number of layers to build ResNet with', choices=[18, 34, 50, 101, 152] ) args = vars(parser.parse_args())
The argument parser defines one command line flag, --num-layers
. We can pass the number of layers from the command line when executing the script. This can take values in the range [18, 34, 50, 101, 152]
to build the ResNet network of our choice.
The BasicBlock for ResNets
This is an important part of the entire codebase. Building any of the ResNet models will make use of the BasicBlock
class. This defines all the convolutional layers that we see in figure 1.
Let’s write the code first, then we will get into the explanation.
The following code block contains the entire code for the BasicBlock
class.
class BasicBlock(nn.Module): """ Builds the Basic Block of the ResNet model. For ResNet18 and ResNet34, these are stackings od 3x3=>3x3 convolutional layers. For ResNet50 and above, these are stackings of 1x1=>3x3=>1x1 (BottleNeck) layers. """ def __init__( self, num_layers: int, in_channels: int, out_channels: int, stride: int = 1, expansion: int = 1, downsample: nn.Module = None ) -> None: super(BasicBlock, self).__init__() self.num_layers = num_layers # Multiplicative factor for the subsequent conv2d layer's output # channels. # It is 1 for ResNet18 and ResNet34, and 4 for the others. self.expansion = expansion self.downsample = downsample # 1x1 convolution for ResNet50 and above. if num_layers > 34: self.conv0 = nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, bias=False ) self.bn0 = nn.BatchNorm2d(out_channels) in_channels = out_channels # Common 3x3 convolution for all. self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False ) self.bn1 = nn.BatchNorm2d(out_channels) # 1x1 convolution for ResNet50 and above. if num_layers > 34: self.conv2 = nn.Conv2d( out_channels, out_channels*self.expansion, kernel_size=1, stride=1, bias=False ) self.bn2 = nn.BatchNorm2d(out_channels*self.expansion) else: # 3x3 convolution for ResNet18 and ResNet34 and above. self.conv2 = nn.Conv2d( out_channels, out_channels*self.expansion, kernel_size=3, padding=1, bias=False ) self.bn2 = nn.BatchNorm2d(out_channels*self.expansion) self.relu = nn.ReLU(inplace=True) def forward(self, x: Tensor) -> Tensor: identity = x # Through 1x1 convolution if ResNet50 or above. if self.num_layers > 34: out = self.conv0(x) out = self.bn0(out) out = self.relu(out) # Use the above output if ResNet50 and above. if self.num_layers > 34: out = self.conv1(out) # Else use the input to the `forward` method. else: out = self.conv1(x) out = self.bn1(out) out = self.relu(out) out = self.conv2(out) out = self.bn2(out) if self.downsample is not None: identity = self.downsample(x) out += identity out = self.relu(out) return out
The __init__() Method
We initialize three variables in the __init__
method of BasicBlock
.
self.num_layers
: The number of layers we want to build the ResNet model with.self.expansion
: The expansion factor for the output channels in the final convolutional layer. It is 1 for ResNets 18 and 34. And 4 for ResNet50, ResNet101, and ResNet152.self.downsample
: It is eitherNone
or a PyTorchSequential
module. We will use this for downsampling the input to the forward method.
If you go through the post where we build the ResNet18 network from scratch, then you find some differences here.
First, we have a conv0
block that gets initialized only if we are building ResNet50 or above (line 42). This is also followed by a BatchNormalization
layer. In case, this happens, the input to the next layer will have to be the output of this layer. So, that change happens on line 51. Note that this is a 1×1 convolution layer. In other words, it is part of the bottleneck block
From line 54, we have the conv1
block which consists of a 3×3 convolutional layer. This is common for all the ResNet architectures.
Starting from line 65, we again have another condition. If the ResNet that we are building has more than 34 layers (50 and above), then the final layer is a 1×1 convolutional layer which is part of the bottleneck block. Else, it is a 3×3 convolutional layer (for ResNets 18 and 34).
Notice that we are using the expansion
factor for the final layer’s output channels and the BatchNormalization
layer as well.
These are all the layers that we need to build the Residual Neural Networks. Let’s get into the forward
method now.
The forward() Method
As we can observe, the forward
method also contains some if-else
blocks. We pass the input tensor through the conv0
and bn0
layers only if the number of layers in the networks that we are building is greater than 34. In this case, the input to the next 3×3 conv1
layer will be the output from the previous layer (line 97).
If we are building ResNet18 or ResNet34, then the first forward pass happens on line 100, in which case the input to the layer is the same as the input to the forward
method.
The downsampling and identity connections happen from lines 107 to 110. These passes will happen whenever the stride is not 1 or when the input channels are not the same as the output channels multiplied by the expansion factor. The second case is true for ResNet50, ResNet101, and ResNet152 as the expansion factor here is 4 and the input channels will not be the same output channels multiplied by the expansion factor.
The ResNet Module
Now, we have reached the final part of building ResNets from scratch using PyTorch. We will be writing the code for ResNet
class that will combine everything.
Most of the things will remain the same as were in the case of building ResNet18 from scratch. A few things will be added, which we will get into.
class ResNet(nn.Module): def __init__( self, img_channels: int, num_layers: int, block: Type[BasicBlock], num_classes: int = 1000 ) -> None: super(ResNet, self).__init__() if num_layers == 18: # The following `layers` list defines the number of `BasicBlock` # to use to build the network and how many basic blocks to stack # together. layers = [2, 2, 2, 2] self.expansion = 1 if num_layers == 34: layers = [3, 4, 6, 3] self.expansion = 1 if num_layers == 50: layers = [3, 4, 6, 3] self.expansion = 4 if num_layers == 101: layers = [3, 4, 23, 3] self.expansion = 4 if num_layers == 152: layers = [3, 8, 36, 3] self.expansion = 4 self.in_channels = 64 # All ResNets (18 to 152) contain a Conv2d => BN => ReLU for the first # three layers. Here, kernel size is 7. self.conv1 = nn.Conv2d( in_channels=img_channels, out_channels=self.in_channels, kernel_size=7, stride=2, padding=3, bias=False ) self.bn1 = nn.BatchNorm2d(self.in_channels) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.layer1 = self._make_layer(block, 64, layers[0], num_layers=num_layers) self.layer2 = self._make_layer(block, 128, layers[1], stride=2, num_layers=num_layers) self.layer3 = self._make_layer(block, 256, layers[2], stride=2, num_layers=num_layers) self.layer4 = self._make_layer(block, 512, layers[3], stride=2, num_layers=num_layers) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(512*self.expansion, num_classes) def _make_layer( self, block: Type[BasicBlock], out_channels: int, blocks: int, stride: int = 1, num_layers: int = 18 ) -> nn.Sequential: downsample = None if stride != 1 or self.in_channels != out_channels * self.expansion: """ This should pass from `layer2` to `layer4` or when building ResNets50 and above. Section 3.3 of the paper Deep Residual Learning for Image Recognition (https://arxiv.org/pdf/1512.03385v1.pdf). """ downsample = nn.Sequential( nn.Conv2d( self.in_channels, out_channels*self.expansion, kernel_size=1, stride=stride, bias=False ), nn.BatchNorm2d(out_channels * self.expansion), ) layers = [] layers.append( block( num_layers, self.in_channels, out_channels, stride, self.expansion, downsample ) ) self.in_channels = out_channels * self.expansion for i in range(1, blocks): layers.append(block( num_layers, self.in_channels, out_channels, expansion=self.expansion )) return nn.Sequential(*layers) def forward(self, x: Tensor) -> Tensor: x = self.conv1(x) x = self.bn1(x) x = self.relu(x) x = self.maxpool(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) # The spatial dimension of the final layer's feature # map should be (7, 7) for all ResNets. x = self.avgpool(x) x = torch.flatten(x, 1) x = self.fc(x) return x
If you have gone through the post where we built ResNet18 from scratch, you will be able to find the additions right away.
- Starting from line 122, we now have conditions for initializing the
layers
list andexpansion
factor for ResNet50, 101, and 152. The elements within thelayers
list show how manyBasicBlock
s will be stacked together and how many such stacking will be there. - One other addition is on line 173 in the
_make_layer
method. Now, we have an additional condition that checks the input and output channels, based on which the downsampling is applied.
The rest of the code remains exactly the same as was in the case of building ResNet18 from scratch.
The last part of the code is writing the main block which builds the ResNets based on the number of layers that we pass from the command line.
if __name__ == '__main__': tensor = torch.rand([1, 3, 224, 224]) model = ResNet( img_channels=3, num_layers=args['num_layers'], block=BasicBlock, num_classes=1000 ) print(model) # Total parameters and trainable parameters. total_params = sum(p.numel() for p in model.parameters()) print(f"{total_params:,} total parameters.") total_trainable_params = sum( p.numel() for p in model.parameters() if p.requires_grad) print(f"{total_trainable_params:,} training parameters.") output = model(tensor)
Here, we just initialize a random tensor
that we pass through the model
. We get the value of num_layers
from the command line argument that we pass when executing the resnet.py
script. We also print the number of parameters in each of our models so that we can compare them with the models from Torchvision.
This is all the code that we need to build ResNets from scratch using PyTorch.
Verify the ResNet Architectures
You may execute the following commands to check the outputs when building the ResNet models.
The following few blocks show the command to build the ResNets and their corresponding outputs.
Building ResNet18.
python resnet.py --num-layers 18
The output.
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) . . . (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=512, out_features=1000, bias=True) ) 11,689,512 total parameters. 11,689,512 training parameters.
Building ResNet34.
python resnet.py --num-layers 34
The output.
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) . . . (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=512, out_features=1000, bias=True) ) 21,797,672 total parameters. 21,797,672 training parameters.
Building ResNet50.
python resnet.py --num-layers 50
The output.
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (downsample): Sequential( (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (conv0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) ) . . . (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=2048, out_features=1000, bias=True) ) 25,557,032 total parameters. 25,557,032 training parameters.
Building ResNet101.
python resnet.py --num-layers 101
The output.
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (downsample): Sequential( (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (conv0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) ) . . . (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=2048, out_features=1000, bias=True) ) 44,549,160 total parameters. 44,549,160 training parameters.
Building ResNet152.
python resnet.py --num-layers 152
The output.
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): BasicBlock( (downsample): Sequential( (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (conv0): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (conv2): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) ) . . . (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=2048, out_features=1000, bias=True) ) 60,192,808 total parameters. 60,192,808 training parameters.
If you compare the number of parameters of our models with the official Torchvision models, then they are exactly the same. The rest is up to the performance of the networks on classification tasks when training from scratch. This may slightly vary because of random initializations of the weights.
If you carry out any classification training using the networks, please let us know about your results in the comment section.
Summary and Conclusion
In this tutorial, you learned how to build different ResNets from scratch using PyTorch. We covered ResNet18, ResNet34, ResNet50, ResNet101, and ResNet152. Hopefully, this was a good learning experience for you.
If you have any doubts, thoughts, or suggestions, please leave them in the comment section. They will surely be addressed.
You can contact me using the Contact section. You can also find me on LinkedIn, and Twitter.
References
- He, K., Zhang, X., Ren, S., & Sun, J. (2015). Deep Residual Learning for Image Recognition. arXiv. https://doi.org/10.48550/arXiv.1512.03385
- https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
7 thoughts on “Building ResNets from Scratch using PyTorch”