In deep learning and object detection, faster inference often comes at a cost. And most of the time, it is usually at the cost of detection accuracy. But lately, things have been changing. We are able to get really good FPS (Frames Per Second) and detection accuracy at the same time. Object detection models with lighter backbones help us achieve this. And we are going to see one such example in this post. Here, we will be using SSDLite with MobileNetV3 backbone for object detection using PyTorch and Torchvision.
In the previous post, we explored object detection using SSD300 with the VGG16 backbone. You may have a look at it to get some background as we are going to use almost the same coding style in this post as well. There will still be a few minor changes.
What will we cover in this post?
- We will mainly focus on using a pre-trained SSDLite object detection model with the MobileNetV3 backbone to carry out object detection.
- We will carry out inference on both images and videos and see how it performs.
A Brief About the SSDLite Model
By now, we know that we will be using a pre-trained model. It is already available as a part of the torchvision
module in the PyTorch framework.
In fact, the complete name is ssdlite320_mobilenet_v3_large
. The 320 indicates that it internally resizes the inputs to the 320×320 and it has a MobileNetV3 Large backbone model. The model has been pre-trained on the MS COCO object detection dataset.
It is also good to know what is the input and output format to and from the model while carrying out object detection inference. We have discussed that in the previous post for the SSD300 model with the VGG16 backbone. It is the same in this case as well. So, please do take a look if you are interested to know more.
The PyTorch Version
Now, this is a bit of an important part. To use the SSDLite with the MobileNetV3 backbone for object detection, you need to have at least PyTorch version 1.9.0 installed on your system.
It is not available for the older versions. If you already have it (or even a higher version), then you are good to go. Else, you need to install it before going further. You can get the command to install the latest version according to your requirements from the official installation page.
The Directory Structure
Let’s follow a simple and clean directory structure for this mini-experiment of ours. Take a look at the following.
├── input │ ├── image_1.jpg │ ├── image_2.jpg │ ├── video_1.mp4 │ └── video_2.mp4 ├── outputs │ └── image_1_05.jpg | ... ├── coco_names.py ├── detect_image.py ├── detect_utils.py ├── detect_video.py ├── model.py
- We have an
input
folder containing 2 images and 2 videos. Out of these, we will carry our object detection inference on one image and two of the videos. - The
outputs
folder will contain the output images and videos after the inference is complete. - Then there are five Python files in which we will write the required code.
You can download the source code and input test data for this post by clicking on the button below.
Let’s start with the coding part of this post now.
SSDLite with MobileNetV3 Backbone for Object Detection using PyTorch and Torchvision
From here onward, we will focus on the coding part of the post.
The MS COCO Class Names
We need to map the detection labels to the MS COCO class names after we carry out the object detection in an image or video frame. For this we need the MS COCO class names.
For simplicity we will create a Python file and keep all the names in a list which we import to whichever other script we want. The following code will go into the coco_names.py
file.
COCO_INSTANCE_CATEGORY_NAMES = [ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush' ]
The COCO_INSTANCE_CATEGORY_NAMES
is a list containing all the class names.
The SSDLite MobileNetV3 Model
As we will be using the SSDLite with MobileNetV3 backbone for object detection in both images and videos, it is better to make it a reusable module. This makes our code much cleaner while reducing the lines of code as well.
The following code will go into the model.py
file.
import torchvision def get_model(device): # load the model model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(pretrained=True) # load the model onto the computation device model = model.eval().to(device) return model
The get_model()
function accepts the computation device
as the input parameter. We load the model from the torchvision
module, switch it to eval()
mode, load it onto the computation device
and return the model.
This is all we need for the model.py
file.
Helper Function and Utility Code
We need a few lines of utility code and helper functions to make our object detection work easier. Let’s get on to write those. Almost all of the code here is similar to the previous post. So, we will not dive much into the details. Please have a look at the previous post for more details.
The following code will go into the detect_utils.py
file.
import torchvision.transforms as transforms import cv2 import numpy as np import torch from coco_names import COCO_INSTANCE_CATEGORY_NAMES as coco_names # this will help us create a different color for each class COLORS = np.random.uniform(0, 255, size=(len(coco_names), 3)) # define the torchvision image transforms transform = transforms.Compose([ transforms.ToTensor(), ])
First, we are importing all the required modules and libraries along with the COCO_INSTANCE_CATEGORY_NAMES
that we defined above.
On line 9, we are defining a COLORS array which holds different color tuples for each of the COCO
classes. We can use these to add different colors to the bounding boxes and texts for each of the classes while annotating the images with OpenCV.
Then we define a simple transform
to convert the input data to tensor.
Helper Function for Forward Pass and Detection
The following is a simple function that forward propagates the input through the model and returns the relevant and required outputs.
def predict(image, model, device, detection_threshold): """ Predict the output of an image after forward pass through the model and return the bounding boxes, class names, and class labels. """ # transform the image to tensor image = transform(image).to(device) # add a batch dimension image = image.unsqueeze(0) # get the predictions on the image with torch.no_grad(): outputs = model(image) # get score for all the predicted objects pred_scores = outputs[0]['scores'].detach().cpu().numpy() # get all the predicted bounding boxes pred_bboxes = outputs[0]['boxes'].detach().cpu().numpy() # get boxes above the threshold score boxes = pred_bboxes[pred_scores >= detection_threshold].astype(np.int32) labels = outputs[0]['labels'][:len(boxes)] # get all the predicited class names pred_classes = [coco_names[i] for i in labels.cpu().numpy()] return boxes, pred_classes, labels
The predict()
function accepts the image
, model
, device
, and detection_threshold
as input parameters. After the forward pass on line 27, we filter out all the predictions according to the threshold value. All the detections which have confidence score below the given value are dropped.
The boxes
containing the bounding box coordinates, pred_classes
containing the class names, and labels
containing the class numbers are returned on line 40.
Helper Function to Draw the Bounding Boxes
We have another simple function to draw the bounding boxes around the objects and put the class name text on top of these boxes.
def draw_boxes(boxes, classes, labels, image): """ Draws the bounding box around a detected object. """ image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB) for i, box in enumerate(boxes): color = COLORS[labels[i]] cv2.rectangle( image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, 2 ) cv2.putText(image, classes[i], (int(box[0]), int(box[1]-5)), cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2, lineType=cv2.LINE_AA) return image
Each class will have a different color bounding box to easily distinguish them. The final image that we return on line 57 is the completely annotated image with bounding boxes and class texts.
As noted earlier, we did not dive into the line-by-line explanation of the code. With this, we finish the utility code and helper functions.
Object Detection using SSDLite MobileNetV3 in Images
In this section, we will use the pre-trained model to detect objects in images. All the code here will go into the detect_image.py
script unless otherwise specified.
Let’s start with importing all the required modules and preparing the argument parser.
import torch import argparse import cv2 import detect_utils from PIL import Image from model import get_model # construct the argument parser parser = argparse.ArgumentParser() parser.add_argument('-i', '--input', default='input/image_1.jpg', help='path to input input image') parser.add_argument('-t', '--threshold', default=0.5, type=float, help='detection threshold') args = vars(parser.parse_args())
- We are importing our own
detect_utils
module for detection and annotation. - Along with that, we are also importing the
get_model
function from themodel
module. - For the argument parser, we have two flags:
--input
: for the input image path.--threshold
: to specify the detection confidence threshold below which all the detections will be dropped.
The next block of code defines the computation device and the SSDLite model.
# define the computation device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = get_model(device)
Depending upon what hardware you have, all the detections will run either on the CPU or GPU. Although for deep learning based object detection GPU is preferable, you should be good to go even with a decent CPU as we are using the SSDLite model with the MobileNetV3 backbone.
Read the Image, Preprocess, and Detect Objects
The final few steps:
- Read the image from the disk.
- Call the
predict()
function from thedetect_utils
module while providing the required arguments. - Draw the bounding boxes and class text on the image.
# read the image image = Image.open(args['input']) # detect outputs boxes, classes, labels = detect_utils.predict(image, model, device, args['threshold']) # draw bounding boxes image = detect_utils.draw_boxes(boxes, classes, labels, image) save_name = f"{args['input'].split('/')[-1].split('.')[0]}_{''.join(str(args['threshold']).split('.'))}" cv2.imshow('Image', image) cv2.imwrite(f"outputs/{save_name}.jpg", image) cv2.waitKey(0)
We are carrying out the detection on line 22 and drawing the bounding boxes on line 24. After that, we visualize the result on the screen and save the result to disk using the save_name
string that we define.
This completes the code for object detection in images.
Execute detect_image.py for Object Detection in Images
We will use one of the images from the input
folder to object detection. Make sure that you are in the current directory where all the Python scripts are present. Execute the following command in your command line/terminal.
python detect_image.py --input input/image_1.jpg -t 0.5
Let’s see what output we are getting with a threshold of 0.5.
Okay, the model is able to detect the person and horse alright. But it is not able to detect the dog. Now, we need to keep in mind that models like SSDLite with MobileNetV3 backbone tend the trade off accuracy for detection speed. So, maybe using a lower confidence threshold will help. Well, actually, from testing I found that the model is able to detect every object in the image when the threshold is 0.1
python detect_image.py --input input/image_1.jpg -t 0.1
But that too is a wrong detection. The model is detecting the obvious dog as a sheep. This is one of the drawbacks of faster models with lighter backbones. They tend to make mistakes which larger models don’t.
Do try a few more images on your own and see what kind of results you are getting.
Object Detection using SSDLite MobileNetV3 in Videos
Now, we will write the code to detect objects in videos. Again, we will not go into much detail about the code as it is almost similar to what we covered in the previous post.
This code will go into the detect_video.py
file.
Let’s import the modules, construct the argument parser, and define the computation device and model. This part is very similar to the object detection in images part.
import cv2 import torch import argparse import time import detect_utils from model import get_model # construct the argument parser parser = argparse.ArgumentParser() parser.add_argument('-i', '--input', default='input/video_1.mp4', help='path to input video') parser.add_argument('-t', '--threshold', default=0.5, type=float, help='detection threshold') args = vars(parser.parse_args()) # define the computation device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = get_model(device)
The only difference is that for the input we will provide the path to a video file instead of an image.
Read the Video File
Here, we will read the video file and complete some other preliminary things like getting the video frames’ height and width.
cap = cv2.VideoCapture(args['input']) if (cap.isOpened() == False): print('Error while trying to read video. Please check path again') # get the frame width and height frame_width = int(cap.get(3)) frame_height = int(cap.get(4)) save_name = f"{args['input'].split('/')[-1].split('.')[0]}_{''.join(str(args['threshold']).split('.'))}" # define codec and create VideoWriter object out = cv2.VideoWriter(f"outputs/{save_name}.mp4", cv2.VideoWriter_fourcc(*'mp4v'), 30, (frame_width, frame_height)) frame_count = 0 # to count total frames total_fps = 0 # to get the final frames per second
We are also initializing the VideoWriter()
object to save the resulting frames to disk. The frame_count
and total_fps
variables will keep track of the total number of frames and cumulative FPS till the end respectively,
Loop Through the Video Frames and Detect Objects in Each Frame
We will use a while
loop to loop through the video frames and detect the objects in each frame.
# read until end of video while(cap.isOpened()): # capture each frame of the video ret, frame = cap.read() if ret == True: # get the start time start_time = time.time() with torch.no_grad(): # get predictions for the current frame boxes, classes, labels = detect_utils.predict(frame, model, device, args['threshold']) # draw boxes and show current frame on screen image = detect_utils.draw_boxes(boxes, classes, labels, frame) # get the end time end_time = time.time() # get the fps fps = 1 / (end_time - start_time) # add fps to total fps total_fps += fps # increment frame count frame_count += 1 # write the FPS on the current frame cv2.putText(image, f"{fps:.3f} FPS", (15, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2) # convert from BGR to RGB color format image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) cv2.imshow('image', image) out.write(image) # press `q` to exit if cv2.waitKey(1) & 0xFF == ord('q'): break else: break # release VideoCapture() cap.release() # close all frames and video windows cv2.destroyAllWindows() # calculate and print the average FPS avg_fps = total_fps / frame_count print(f"Average FPS: {avg_fps:.3f}")
For each frame:
- We are detecting the objects.
- Drawing the bounding boxes around the detected objects.
- Annotating the frame with the FPS text.
- Visualizing it and saving it to disk.
Finally, we are releasing the VideoCapture()
object and destroying all OpenCV windows.
We have completed the code for object detection in videos as well. Let’s execute the script and see what kind of results we get.
Execute detect_video.py for Object Detection in Videos
Note that all the detections are happening on a laptop with 16 GB RAM, i7 8th Gen CPU and 6 GB of GTX 1060 GPU. Your results will vary depending on the hardware that you use.
Let’s start with the first video in the input
folder.
python detect_video.py --input input/video_1.mp4 -t 0.3
Average FPS: 22.220
We are carrying out detection with a threshold of 0.3 and getting an average FPS of 22.220. Not that bad. But interestingly, we are getting lower FPS than the SSD300 model with the VGG16 backbone that we experimented with in the last post.
The detections are decent though. It is detecting the person, the motorcycles, and the truck. It’s struggling a bit to detect the persons at the far back though.
Let’s try another video.
python detect_video.py --input input/video_2.mp4 -t 0.3
Average FPS: 21.201
The average FPS is 21.201. How about the detection?
The model is able to detect the persons which are nearer to the camera. There are few detections for the cars at the left side of the video but the bounding boxes seem to be a bit wrong. Again, issue of trade-off between accuracy and speed.
A Few Takeaways
- Object detection models like SSDLite with lighter backbones such as MobileNetV3 are great for deployment and edge devices.
- But they are less accurate. They trade-off accuracy for speed.
- If your application can do with a bit less accurate predictions and speed is of more importance, then surely go with such models.
Summary and Conclusion
In this post, we explored object detection using the SSDLite model with the MobileNetV3 backbone. We saw some of the advantages and drawbacks of such less compute-intensive models. I hope that you learned something new from this post.
If you have any doubts, thoughts, or suggestions, then please leave them in the comment section. I will surely address them.
You can contact me using the Contact section. You can also find me on LinkedIn, and Twitter.
1 thought on “SSDLite MobileNetV3 Backbone Object Detection with PyTorch and Torchvision”