Surgical Tool Recognition using PyTorch and Deep Learning


Surgical Tool Recognition using PyTorch and Deep Learning

Deep learning and computer vision have a lot of applications in the field of medical imaging. In fact, many deep learning-based approaches are currently being used in medical imaging. These can range from medical image recognition to disease detection, and segmentation of diseased areas. But that’s not all. We can also use deep learning on surgical images. As such, recognizing surgical tools is a big part of such an approach. In this article, we will carry out surgical tool recognition using PyTorch and deep learning.

Class activation map of surgical tool recognition.
Figure 1. Example output – Class activation map of surgical tool recognition.

The above image will give you some perspective into what we are going to do in this article. In simple words, we are going to recognize different types of surgical tools that are used during surgery. We will discuss all the details later when exploring the dataset.

For now, let’s keep in mind that this article is entirely for educational purposes and exploring the world of medical imaging from a deep learning perspective. So, if you find anything wrong or inaccurate, please feel free to point that out in the comment section.

Here are the topics that we will cover in this article:

  • We will start with a discussion of the dataset. This includes the source of the dataset, the types of images in the dataset, the number of samples, and classes.
  • Then we will discuss the training strategy and the deep learning model to be used for training. As this project contains quite a lot of code, we will cover very limited coding related things. Instead, we will focus on the training strategy and the results.
  • Next, we will train the model and test it on the available test dataset.
  • Finally, we will check out the class activation maps of the predictions.

The Surgical Tool Recognition Dataset

We will use the Cholec-Tinytools dataset from Kaggle for surgical tool recognition in this article. There is a lot of background to cover around this dataset, but we will cover the basics of it.

The dataset contains the tooltip images of four different surgical equipment. These images have been sourced from 80 laparoscopic cholecystectomy videos. It is a part of a very large dataset called Cholec80 which you can find on the CAMMA website.

But we are not going to use this huge dataset. Instead, we will use the chloec-tinytools dataset from Kaggle. It is a very simple and small dataset to start prototyping with surgical tools and applying deep learning and computer vision approaches.

One big difference apart from the dataset size is that the original datasets on the CAMMA website come with several types of annotations. We can use such annotations and apply deep learning models and approaches for:

  • Surgical tools classification
  • Surgical phase recognization
  • Surgeon skill classification

But we are going to use a toned-down dataset for image classification only.

Please go ahead and download the dataset from Kaggle. After extracting it, you should have a structure similar to the following:

cholec-tinytools/
├── LICENSE
├── README.txt
├── test
│   ├── clipper
│   ├── grasper
│   ├── hook
│   └── scissor
├── train
│   ├── clipper
│   ├── grasper
│   ├── hook
│   └── scissor
└── validation
    ├── clipper
    ├── grasper
    ├── hook
    └── scissor

The images are inside the cholec-tinytools directory. We have three splits for the surgical tool recognition dataset, train, test, and validation.

Exploring the Cholec-Tinytools Surgical Recognition Dataset

As we can see from the above directory structure, there are 4 classes in the dataset:

  • Clipper
  • Grasper
  • Hook
  • Scissor

There are 1200 images in the training set, 200 images in the validation set, and 599 images in the test set. Obviously, this is not a huge dataset but there are enough samples to start experimenting.

Here are a few examples from the dataset along with their annotations.

Surgical tool recognition ground truth images.
Figure 2. Ground truth images from the surgical tool recognition dataset.

The images contain snapshots from the surgical procedures and the images contain the surgical equipment’s tooltips as well. The deep learning model that we are going to build, should be able to recognize the surgical tool from such images.

From a real-world scenario only recognizing the surgical tool may not be enough. We need surgical tool detection and segmentation also to create an application that’s worthwhile. But recognition is the first step and this will also give us more insights into where a deep learning model may fail.

Also, as we will visualize the class activation maps on the test after, we will know where the model is focusing on when recognizing a certain surgical tool.

The Project Directory Structure

Before moving into the training experiments, let’s take a look at the project directory structure.

├── input
│   └── cholec-tinytools
├── outputs
│   ├── cam_results
│   ├── test_results
│   ├── accuracy.png
│   ├── best_model.pth
│   ├── loss.png
│   └── model.pth
└── src
    ├── cam.py
    ├── datasets.py
    ├── model.py
    ├── test.py
    ├── train.py
    └── utils.py
  • The input directory contains the cholec-tinytools dataset which contains the data splits as we saw above.
  • We have all the training, testing, and class activation map outputs in the outputs directory.
  • The src directory contains 5 Python files. Although we will not discuss the code in detail, we will explore the necessary parts of the any script when need be.

This is all the information that we need regarding the dataset. In the next section, we can move on to the training experiments.

The zip file that comes with the dataset provides the best trained model and the training scripts. If you want to train your own deep learning model, you just need to download the dataset and structure it accordingly.

The PyTorch Version

The code provided in this article has been developed with TORCH 1.12.0 and TORCHVISION 0.13.0. As we are using the latest API to load the pretrained weights, you will need version 1.12.0 and above.

You can install/upgrade PyTorch from the official website.

Surgical Tool Recognition using Deep Learning and MobileNet

Let’s start with the training experiments and explore the technical necessities from here onward.

Download Code

The Deep Learning Model

We will use the MobileNetV3 Small model for surgical tool recognition. The MobileNet models in general are very efficient in terms of speed, even the SSD models with MobileNet backbones.

Moreover, from experiments, I found that we don’t even need the large MobileNetV3 model. The MobileNetV3 Small gives very good results and the training iterations are much faster on the surgical tool recognition dataset.

For that reason, we will use this model. As PyTorch already provides the pretrained MobileNetV3 Small model, we just need to load it and change the classification head. Here is how we can do it.

from torchvision import models

import torch.nn as nn

def build_model(fine_tune=True, num_classes=10):
    model = models.mobilenet_v3_small(weights='DEFAULT')
    if fine_tune:
        print('[INFO]: Fine-tuning all layers...')
        for params in model.parameters():
            params.requires_grad = True
    if not fine_tune:
        print('[INFO]: Freezing hidden layers...')
        for params in model.parameters():
            params.requires_grad = False
    # model.fc = nn.Linear(in_features=2048, out_features=num_classes, bias=True)
    model.classifier[3] = nn.Linear(
        in_features=1024, out_features=num_classes, bias=True
    )
    return model

The build_model function takes in parameters for fine tuning the hidden layers and the number of classes.

Just to keep in mind, we will be fine tuning all the layers to get the best results.

Also, in terms of modification, we will just change the final Linear layer as per the number of classes in the dataset.

Dataset Preparation

The datasets.py contains all the code for preparing the datasets and data loaders.

We resize all the images to 224×224 resolution. By default, all the images in the dataset are128x128 in size. Also, we use 4 parallel workers for the dataset preprocessing.

We apply different augmentations to the training dataset to avoid overfitting. Here are all the augmentations that we apply.

  • RandomHorizontalFlip
  • RandomRotation
  • RandomAdjustSharpness

For a model like MobileNetV3 Small, these are enough augmentations to train for 40 epochs without overfitting.

The following figure shows what the images look like after applying the augmentations.

Surgical tool recognition dataset images after applying the augmentations.
Figure 3. Surgical tool recognition dataset images after applying the augmentations. The neural network will get such augmented images as input.

It is clear that applying the augmentations to the images makes them look different which will help the model learn more features.

Helper Functions and the Training Script

The utils.py contains a few helper functions and classes. These include the class to save the best model. Whenever the current epoch’s validation loss will be the lowest, then the weights after that epoch will be saved.

It also contains the code for saving the accuracy and loss graphs.

Coming to the training script, that is train.py. This is the driver script that initiates the training process. It has a few command line argument flags that let us control some training parameters. Let’s discuss these parameters first.

  • --epochs: This flag lets us set the number of epochs that we want to train for.
  • --learning-rate: To control the learning rate of the optimizer. Its default value is 0.001.
  • --batch-size: The batch size for the data loader, which is 32 by default.
  • --fine-tune: It is a boolean flag. If we pass this argument then all the layers of the MobileNetV3 Small model will be trained.
  • --save-name: This accepts a string name which is used to save the model weights. If we pass the value as new, the model weight file name will be new.pth.

Other than this, the train.py file also contains the training and validation functions.

Training the MobileNetV3 Small Model on the Surgical Tool Recognition Dataset

We have everything set up. We can start the training now.

All the Python files are inside the src directory. Also, we will execute all the commands in the terminal within the src directory. To start the training, you can execute the following command in the terminal.

python train.py --epochs 40 --fine-tune

We are training the MobileNetV3 Small model on the Surgical Tool Recognition dataset for 40 epochs. This may seem too large, but for the MobileNetV3 Small it is not, as we will later see. Also, we are passing the --fine-tune flag which tells the training script to train all the layers of the model.

Even if you train on a mid-range GPU, it should not take long. Here are the truncated outputs from the terminal.

python train.py --epochs 40 --fine-tune
[INFO]: Number of training images: 1200
[INFO]: Number of validation images: 200
[INFO]: Classes: ['clipper', 'grasper', 'hook', 'scissor']
Computation device: cuda
Learning rate: 0.001
Epochs to train for: 40

[INFO]: Fine-tuning all layers...
MobileNetV3(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
.
.
.
  (avgpool): AdaptiveAvgPool2d(output_size=1)
  (classifier): Sequential(
    (0): Linear(in_features=576, out_features=1024, bias=True)
    (1): Hardswish()
    (2): Dropout(p=0.2, inplace=True)
    (3): Linear(in_features=1024, out_features=4, bias=True)
  )
)
1,521,956 total parameters.
1,521,956 training parameters.
Adjusting learning rate of group 0 to 1.0000e-03.
[INFO]: Epoch 1 of 40
Training
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:04<00:00,  9.25it/s]
Validation
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 22.78it/s]
Training loss: 1.215, training acc: 54.000
Validation loss: 1.079, validation acc: 70.500

Best validation loss: 1.078602569443839

Saving best model for epoch: 1
.
.
.
[INFO]: Epoch 40 of 40
Training
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 38/38 [00:02<00:00, 13.43it/s]
Validation
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 27.12it/s]
Training loss: 0.026, training acc: 99.417
Validation loss: 0.169, validation acc: 94.500

Best validation loss: 0.16853357345930167

Saving best model for epoch: 40

--------------------------------------------------
TRAINING COMPLETE

As we can see from the above output block, with 4 classes, the MobileNetV3 Small Model contains only 1.5 million parameters.

Yet, we are getting the best results even in the last epoch. This tells us that the model has not been overfitted yet. It may still be able to learn without overfitting.

Further, it is worthwhile to note that on the last epoch, the validation accuracy is 94.5% which is quite impressive for such a small model.

Right now, we can take a look at the accuracy and loss graphs which will tell us even more about the training procedure.

Accuracy after training the MobileNetV3 Small model on the dataset.
Figure 4. Accuracy after training the MobileNetV3 Small model on the dataset.
Loss after training the MobileNetV3 Small model on the dataset.
Figure 5. Loss after training the MobileNetV3 Small model on the dataset.

The accuracy plots look pretty good. The training accuracy seems to be still improving. It tells us that if the learning rate scheduler kicks in and reduces the learning rate, the validation accuracy and loss may further improve.

Testing the Trained MobileNetV3 Small Model on the Test Set

We have the test.py script which uses the best trained model to evaluate the accuracy on the test set.

The dataset already has a test split, we can easily test how the model performs on unseen surgical tools when doing image recognition.

Following is the command to run the test.

python test.py

Here are the outputs.

[INFO]: Freezing hidden layers...
Testing model
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 599/599 [00:06<00:00, 96.43it/s]
Test accuracy: 95.326%

The model achieves 95.32% accuracy on the Surgical Tool Recognition test set. This is slightly higher than the best validation accuracy.

Following are some of the images where the model made correct predictions.

Correct predictions made by the MobileNetV3 Small model on the test set of surgical tool recognition dataset.
Figure 6. Correct predictions made by the MobileNetV3 Small model on the test set of surgical tool recognition dataset.

And here are some wrong predictions.

Wrong predictions made by the MobileNetV3 Small model on the test set.
Figure 7. Wrong predictions made by the MobileNetV3 Small model on the test set of surgical tool recognition dataset.

It is a bit difficult to explain why the model made certain wrong predictions. Most probably, taking a look at the class activation maps will clear up our confusion.

Class Activation Maps (CAM) for the Surgical Tool Recognition Test Set

The cam.py script contains the code to visualize the class activation maps on the test set. It is almost the same as the test.py script but adds the class activation maps on the predictions.

We can run the script using the following command.

python cam.py

Now, let’s take a look at the same four images.

Class activation maps of the wrong predictions.
Figure 8. Class activation maps of the wrong predictions.

In the first and second images, the model predicts the clipper as a grasper and the grasper as a scissor. Surely, all three have some similarities, and therefore, when the images are not that clear, it is wrongly predicting the tool. For the same reason, it may also be predicting the scissor as a clipper. What’s more interesting is that while making all these predictions, the model is directly focusing on the area around the tool.

But the model is also predicting the hook as a scissor while focusing on the bent part of the hook. If you observe, the scissors have a slight bent to their tooltips. This may be one of the reasons for the wrong predictions.

Summary and Conclusion

We covered the recognition of surgical tools in this article on a small dataset and using the MobileNetV3 Small model. We did not get state-of-the-art results and certainly did not become surgical experts. But we got to know a good deal about how even small models can perform pretty well given a proper dataset and decent training pipeline. We also saw where the models may fail while doing predictions. I hope that this article was useful to you.

If you have any doubts, thoughts, or suggestions, please leave them in the comment section. I will surely address them.

You can contact me using the Contact section. You can also find me on LinkedIn, and Twitter.

Liked it? Take a second to support Sovit Ranjan Rath on Patreon!
Become a patron at Patreon!

2 thoughts on “Surgical Tool Recognition using PyTorch and Deep Learning”

  1. Raj says:

    Thank you for the wonderful tutorial. Do something about yolov8 onject segmentation.

    1. Sovit Ranjan Rath says:

      Thank you Raj. I am planning on YOLOv8 too.

Leave a Reply

Your email address will not be published. Required fields are marked *