SSDLite MobileNetV3 Backbone Object Detection with PyTorch and Torchvision


SSDLite MobileNetV3 Backbone Object Detection with PyTorch and Torchvision

In deep learning and object detection, faster inference often comes at a cost. And most of the time, it is usually at the cost of detection accuracy. But lately, things have been changing. We are able to get really good FPS (Frames Per Second) and detection accuracy at the same time. Object detection models with lighter backbones help us achieve this. And we are going to see one such example in this post. Here, we will be using SSDLite with MobileNetV3 backbone for object detection using PyTorch and Torchvision.

SSDLite MobileNetV3 backbone object detection
Figure 1. An example of object detection using SSDLite MobileNetV3 backbone.

In the previous post, we explored object detection using SSD300 with the VGG16 backbone. You may have a look at it to get some background as we are going to use almost the same coding style in this post as well. There will still be a few minor changes.

What will we cover in this post?

  • We will mainly focus on using a pre-trained SSDLite object detection model with the MobileNetV3 backbone to carry out object detection.
  • We will carry out inference on both images and videos and see how it performs.

A Brief About the SSDLite Model

By now, we know that we will be using a pre-trained model. It is already available as a part of the torchvision module in the PyTorch framework.

In fact, the complete name is ssdlite320_mobilenet_v3_large. The 320 indicates that it internally resizes the inputs to the 320×320 and it has a MobileNetV3 Large backbone model. The model has been pre-trained on the MS COCO object detection dataset.

It is also good to know what is the input and output format to and from the model while carrying out object detection inference. We have discussed that in the previous post for the SSD300 model with the VGG16 backbone. It is the same in this case as well. So, please do take a look if you are interested to know more.

The PyTorch Version

Now, this is a bit of an important part. To use the SSDLite with the MobileNetV3 backbone for object detection, you need to have at least PyTorch version 1.9.0 installed on your system.

It is not available for the older versions. If you already have it (or even a higher version), then you are good to go. Else, you need to install it before going further. You can get the command to install the latest version according to your requirements from the official installation page.

The Directory Structure

Let’s follow a simple and clean directory structure for this mini-experiment of ours. Take a look at the following.

├── input
│   ├── image_1.jpg
│   ├── image_2.jpg
│   ├── video_1.mp4
│   └── video_2.mp4
├── outputs
│   └── image_1_05.jpg
|   ...
├── coco_names.py
├── detect_image.py
├── detect_utils.py
├── detect_video.py
├── model.py
  • We have an input folder containing 2 images and 2 videos. Out of these, we will carry our object detection inference on one image and two of the videos.
  • The outputs folder will contain the output images and videos after the inference is complete.
  • Then there are five Python files in which we will write the required code.

You can download the source code and input test data for this post by clicking on the button below.

Let’s start with the coding part of this post now.

SSDLite with MobileNetV3 Backbone for Object Detection using PyTorch and Torchvision

From here onward, we will focus on the coding part of the post.

The MS COCO Class Names

We need to map the detection labels to the MS COCO class names after we carry out the object detection in an image or video frame. For this we need the MS COCO class names.

For simplicity we will create a Python file and keep all the names in a list which we import to whichever other script we want. The following code will go into the coco_names.py file.

COCO_INSTANCE_CATEGORY_NAMES = [
    '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
    'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
    'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
    'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
    'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
    'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
    'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
    'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
    'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
    'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
    'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

The COCO_INSTANCE_CATEGORY_NAMES is a list containing all the class names.

The SSDLite MobileNetV3 Model

As we will be using the SSDLite with MobileNetV3 backbone for object detection in both images and videos, it is better to make it a reusable module. This makes our code much cleaner while reducing the lines of code as well.

The following code will go into the model.py file.

import torchvision

def get_model(device):
    # load the model 
    model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True)
    # load the model onto the computation device
    model = model.eval().to(device)
    return model

The get_model() function accepts the computation device as the input parameter. We load the model from the torchvision module, switch it to eval() mode, load it onto the computation device and return the model.

This is all we need for the model.py file.

Helper Function and Utility Code

We need a few lines of utility code and helper functions to make our object detection work easier. Let’s get on to write those. Almost all of the code here is similar to the previous post. So, we will not dive much into the details. Please have a look at the previous post for more details.

The following code will go into the detect_utils.py file.

import torchvision.transforms as transforms
import cv2
import numpy as np
import torch

from coco_names import COCO_INSTANCE_CATEGORY_NAMES as coco_names

# this will help us create a different color for each class
COLORS = np.random.uniform(0, 255, size=(len(coco_names), 3))

# define the torchvision image transforms
transform = transforms.Compose([
    transforms.ToTensor(),
])

First, we are importing all the required modules and libraries along with the COCO_INSTANCE_CATEGORY_NAMES that we defined above.

On line 9, we are defining a COLORS array which holds different color tuples for each of the COCO classes. We can use these to add different colors to the bounding boxes and texts for each of the classes while annotating the images with OpenCV.

Then we define a simple transform to convert the input data to tensor.

Helper Function for Forward Pass and Detection

The following is a simple function that forward propagates the input through the model and returns the relevant and required outputs.

def predict(image, model, device, detection_threshold):
    """
    Predict the output of an image after forward pass through
    the model and return the bounding boxes, class names, and 
    class labels. 
    """
    # transform the image to tensor
    image = transform(image).to(device)
    # add a batch dimension
    image = image.unsqueeze(0) 
    # get the predictions on the image
    with torch.no_grad():
        outputs = model(image) 

    # get score for all the predicted objects
    pred_scores = outputs[0]['scores'].detach().cpu().numpy()

    # get all the predicted bounding boxes
    pred_bboxes = outputs[0]['boxes'].detach().cpu().numpy()
    # get boxes above the threshold score
    boxes = pred_bboxes[pred_scores >= detection_threshold].astype(np.int32)
    labels = outputs[0]['labels'][:len(boxes)]
    # get all the predicited class names
    pred_classes = [coco_names[i] for i in labels.cpu().numpy()]

    return boxes, pred_classes, labels

The predict() function accepts the image, model, device, and detection_threshold as input parameters. After the forward pass on line 27, we filter out all the predictions according to the threshold value. All the detections which have confidence score below the given value are dropped.

The boxes containing the bounding box coordinates, pred_classes containing the class names, and labels containing the class numbers are returned on line 40.

Helper Function to Draw the Bounding Boxes

We have another simple function to draw the bounding boxes around the objects and put the class name text on top of these boxes.

def draw_boxes(boxes, classes, labels, image):
    """
    Draws the bounding box around a detected object.
    """
    image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB)
    for i, box in enumerate(boxes):
        color = COLORS[labels[i]]
        cv2.rectangle(
            image,
            (int(box[0]), int(box[1])),
            (int(box[2]), int(box[3])),
            color, 2
        )
        cv2.putText(image, classes[i], (int(box[0]), int(box[1]-5)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2, 
                    lineType=cv2.LINE_AA)
    return image

Each class will have a different color bounding box to easily distinguish them. The final image that we return on line 57 is the completely annotated image with bounding boxes and class texts.

As noted earlier, we did not dive into the line-by-line explanation of the code. With this, we finish the utility code and helper functions.

Object Detection using SSDLite MobileNetV3 in Images

In this section, we will use the pre-trained model to detect objects in images. All the code here will go into the detect_image.py script unless otherwise specified.

Let’s start with importing all the required modules and preparing the argument parser.

import torch
import argparse
import cv2
import detect_utils

from PIL import Image
from model import get_model

# construct the argument parser
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', default='input/image_1.jpg', 
                    help='path to input input image')
parser.add_argument('-t', '--threshold', default=0.5, type=float,
                    help='detection threshold')
args = vars(parser.parse_args())
  • We are importing our own detect_utils module for detection and annotation.
  • Along with that, we are also importing the get_model function from the model module.
  • For the argument parser, we have two flags:
    • --input: for the input image path.
    • --threshold: to specify the detection confidence threshold below which all the detections will be dropped.

The next block of code defines the computation device and the SSDLite model.

# define the computation device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = get_model(device)

Depending upon what hardware you have, all the detections will run either on the CPU or GPU. Although for deep learning based object detection GPU is preferable, you should be good to go even with a decent CPU as we are using the SSDLite model with the MobileNetV3 backbone.

Read the Image, Preprocess, and Detect Objects

The final few steps:

  • Read the image from the disk.
  • Call the predict() function from the detect_utils module while providing the required arguments.
  • Draw the bounding boxes and class text on the image.

# read the image
image = Image.open(args['input'])
# detect outputs
boxes, classes, labels = detect_utils.predict(image, model, device, args['threshold'])
# draw bounding boxes
image = detect_utils.draw_boxes(boxes, classes, labels, image)
save_name = f"{args['input'].split('/')[-1].split('.')[0]}_{''.join(str(args['threshold']).split('.'))}"
cv2.imshow('Image', image)
cv2.imwrite(f"outputs/{save_name}.jpg", image)
cv2.waitKey(0)

We are carrying out the detection on line 22 and drawing the bounding boxes on line 24. After that, we visualize the result on the screen and save the result to disk using the save_name string that we define.

This completes the code for object detection in images.

Execute detect_image.py for Object Detection in Images

We will use one of the images from the input folder to object detection. Make sure that you are in the current directory where all the Python scripts are present. Execute the following command in your command line/terminal.

python detect_image.py --input input/image_1.jpg -t 0.5 

Let’s see what output we are getting with a threshold of 0.5.

An image showing a person, and a horse being detected by the SSDLite object detection model.
Figure 2. The model is able to detect the person and the horse, but not the dog with a threshold of 0.5

Okay, the model is able to detect the person and horse alright. But it is not able to detect the dog. Now, we need to keep in mind that models like SSDLite with MobileNetV3 backbone tend the trade off accuracy for detection speed. So, maybe using a lower confidence threshold will help. Well, actually, from testing I found that the model is able to detect every object in the image when the threshold is 0.1

python detect_image.py --input input/image_1.jpg -t 0.1
Image showing a person and horse being detected by the SSDLite model. But it is detecting the dog wrongly as a sheep.
Figure 3. With a threshold of 0.1 the model is detecting the dog wrongly as a sheep.

But that too is a wrong detection. The model is detecting the obvious dog as a sheep. This is one of the drawbacks of faster models with lighter backbones. They tend to make mistakes which larger models don’t.

Do try a few more images on your own and see what kind of results you are getting.

Object Detection using SSDLite MobileNetV3 in Videos

Now, we will write the code to detect objects in videos. Again, we will not go into much detail about the code as it is almost similar to what we covered in the previous post.

This code will go into the detect_video.py file.

Let’s import the modules, construct the argument parser, and define the computation device and model. This part is very similar to the object detection in images part.

import cv2
import torch
import argparse
import time
import detect_utils

from model import get_model

# construct the argument parser
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--input', default='input/video_1.mp4', 
                    help='path to input video')
parser.add_argument('-t', '--threshold', default=0.5, type=float,
                    help='detection threshold')
args = vars(parser.parse_args())

# define the computation device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = get_model(device)

The only difference is that for the input we will provide the path to a video file instead of an image.

Read the Video File

Here, we will read the video file and complete some other preliminary things like getting the video frames’ height and width.

cap = cv2.VideoCapture(args['input'])

if (cap.isOpened() == False):
    print('Error while trying to read video. Please check path again')

# get the frame width and height
frame_width = int(cap.get(3))
frame_height = int(cap.get(4))

save_name = f"{args['input'].split('/')[-1].split('.')[0]}_{''.join(str(args['threshold']).split('.'))}"
# define codec and create VideoWriter object 
out = cv2.VideoWriter(f"outputs/{save_name}.mp4", 
                      cv2.VideoWriter_fourcc(*'mp4v'), 30, 
                      (frame_width, frame_height))

frame_count = 0 # to count total frames
total_fps = 0 # to get the final frames per second

We are also initializing the VideoWriter() object to save the resulting frames to disk. The frame_count and total_fps variables will keep track of the total number of frames and cumulative FPS till the end respectively,

Loop Through the Video Frames and Detect Objects in Each Frame

We will use a while loop to loop through the video frames and detect the objects in each frame.

# read until end of video
while(cap.isOpened()):
    # capture each frame of the video
    ret, frame = cap.read()
    if ret == True:
        # get the start time
        start_time = time.time()
        with torch.no_grad():
            # get predictions for the current frame
            boxes, classes, labels = detect_utils.predict(frame, model, device, args['threshold'])
        
        # draw boxes and show current frame on screen
        image = detect_utils.draw_boxes(boxes, classes, labels, frame)

        # get the end time
        end_time = time.time()
        # get the fps
        fps = 1 / (end_time - start_time)
        # add fps to total fps
        total_fps += fps
        # increment frame count
        frame_count += 1
        # write the FPS on the current frame
        cv2.putText(image, f"{fps:.3f} FPS", (15, 30), cv2.FONT_HERSHEY_SIMPLEX,
                    1, (0, 255, 0), 2)
        # convert from BGR to RGB color format
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        cv2.imshow('image', image)
        out.write(image)
        # press `q` to exit
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    else:
        break

# release VideoCapture()
cap.release()
# close all frames and video windows
cv2.destroyAllWindows()

# calculate and print the average FPS
avg_fps = total_fps / frame_count
print(f"Average FPS: {avg_fps:.3f}")

For each frame:

  • We are detecting the objects.
  • Drawing the bounding boxes around the detected objects.
  • Annotating the frame with the FPS text.
  • Visualizing it and saving it to disk.

Finally, we are releasing the VideoCapture() object and destroying all OpenCV windows.

We have completed the code for object detection in videos as well. Let’s execute the script and see what kind of results we get.

Execute detect_video.py for Object Detection in Videos

Note that all the detections are happening on a laptop with 16 GB RAM, i7 8th Gen CPU and 6 GB of GTX 1060 GPU. Your results will vary depending on the hardware that you use.

Let’s start with the first video in the input folder.

python detect_video.py --input input/video_1.mp4 -t 0.3
Average FPS: 22.220

We are carrying out detection with a threshold of 0.3 and getting an average FPS of 22.220. Not that bad. But interestingly, we are getting lower FPS than the SSD300 model with the VGG16 backbone that we experimented with in the last post.

Clip 1. Object detection in video using the SSDLite model with MobileNetV3 backbone. The model is able to detect most of the persons and vehicles on the road correctly.

The detections are decent though. It is detecting the person, the motorcycles, and the truck. It’s struggling a bit to detect the persons at the far back though.

Let’s try another video.

python detect_video.py --input input/video_2.mp4 -t 0.3
Average FPS: 21.201

The average FPS is 21.201. How about the detection?

Clip 2. Object detection in video using the SSDLite model with MobileNetV3 backbone. The model is suffering a bit while detecting the cars that are far away.

The model is able to detect the persons which are nearer to the camera. There are few detections for the cars at the left side of the video but the bounding boxes seem to be a bit wrong. Again, issue of trade-off between accuracy and speed.

A Few Takeaways

  • Object detection models like SSDLite with lighter backbones such as MobileNetV3 are great for deployment and edge devices.
  • But they are less accurate. They trade-off accuracy for speed.
  • If your application can do with a bit less accurate predictions and speed is of more importance, then surely go with such models.

Summary and Conclusion

In this post, we explored object detection using the SSDLite model with the MobileNetV3 backbone. We saw some of the advantages and drawbacks of such less compute-intensive models. I hope that you learned something new from this post.

If you have any doubts, thoughts, or suggestions, then please leave them in the comment section. I will surely address them.

You can contact me using the Contact section. You can also find me on LinkedIn, and Twitter.

Liked it? Take a second to support Sovit Ranjan Rath on Patreon!
Become a patron at Patreon!

1 thought on “SSDLite MobileNetV3 Backbone Object Detection with PyTorch and Torchvision”

Leave a Reply

Your email address will not be published. Required fields are marked *