Image Classification with Web-DINO


Image Classification with Web-DINO

DINOv2 models led to several successful downstream tasks that include image classification, semantic segmentation, and depth estimation. Recently, the DINOv2 models were trained with web-scale data using the Web-SSL framework, terming the new models as Web-DINO. We covered the motivation, architecture, and benchmarks of Web-DINO in our last article. In this article, we are going to use one of the Web-DINO models for image classification.

Web-DINO image classification workflow.
Figure 1. Web-DINO image classification workflow.

The authors of Web-DINO have released several pretrained models that are available in both pure PyTorch and Hugging Face formats. We are going to use one of the pretrained models for the image classification downstream task.

We are covering the following topics for image classification using Web-DINO?

  • The dataset that we will use for image classification.
  • Modifying pretrained Web-DINO model for image classification.
  • Training the Web-DINO model for image classification.
  • Inference using the trained model.
  • A few drawbacks.

Which Web-DINO Model Are We Going to Train for Image Classification?

The Web-DINO paper mentions models ranging from 1B to 7B parameters, all based on the ViT-g architecture. However, their Hugging Face and GitHub repository contain an additional 300M parameter model. We are going with this model as it will be less computationally expensive and faster to train.

Models available in the Web-DINO family.
Figure 2. Models available in the Web-DINO family. (source: https://github.com/facebookresearch/webssl/ )

The Web-DINO 300M is based on the ViT-L architecture and is trained on the same MC-2B dataset as other Web-DINO models with the Web-SSL framework.

The Cotton Disease Classification Dataset

We will use the cotton disease classification dataset to train the Web-DINO 300M model. This is the same dataset that we trained the DINOv2 model for image classification in one of the earlier articles.

You can find the cotton disease classification dataset here on Kaggle. We get the following directory structure after downloading and extracting the dataset.

├── Cotton-Disease-Training
│   └── trainning
│       └── Cotton leaves - Training
├── Cotton-Disease-Validation
│   └── validation
│       └── Cotton plant disease-Validation
└── Customized Cotton Dataset-Complete
    └── content
        ├── trainning
        └── validation

We will ignore the Customized Cotton Dataset-Complete directory. Instead, we will use the Cotton-Disease-Training for training and Cotton-Disease-Validation for validation.

The dataset contains 6628 images for training and 356 images for validation. There are a total of 8 classes.

['Aphids', 'Army worm', 'Bacterial blight', 'Cotton Boll Rot', 
'Green Cotton Boll', 'Healthy', 'Powdery mildew', 'Target spot']

Let’s take a look at a few samples from the dataset.

Ground truth samples from the cotton disease classification dataset.
Figure 3. Ground truth samples from the cotton disease classification dataset.

As we can see, the dataset contains varied images in different scenarios.

Directory Structure

Let’s take a look at the project’s directory structure.

├── input
│   ├── Cotton-Disease-Training
│   │   └── trainning
│   ├── Cotton-Disease-Validation
│   │   └── validation
│   ├── Customized Cotton Dataset-Complete
│   │   └── content
│   └── inference_data
│       ├── aphids.jpg
│       ├── army_worm.jpg
│       ├── bacterial_blight.jpg
│       ├── cotton_ball_rot.jpg
│       ├── green_cotton_ball.jpg
│       ├── healthy_leaf.jpg
│       ├── powdery_mildew.jpg
│       └── target_spot.jpg
├── outputs
│   └── transfer_learning
│       ├── accuracy.png
│       ├── best_model.pth
│       ├── loss.png
│       └── model.pth
└── src
    ├── datasets.py
    ├── dinov2
    │   ├── layers
    │   └── vision_transformer.py
    ├── inference.py
    ├── model.py
    ├── train.py
    ├── utils.py
    └── weights
        └── webssl_dino300m_full2b_224.pth
  • The input directory contains the downloaded cotton disease classification dataset that we saw in the previous section. Along with that, the inference_data directory contains a few images from the validation set that we will later use for inference.
  • The outputs directory contains the trained weights and graphs.
  • We have the source code in the src directory. Along with the custom Python files, we also have a dinov2 directory. This directory is borrowed from the official Web-SSL GitHub repository by cloning the repository first and then copying the dinov2 directory to our project directory. This is essential for importing the Web-DINO models. Although the directory will be available along with the downloadable codebase in this article, you are free to clone the official codebase and make changes as per your requirements.
  • Finally, the weights subdirectory in the src directory contains the pretrained Web-DINO 300M backbone weights. We will later see how to download this.

All the code files, trained weights, and inference data will be available for download via the download section. If you wish to train the model yourself, you will need to download the training dataset and arrange it in the above directory structure.

Download Code

Installing Requirements

There are a few major requirements for this article.

  • The codebase is based on PyTorch 2.5.1. Although older and newer versions are likely to work. You can install the framework from here.
  • Other requirements.
pip install tqdm matplotlib opencv-python

Training Web-DINO for Image Classification

Let’s get into the important parts of the codebase for training Web-DINO for image classification. However, before that, we need to download the pretrained Web-DINO 300M weights. You can find links to all the weights here. The table contains both Hugging Face links and the PyTorch weights. We will use the PyTorch weights. To directly download them, you can execute the following command in your terminal inside the src/weights directory.

wget https://dl.fbaipublicfiles.com/webssl/webssl_dino300m_full2b_224.pth

Modifying Web-DINO 300M to Create an Image Classification Model

The most important part, of course, is adapting the pretrained Web-DINO for image classification. Let’s do that.

The code for creating the Web-DINO image classification model is present in the model.py file in the src directory.

import torch

from collections import OrderedDict
from dinov2.vision_transformer import webssl_dino300m_full2b_224

def load_model():
    # Load model
    model = webssl_dino300m_full2b_224()

    # Load weights
    checkpoint_path = 'weights/webssl_dino300m_full2b_224.pth'
    state_dict = torch.load(checkpoint_path, map_location='cpu')
    msg = model.load_state_dict(state_dict, strict=False)
    print(f"Checkpoint loading message: {msg}")

    return model

def build_model(num_classes=10, fine_tune=False):
    backobone_model = load_model()

    model = torch.nn.Sequential(OrderedDict([
        ('backbone', backobone_model),
        ('head', torch.nn.Linear(
            in_features=1024, out_features=num_classes, bias=True
        ))
    ]))
    
    if not fine_tune:
        for params in model.backbone.parameters():
            params.requires_grad = False

    return model

We have two functions here.

load_model():

  • The load_model function first initializes the backbone model using the webssl_dino300m_full2b_224 function that we import from the dinov2 module.
  • Then we initialize the model with the checkpoint that we downloaded earlier.

build_model():

  • The build_model function accepts two parameters; num_classes to control the number of output classes in the final head, and fine_tune which is a boolean value indicating whether we want to train the backbone or not.
  • In this function, first, we initialize the Web-DINO backbone model. Then we create a Sequential model with the backbone and the classification head.

The file also contains a main block to check the forward pass and output shapes.

if __name__ == '__main__':
    from PIL import Image
    from torchvision import transforms
    from torchinfo import summary

    import numpy as np

    # We can give any multiple of 14.
    sample_size = 224

    # Define image transformation
    transform = transforms.Compose([
        transforms.Resize(
            sample_size, 
            interpolation=transforms.InterpolationMode.BICUBIC
        ),
        transforms.CenterCrop(sample_size),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=(0.485, 0.456, 0.406), 
            std=(0.229, 0.224, 0.225)
        )
    ])

    # Loading the pretrained model without classification head.
    model = load_model()

    # Total parameters and trainable parameters.
    total_params = sum(p.numel() for p in model.parameters())
    print(f"{total_params:,} total parameters.")
    total_trainable_params = sum(
        p.numel() for p in model.parameters() if p.requires_grad)
    print(f"{total_trainable_params:,} training parameters.")

    # Testing forward pass.
    pil_image = Image.fromarray(np.ones((sample_size, sample_size, 3), dtype=np.uint8))
    model_input = transform(pil_image).unsqueeze(0)

    summary(
        model,
        input_data=model_input,
        col_names=('input_size', 'output_size', 'num_params'),
        row_settings=['var_names']
    )

    # Manual torch forward pass.
    with torch.no_grad():
        features = model.forward_features(model_input)
        patch_features = features['x_norm_patchtokens']

    print(features.keys())
    print(f"Patch features shape: {patch_features.shape}")

    # Check the forward passes through the complete model.
    # To check what gets fed to the classification layer.
    model_cls = build_model()
    features = model_cls.backbone(model_input)
    print(f"Shape of features getting fed to classification layer: {features.shape}")

We can execute the model file using the following command.

python model.py

We get the following output.

=============================================================================================================================
Layer (type (var_name))                            Input Shape               Output Shape              Param #
=============================================================================================================================
DinoVisionTransformer (DinoVisionTransformer)      [1, 3, 224, 224]          [1, 1024]                 265,216
├─PatchEmbed (patch_embed)                         [1, 3, 224, 224]          [1, 256, 1024]            --
│    └─Conv2d (proj)                               [1, 3, 224, 224]          [1, 1024, 16, 16]         603,136
│    └─Identity (norm)                             [1, 256, 1024]            [1, 256, 1024]            --
├─ModuleList (blocks)                              --                        --                        --
│    └─BlockChunk (0)                              [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─NestedTensorBlock (0)                  [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (1)                  [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (2)                  [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (3)                  [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (4)                  [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (5)                  [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    └─BlockChunk (1)                              [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (5)                           [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (5)                           [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (5)                           [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (5)                           [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (5)                           [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (5)                           [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─NestedTensorBlock (6)                  [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (7)                  [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (8)                  [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (9)                  [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (10)                 [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (11)                 [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    └─BlockChunk (2)                              [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (11)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (11)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (11)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (11)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (11)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (11)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (11)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (11)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (11)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (11)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (11)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (11)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─NestedTensorBlock (12)                 [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (13)                 [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (14)                 [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (15)                 [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (16)                 [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (17)                 [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    └─BlockChunk (3)                              [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─Identity (17)                          [1, 257, 1024]            [1, 257, 1024]            --
│    │    └─NestedTensorBlock (18)                 [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (19)                 [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (20)                 [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (21)                 [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (22)                 [1, 257, 1024]            [1, 257, 1024]            12,616,032
│    │    └─NestedTensorBlock (23)                 [1, 257, 1024]            [1, 257, 1024]            12,616,032
├─LayerNorm (norm)                                 [1, 257, 1024]            [1, 257, 1024]            2,048
├─Identity (head)                                  [1, 1024]                 [1, 1024]                 --
=============================================================================================================================
Total params: 303,655,168
Trainable params: 303,655,168
Non-trainable params: 0
Total mult-adds (Units.MEGABYTES): 457.14
=============================================================================================================================
Input size (MB): 0.60
Forward/backward pass size (MB): 728.97
Params size (MB): 1213.56
Estimated Total Size (MB): 1943.13
=============================================================================================================================
dict_keys(['x_norm_clstoken', 'x_norm_regtokens', 'x_norm_patchtokens', 'x_prenorm', 'masks'])
Patch features shape: torch.Size([1, 256, 1024])
Checkpoint loading message: <All keys matched successfully>
Shape of features getting fed to classification layer: torch.Size([1, 1024])

The patch token features, excluding the classification token, are of shape [1, 256, 1024]. We can access this by calling the forward_features method of the backbone. This is not useful for classification. However, these features can be used for semantic segmentation.

We are interested in the pooled features of the last layer that have the shape [1, 1024]. These features will be fed to the classification head.

Rest Of The Code

The rest of the code involves the dataset preparation, utilities to save the model, loss & accuracy graphs, and the training script.

The training script, train.py, supports several command line arguments. We will use the ones that we need only.

Dataset Preparation:

We have all the dataset preparation code in datasets.py file. We carry out the following transformations/augmentations for the images:

  • We resize all images to 256×256 resolution and center-crop them to 224×224 resolution. This is for both training and validation images.
  • Apart from the above, the training images go through horizontal flipping augmentation.
  • Furthermore, all the images are normalized according to ImageNet mean and standard deviations.

You may go through the above files if you need a better understanding of the codebase.

Executing the Training Script

All training and inference experiments were carried out on a machine with 10GB RTX 3080 GPU, 32GB RAM, and an i7 10th generation processor.

We can execute the train.py script within the src directory using the following command.

python train.py -lr 0.0005 --epochs 20 --batch 32 --out-dir transfer_learning --scheduler 10
  • We are training the model for 20 epochs.
  • The initial learning rate is 0.0005, and we are reducing it by a factor of 10 after 10 epochs (--scheduler).
  • The batch size is 32.
  • All the results will be saved within outputs/transfer_learning directory.

Following is the truncated output from the terminal.

Namespace(epochs=20, learning_rate=0.0005, batch_size=32, save_name='model', fine_tune=False, out_dir='transfer_learning', scheduler=[10])
[INFO]: Number of training images: 6628
[INFO]: Number of validation images: 357
[INFO]: Classes: ['Aphids', 'Army worm', 'Bacterial blight', 'Cotton Boll Rot', 'Green Cotton Boll', 'Healthy', 'Powdery mildew', 'Target spot']
Computation device: cuda
Learning rate: 0.0005
Epochs to train for: 20
.
.
.
Checkpoint loading message: <All keys matched successfully>
Sequential(
  (backbone): DinoVisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
      (norm): Identity()
    )
.
.
.
    )
    (norm): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
    (head): Identity()
  )
  (head): Linear(in_features=1024, out_features=8, bias=True)
)
303,663,368 total parameters.
8,200 training parameters.
.
.
.
[INFO]: Epoch 1 of 20
Training
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 208/208 [01:21<00:00,  2.55it/s]
Validation
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:04<00:00,  2.42it/s]
Training loss: 0.517, training acc: 85.667
Validation loss: 0.220, validation acc: 94.398

Best validation loss: 0.21986528082440296

Saving best model for epoch: 1

--------------------------------------------------
LR for next epoch: [0.0005]
[INFO]: Epoch 2 of 20
Training
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 208/208 [01:20<00:00,  2.58it/s]
Validation
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:04<00:00,  2.54it/s]
Training loss: 0.180, training acc: 97.013
Validation loss: 0.162, validation acc: 95.238

Best validation loss: 0.1622757208533585

Saving best model for epoch: 2

--------------------------------------------------
.
.
.

As we can see, we are training just 8,200 parameters of the classification head. The entire backbone is frozen.

The model reached a maximum validation accuracy of 96.9% on epoch 9. The lowest validation loss is 0.115 on the same epoch. We save the best model according to the validation loss, which we will be using for inference.

Accuracy graph after training the Web-DINO model on the image classification task.
Figure 4. Accuracy graph after training the Web-DINO model on the image classification task.
Loss graph after training the Web-DINO model on the image classification task.
Figure 5. Loss graph after training the Web-DINO model on the image classification task.

The validation accuracy line seems to have plateaued out and then deteriorated after around 7 epochs. This indicates overfitting. We can see a similar trend with the loss graph as well. We can add more augmentations and check whether the model improves for longer.

Inference Using The Trained Model

Let’s use the best trained weights to run inference on some of the validation images.

The inference code is present in the inference.py file. We can pass the path to the best trained weights using the --weights command line argument. We apply the same image transformations as during training.

Let’s execute the code.

python inference.py --weights ../outputs/transfer_learning/best_model.pth

The results will be saved to outputs/inference_results directory. The file names indicate the ground truth classes, and the predictions are annotated on the image. Let’s check the results.

Web DINO image classification inference results.
Figure 6. Web DINO image classification inference results.

The model predicted one of the images wrongly. Instead of Cotton Boll Rot, it predicted the class as Green Cotton Boll.

Further Experiments

  • We can expand this experiment to a good extent. The next step should be to apply such models to real-life use cases like plant pathology and visualizing attention maps before and after training.
  • Also, an experiment for fine-tuning vs transfer learning for such Vision Transformer models is worthwhile.

Summary and Conclusion

In this article, we modified the Web-DINO 300M model for a simple image classification experiment. We covered the changes necessary for the model, the training & inference results. We will try to cover more in-depth experiments in future articles.

If you have any questions, thoughts, or suggestions, please leave them in the comment section. I will surely address them.

You can contact me using the Contact section. You can also find me on LinkedIn, and X.

References

Liked it? Take a second to support Sovit Ranjan Rath on Patreon!
Become a patron at Patreon!

1 thought on “Image Classification with Web-DINO”

Leave a Reply

Your email address will not be published. Required fields are marked *