Background Replacement Using BiRefNet


Background Replacement Using BiRefNet

In the last article, we covered the introduction to BiRefNet along with a small background removal demo. We discussed how BiRefNet’s unique architecture positions it to better segment the foreground from the background. In this article, we will take a step further. We will create a simple background replacement application using BiRefNet.

Background replacement using BiRefNet - Gradio app demo.
Figure 1. Background replacement using BiRefNet – Gradio app demo.

With BiRefNet’s strong dichotomous segmentation properties and adding a bit of image processing, we can achieve excellent results for background replacement. A small demo is shown in the above GIF.

What will we cover in this article?

  • How to set up the BiRefNet codebase along with the new background replacement code files?
  • Creating a simple Jupyter Notebook for background replacement.
  • Integrating the above code to create a Gradio application.
  • Testing with images to check how well our background replacement using BiRefNet works.

Setting Up the Codebase for BiRefNet Background Replacement

Let’s set up the codebase first. This is going to be similar to the last article, where we covered the introduction to BiRefNet and used the model for background removal in images.

First, we will clone the official BiRefNet repository.

git clone https://github.com/ZhengPeng7/BiRefNet.git

Second, we will add our directories for images, backgrounds, a Jupyter Notebook, and a Gradio application for background replacement.

Final Project Directory Structure

The final project directory structure looks like the following:

BiRefNet/
├── backgrounds
│   ├── bg_1.jpg
│   └── bg_2.jpg
├── evaluation
│   └── metrics.py
├── images
│   ├── image_1.jpg
│   ├── image_2.jpg
│   └── image_3.jpg
├── models
│   ├── backbones
│   ├── modules
│   ├── refinement
│   └── birefnet.py
├── predictions
│   ├── image_1-mask.png
│   ...
│   └── image_3-subject.png
├── tutorials
│   ├── birefnet_background_removal.ipynb
│   ...
│   └── BiRefNet_pth2onnx.ipynb
├── weights
│   └── BiRefNet-general-epoch_244.pth
├── birefnet_background_replacement.ipynb
├── config.py
├── dataset.py
├── eval_existingOnes.py
├── gen_best_ep.py
├── gradio_app.py
├── image_proc.py
├── inference.py
├── LICENSE
├── loss.py
├── make_a_copy.sh
├── README.md
├── requirements.txt
├── rm_cache.sh
├── sub.sh
├── test.sh
├── train.py
├── train.sh
├── train_test.sh
└── utils.py

In the above tree structure, BiRefNet (the top directory) is the cloned repository.

  • The images and background directories contain the data that we will use for inference and background replacement.
  • birefnet_background_replacement.ipynb is a Jupyter Notebook containing the code for background replacement.
  • And gradio_app.py is the Gradio application that adds a UI component to the application for a user-friendly experience.

The images & backgrounds directories, along with the Jupyter Notebook and the Gradio app file are provided for download in the download section. Please transfer them to the cloned BiRefNet directory after extracting the content.

Apart from these, all other files and directories are part of the official repository.

Download Code

Additional Setup Steps

Before moving forward with the code, ensure that you create a weights directory inside the cloned folder. Next, download the pretrained BiRefNet weights from the following link and keep it in the weights directory.

Link to download pretrained BiRefNet weights.

Installing the Requirements

Next, install the requirements.

pip install requirements.txt

We are done with all the necessary set up.

Using BiRefNet for Background Replacement

We will start with the Jupyter Notebook code first. Next, we will integrate the code into a Gradio application for background replacement.

Jupyter Notebook Code

All the code that we will discuss here is present in the birefnet_background_replacement.ipynb. We will mostly discuss the new parts of the code in detail while skipping those which were part of the last article. The code till background removal mostly remains unchanged. The only new introduction is the background replacement code after we obtain the segmented foreground object.

The Import Statements

Let’s start with the import statements.

import torch
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np

from PIL import Image
from torchvision import transforms
from IPython.display import display
from models.birefnet import BiRefNet
from utils import check_state_dict
from glob import glob
from image_proc import refine_foreground

We import all the necessary modules from the BiRefNet codebase, and those required for image processing as well.

Setting Up the Device, Loading the Model, and Transforms

Next, we define the computation device, load the model, and define the necessary transforms.

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = 'cpu'

def load_model():
    birefnet = BiRefNet(bb_pretrained=False)
    state_dict = torch.load('weights/BiRefNet-general-epoch_244.pth', map_location=device)
    state_dict = check_state_dict(state_dict)
    birefnet.load_state_dict(state_dict)
    
    # Load Model
    torch.set_float32_matmul_precision(['high', 'highest'][0])
    
    birefnet.to(device)
    birefnet.eval()
    print('BiRefNet is ready to use.')
    birefnet.half()

    return birefnet

# Choose from: 'BiRefNet', 'BiRefNet_HR', 'BiRefNet_HR-matting'
model_name = 'BiRefNet_HR'
model = load_model()

transform_image = transforms.Compose([
    transforms.Resize((1024, 1024) if '_HR' not in model_name else (2048, 2048)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

Ensure that you have at least 10GB of VRAM to run these experiments. As we are choosing high-resolution transforms via BiRefNet_HR, this will resize the images to 2048×2048 resolution. Although this gives high-resolution segmentation maps, it also needs more VRAM compared to choosing the BiRefNet option.

Function To Remove Background

Next, we have a function that, given an image path, will remove its background using the BiRefNet model.

def run_inference(image_path):
    image = Image.open(image_path)
    input_images = transform_image(image).unsqueeze(0).to(device)
    input_images = input_images.half()

    # Prediction
    with torch.no_grad():
        preds = model(input_images)[-1].sigmoid().cpu()
    pred = preds[0].squeeze()

    # Save Results
    file_ext = os.path.splitext(image_path)[-1]
    pred_pil = transforms.ToPILImage()(pred)
    pred_pil = pred_pil.resize(image.size)
    pred_pil.save(image_path.replace(src_dir, dst_dir).replace(file_ext, '-mask.png'))
    image_masked = refine_foreground(image, pred_pil)
    image_masked.putalpha(pred_pil)
    image_masked.save(image_path.replace(src_dir, dst_dir).replace(file_ext, '-subject.png'))

    # Visualize the last sample:
    # Scale proportionally with max length to 1024 for faster showing
    scale_ratio = 1024 / max(image.size)
    scaled_size = (int(image.size[0] * scale_ratio), int(image.size[1] * scale_ratio))
    
    plt.figure(figsize=(15, 12))
    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.axis('off')
    plt.subplot(1, 3, 2)
    plt.imshow(image_masked)
    plt.axis('off')
    plt.subplot(1, 3, 3)
    plt.imshow(pred_pil, cmap='gray')
    plt.axis('off')

    plt.show()

It saves the segmented foreground alpha channel images and the segmentation masks in the predictions directories.

Let’s create the predictions directory and loop over the images in the images directory.

src_dir = 'images'
image_paths = glob(os.path.join(src_dir, '*'))
dst_dir = 'predictions/'
os.makedirs(dst_dir, exist_ok=True)

for image_path in image_paths:
    print('Processing {} ...'.format(image_path))
    run_inference(image_path)

The following is a result of the above visualization.

Examples after the background removal step using BiRefNet - (from left to right) original image, alpha channel foreground image, and binary mask image.
Figure 2. Examples after the background removal step using BiRefNet – (from left to right) original image, alpha channel foreground image, and binary mask image.

Function to Replace Background

Next, we have the function to replace the background in images.

def apply_new_background(subject_binary, foreground, background):
    """
    :param subject_binary: mask containing the foreground binary pixels
    :param foreground: mask containg the extracted foreground image
    :param save_name: name of the input image file
    """
    # normalization of mask3d mask, keeping values between 0 and 1
    subject_binary = subject_binary / 255.0
    # get the scaled product by multiplying
    foreground = cv2.multiply(subject_binary.astype(np.float32), foreground.astype(np.float32))
    # resize it according to the foreground image
    background = cv2.resize(background, (foreground.shape[1], foreground.shape[0]))
    background = background.astype(np.float32)
    # get the scaled product by multiplying
    background = cv2.multiply(1.0 - subject_binary.astype(np.float32), background.astype(np.float32))
    # add the foreground and new background image
    new_image = cv2.add(foreground.astype(np.float32), background.astype(np.float32))
    return new_image/255.

apply_new_background function accepts the binary mask image, the alpha channel foreground image, and the background that we want to add as parameters.

After doing all the processing, it returns the resulting image with the new background.

Let’s read the respective image and call the above function.

subject = cv2.imread('predictions/image_3-subject.png')
subject = cv2.cvtColor(subject, cv2.COLOR_BGR2RGB)
print(subject.shape)

subject_binary = cv2.imread('predictions/image_3-mask.png')

bg = cv2.imread('backgrounds/bg_2.jpg')
bg = cv2.cvtColor(bg, cv2.COLOR_BGR2RGB)
print(bg.shape)
result = apply_new_background(subject_binary, subject, bg)

plt.figure(figsize=(15, 12))
plt.imshow(result)
plt.axis('off')
plt.show()

We call the function using the jelly fish segmented images. Let’s take a look at the results.

Result of background replacement using BiRefNet.
Figure 3. Result of background replacement using BiRefNet.

Looks like our function is working well. Most of the credit goes to the BiRefNet model for producing such high-quality foreground and binary masks.

Gradio Application for Background Replacement using BiRefNet

Now, we will jump into the code to create a Gradio interface for the above background replacement application.

The code for this is present in the gradio_app.py file. It will be a slightly modified version of the code that we saw above.

Importing the Modules

Starting with importing the necessary modules.

import torch
import os
import cv2
import gradio as gr
import numpy as np
from PIL import Image
from torchvision import transforms

from models.birefnet import BiRefNet
from utils import check_state_dict

Function to Refine the Foreground and Apply a New Background

We have one function to refine the foreground and put an alpha channel on it. The other one is to apply a new background to a segmented image.

# Helper Functions.
def refine_foreground(image: Image.Image, mask: Image.Image) -> Image.Image:
    """Applies the mask to the image to get the subject with a transparent background."""
    mask = mask.convert('L')
    image_rgba = image.convert('RGBA')
    image_rgba.putalpha(mask)
    return image_rgba

def apply_new_background(
    original_image: Image.Image, 
    mask_image: Image.Image, 
    background_image: Image.Image
) -> Image.Image:
    """
    Pastes the subject from the original image onto a new background using a mask.
    This function is based on the provided logic using OpenCV.
    """
    # Convert PIL Images to NumPy arrays
    foreground_np = np.array(original_image.convert('RGB'))
    background_np = np.array(background_image.convert('RGB'))
    mask_np = np.array(mask_image.convert('L'))

    # Resize background to match the original image's dimensions
    bg_h, bg_w, _ = background_np.shape
    fg_h, fg_w, _ = foreground_np.shape
    if bg_h != fg_h or bg_w != fg_w:
        print(f"Resizing background from {bg_w}x{bg_h} to {fg_w}x{fg_h}")
        background_np = cv2.resize(background_np, (fg_w, fg_h), interpolation=cv2.INTER_AREA)

    # Normalize mask to be in the [0, 1] range
    mask_normalized = mask_np / 255.0

    # Expand mask to 3 channels for multiplication
    mask_3d = np.stack([mask_normalized] * 3, axis=-1)

    # Blend the foreground and background
    # foreground * mask + background * (1 - mask)
    foreground_part = cv2.multiply(mask_3d.astype(np.float32), foreground_np.astype(np.float32))
    background_part = cv2.multiply((1.0 - mask_3d).astype(np.float32), background_np.astype(np.float32))
    
    # Combine parts and convert back to uint8
    composite_np = cv2.add(foreground_part, background_part)
    composite_np = np.clip(composite_np, 0, 255).astype(np.uint8)

    # Convert final NumPy array back to a PIL Image
    return Image.fromarray(composite_np)

Define Computation Device and Load the Model

Next, define the CUDA device and load the BiRefNet model for background replacement.

# Global Setup (Load Model Once).
print("Setting up device...")
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {device}")

def load_model():
    """Loads the BiRefNet model and sets it to evaluation mode."""
    print("Loading BiRefNet model...")
    model = BiRefNet(bb_pretrained=False)
    
    model_path = 'weights/BiRefNet-general-epoch_244.pth'
    if not os.path.exists(model_path):
        raise FileNotFoundError(f"Model weights not found at: {model_path}. Please download them.")
        
    state_dict = torch.load(model_path, map_location=device)
    state_dict = check_state_dict(state_dict)
    model.load_state_dict(state_dict)
    
    torch.set_float32_matmul_precision(['high', 'highest'][0])
    
    model.to(device)
    model.eval()
    
    if device.type == 'cuda':
        model.half()
        
    print("BiRefNet model is ready.")
    return model

model = load_model()
# Choose from: 'BiRefNet', 'BiRefNet_HR', 'BiRefNet_HR-matting'
model_name = 'BiRefNet_HR'

transform_image = transforms.Compose([
    transforms.Resize((1024, 1024) if '_HR' not in model_name else (2048, 2048)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

We also define the model transforms along with that.

Inference Function to Create the Segmented and Final Image

The following process_image function creates the binary mask, the foreground image, and the image with the changed background.

# The Core Inference Function.
def process_image(input_pil_image: Image.Image, background_pil_image: Image.Image = None):
    """
    Takes a user-uploaded image, runs inference, and optionally replaces the background.
    """
    if input_pil_image is None:
        return None, None, None
        
    print("Processing new image...")
    original_size = input_pil_image.size
    
    # Prepare the image for the model
    input_tensor = transform_image(input_pil_image).unsqueeze(0).to(device)
    
    if device.type == 'cuda':
        input_tensor = input_tensor.half()

    # Run Prediction
    with torch.no_grad():
        preds = model(input_tensor)[-1].sigmoid()

    # Process prediction back to a PIL Image
    pred_tensor = preds[0].squeeze()
    pred_pil = transforms.ToPILImage()(pred_tensor.cpu())
    
    # Resize mask to original image size
    mask_pil = pred_pil.resize(original_size, Image.Resampling.LANCZOS)
    
    # Get the subject with a transparent background
    subject_pil = refine_foreground(input_pil_image, mask_pil)
    
    # Handle background replacement
    if background_pil_image:
        print("Applying new background...")
        composite_image = apply_new_background(input_pil_image, mask_pil, background_pil_image)
    else:
        # If no background is provided, the composite result is just the transparent subject
        composite_image = subject_pil

    print("Processing complete.")
    return composite_image, subject_pil, mask_pil

Creating the Gradio UI and Running the App

Finally, we create the Gradio UI and run the application.

# Gradio UI Definition.
with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange")) as demo:
    gr.Markdown(
        """
        # ✂️ BiRefNet: Background Removal & Replacement
        Upload an image to isolate the main subject. You can also provide a new background to create a composite image.
        This demo uses the BiRefNet model for high-resolution segmentation.
        """
    )
    
    with gr.Row():
        input_image = gr.Image(type="pil", label="Input Image")
        background_image = gr.Image(type="pil", label="New Background (Optional)")
    
    process_btn = gr.Button("Process Image", variant="primary")
    
    with gr.Tabs():
        with gr.TabItem("Final Result"):
            output_composite = gr.Image(type="pil", label="Image with New Background")
        with gr.TabItem("Transparent Subject"):
            output_subject = gr.Image(type="pil", label="Subject (Transparent PNG)", format="png")
        with gr.TabItem("Generated Mask"):
             output_mask = gr.Image(type="pil", label="Generated Mask")
        
    gr.Examples(
        examples=[
            [os.path.join(os.path.dirname(__file__), "images/image_1.jpg")],
            [os.path.join(os.path.dirname(__file__), "images/image_2.jpg")],
            [os.path.join(os.path.dirname(__file__), "images/image_3.jpg")],
        ],
        inputs=input_image,
        outputs=[output_composite, output_subject, output_mask],
        fn=process_image,
        cache_examples=True,
    )

    process_btn.click(
        fn=process_image, 
        inputs=[input_image, background_image], 
        outputs=[output_composite, output_subject, output_mask]
    )

# Launch the App.
if __name__ == "__main__":
    demo.launch(share=True)

We can execute the following command to start the application.

python gradio_app.py

This is what the UI looks like when we run the application for one image.

Resulting image with Gradio UI afer running the background replacement process once.
Figure 4. Resulting image with Gradio UI afer running the background replacement process once.

We can upload the image and background of our choice and run the application to get the final image. The UI also supports separate tabs to visualize the foreground and binary mask image.

The following is a video showing one full run through the application.

Video 1. Video showing complete workflow of background replacement using BiRefNet with Gradio UI.

You can upload the images of your choice and try out how it works for complex scenarios.

Summary and Conclusion

In this article, we created a simple application for background replacement using the BiRefNet module. Such models open the prospect for real-world applications in image and video editing apps. We can try to tackle such an application in one of the future articles.

If you have any questions, thoughts, or suggestions, please leave them in the comment section. I will surely address them.

You can contact me using the Contact section. You can also find me on LinkedIn, and X.

Liked it? Take a second to support Sovit Ranjan Rath on Patreon!
Become a patron at Patreon!

Leave a Reply

Your email address will not be published. Required fields are marked *