In the last article, we covered the introduction to BiRefNet along with a small background removal demo. We discussed how BiRefNet’s unique architecture positions it to better segment the foreground from the background. In this article, we will take a step further. We will create a simple background replacement application using BiRefNet.
With BiRefNet’s strong dichotomous segmentation properties and adding a bit of image processing, we can achieve excellent results for background replacement. A small demo is shown in the above GIF.
What will we cover in this article?
- How to set up the BiRefNet codebase along with the new background replacement code files?
- Creating a simple Jupyter Notebook for background replacement.
- Integrating the above code to create a Gradio application.
- Testing with images to check how well our background replacement using BiRefNet works.
Setting Up the Codebase for BiRefNet Background Replacement
Let’s set up the codebase first. This is going to be similar to the last article, where we covered the introduction to BiRefNet and used the model for background removal in images.
First, we will clone the official BiRefNet repository.
git clone https://github.com/ZhengPeng7/BiRefNet.git
Second, we will add our directories for images, backgrounds, a Jupyter Notebook, and a Gradio application for background replacement.
Final Project Directory Structure
The final project directory structure looks like the following:
BiRefNet/ ├── backgrounds │ ├── bg_1.jpg │ └── bg_2.jpg ├── evaluation │ └── metrics.py ├── images │ ├── image_1.jpg │ ├── image_2.jpg │ └── image_3.jpg ├── models │ ├── backbones │ ├── modules │ ├── refinement │ └── birefnet.py ├── predictions │ ├── image_1-mask.png │ ... │ └── image_3-subject.png ├── tutorials │ ├── birefnet_background_removal.ipynb │ ... │ └── BiRefNet_pth2onnx.ipynb ├── weights │ └── BiRefNet-general-epoch_244.pth ├── birefnet_background_replacement.ipynb ├── config.py ├── dataset.py ├── eval_existingOnes.py ├── gen_best_ep.py ├── gradio_app.py ├── image_proc.py ├── inference.py ├── LICENSE ├── loss.py ├── make_a_copy.sh ├── README.md ├── requirements.txt ├── rm_cache.sh ├── sub.sh ├── test.sh ├── train.py ├── train.sh ├── train_test.sh └── utils.py
In the above tree structure, BiRefNet (the top directory) is the cloned repository.
- The
imagesandbackgrounddirectories contain the data that we will use for inference and background replacement. birefnet_background_replacement.ipynbis a Jupyter Notebook containing the code for background replacement.- And
gradio_app.pyis the Gradio application that adds a UI component to the application for a user-friendly experience.
The images & backgrounds directories, along with the Jupyter Notebook and the Gradio app file are provided for download in the download section. Please transfer them to the cloned BiRefNet directory after extracting the content.
Apart from these, all other files and directories are part of the official repository.
Download Code
Additional Setup Steps
Before moving forward with the code, ensure that you create a weights directory inside the cloned folder. Next, download the pretrained BiRefNet weights from the following link and keep it in the weights directory.
Link to download pretrained BiRefNet weights.
Installing the Requirements
Next, install the requirements.
pip install requirements.txt
We are done with all the necessary set up.
Using BiRefNet for Background Replacement
We will start with the Jupyter Notebook code first. Next, we will integrate the code into a Gradio application for background replacement.
Jupyter Notebook Code
All the code that we will discuss here is present in the birefnet_background_replacement.ipynb. We will mostly discuss the new parts of the code in detail while skipping those which were part of the last article. The code till background removal mostly remains unchanged. The only new introduction is the background replacement code after we obtain the segmented foreground object.
The Import Statements
Let’s start with the import statements.
import torch import os import cv2 import matplotlib.pyplot as plt import numpy as np from PIL import Image from torchvision import transforms from IPython.display import display from models.birefnet import BiRefNet from utils import check_state_dict from glob import glob from image_proc import refine_foreground
We import all the necessary modules from the BiRefNet codebase, and those required for image processing as well.
Setting Up the Device, Loading the Model, and Transforms
Next, we define the computation device, load the model, and define the necessary transforms.
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# device = 'cpu'
def load_model():
birefnet = BiRefNet(bb_pretrained=False)
state_dict = torch.load('weights/BiRefNet-general-epoch_244.pth', map_location=device)
state_dict = check_state_dict(state_dict)
birefnet.load_state_dict(state_dict)
# Load Model
torch.set_float32_matmul_precision(['high', 'highest'][0])
birefnet.to(device)
birefnet.eval()
print('BiRefNet is ready to use.')
birefnet.half()
return birefnet
# Choose from: 'BiRefNet', 'BiRefNet_HR', 'BiRefNet_HR-matting'
model_name = 'BiRefNet_HR'
model = load_model()
transform_image = transforms.Compose([
transforms.Resize((1024, 1024) if '_HR' not in model_name else (2048, 2048)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
Ensure that you have at least 10GB of VRAM to run these experiments. As we are choosing high-resolution transforms via BiRefNet_HR, this will resize the images to 2048×2048 resolution. Although this gives high-resolution segmentation maps, it also needs more VRAM compared to choosing the BiRefNet option.
Function To Remove Background
Next, we have a function that, given an image path, will remove its background using the BiRefNet model.
def run_inference(image_path):
image = Image.open(image_path)
input_images = transform_image(image).unsqueeze(0).to(device)
input_images = input_images.half()
# Prediction
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
# Save Results
file_ext = os.path.splitext(image_path)[-1]
pred_pil = transforms.ToPILImage()(pred)
pred_pil = pred_pil.resize(image.size)
pred_pil.save(image_path.replace(src_dir, dst_dir).replace(file_ext, '-mask.png'))
image_masked = refine_foreground(image, pred_pil)
image_masked.putalpha(pred_pil)
image_masked.save(image_path.replace(src_dir, dst_dir).replace(file_ext, '-subject.png'))
# Visualize the last sample:
# Scale proportionally with max length to 1024 for faster showing
scale_ratio = 1024 / max(image.size)
scaled_size = (int(image.size[0] * scale_ratio), int(image.size[1] * scale_ratio))
plt.figure(figsize=(15, 12))
plt.subplot(1, 3, 1)
plt.imshow(image)
plt.axis('off')
plt.subplot(1, 3, 2)
plt.imshow(image_masked)
plt.axis('off')
plt.subplot(1, 3, 3)
plt.imshow(pred_pil, cmap='gray')
plt.axis('off')
plt.show()
It saves the segmented foreground alpha channel images and the segmentation masks in the predictions directories.
Let’s create the predictions directory and loop over the images in the images directory.
src_dir = 'images'
image_paths = glob(os.path.join(src_dir, '*'))
dst_dir = 'predictions/'
os.makedirs(dst_dir, exist_ok=True)
for image_path in image_paths:
print('Processing {} ...'.format(image_path))
run_inference(image_path)
The following is a result of the above visualization.

Function to Replace Background
Next, we have the function to replace the background in images.
def apply_new_background(subject_binary, foreground, background):
"""
:param subject_binary: mask containing the foreground binary pixels
:param foreground: mask containg the extracted foreground image
:param save_name: name of the input image file
"""
# normalization of mask3d mask, keeping values between 0 and 1
subject_binary = subject_binary / 255.0
# get the scaled product by multiplying
foreground = cv2.multiply(subject_binary.astype(np.float32), foreground.astype(np.float32))
# resize it according to the foreground image
background = cv2.resize(background, (foreground.shape[1], foreground.shape[0]))
background = background.astype(np.float32)
# get the scaled product by multiplying
background = cv2.multiply(1.0 - subject_binary.astype(np.float32), background.astype(np.float32))
# add the foreground and new background image
new_image = cv2.add(foreground.astype(np.float32), background.astype(np.float32))
return new_image/255.
apply_new_background function accepts the binary mask image, the alpha channel foreground image, and the background that we want to add as parameters.
After doing all the processing, it returns the resulting image with the new background.
Let’s read the respective image and call the above function.
subject = cv2.imread('predictions/image_3-subject.png')
subject = cv2.cvtColor(subject, cv2.COLOR_BGR2RGB)
print(subject.shape)
subject_binary = cv2.imread('predictions/image_3-mask.png')
bg = cv2.imread('backgrounds/bg_2.jpg')
bg = cv2.cvtColor(bg, cv2.COLOR_BGR2RGB)
print(bg.shape)
result = apply_new_background(subject_binary, subject, bg)
plt.figure(figsize=(15, 12))
plt.imshow(result)
plt.axis('off')
plt.show()
We call the function using the jelly fish segmented images. Let’s take a look at the results.
Looks like our function is working well. Most of the credit goes to the BiRefNet model for producing such high-quality foreground and binary masks.
Gradio Application for Background Replacement using BiRefNet
Now, we will jump into the code to create a Gradio interface for the above background replacement application.
The code for this is present in the gradio_app.py file. It will be a slightly modified version of the code that we saw above.
Importing the Modules
Starting with importing the necessary modules.
import torch import os import cv2 import gradio as gr import numpy as np from PIL import Image from torchvision import transforms from models.birefnet import BiRefNet from utils import check_state_dict
Function to Refine the Foreground and Apply a New Background
We have one function to refine the foreground and put an alpha channel on it. The other one is to apply a new background to a segmented image.
# Helper Functions.
def refine_foreground(image: Image.Image, mask: Image.Image) -> Image.Image:
"""Applies the mask to the image to get the subject with a transparent background."""
mask = mask.convert('L')
image_rgba = image.convert('RGBA')
image_rgba.putalpha(mask)
return image_rgba
def apply_new_background(
original_image: Image.Image,
mask_image: Image.Image,
background_image: Image.Image
) -> Image.Image:
"""
Pastes the subject from the original image onto a new background using a mask.
This function is based on the provided logic using OpenCV.
"""
# Convert PIL Images to NumPy arrays
foreground_np = np.array(original_image.convert('RGB'))
background_np = np.array(background_image.convert('RGB'))
mask_np = np.array(mask_image.convert('L'))
# Resize background to match the original image's dimensions
bg_h, bg_w, _ = background_np.shape
fg_h, fg_w, _ = foreground_np.shape
if bg_h != fg_h or bg_w != fg_w:
print(f"Resizing background from {bg_w}x{bg_h} to {fg_w}x{fg_h}")
background_np = cv2.resize(background_np, (fg_w, fg_h), interpolation=cv2.INTER_AREA)
# Normalize mask to be in the [0, 1] range
mask_normalized = mask_np / 255.0
# Expand mask to 3 channels for multiplication
mask_3d = np.stack([mask_normalized] * 3, axis=-1)
# Blend the foreground and background
# foreground * mask + background * (1 - mask)
foreground_part = cv2.multiply(mask_3d.astype(np.float32), foreground_np.astype(np.float32))
background_part = cv2.multiply((1.0 - mask_3d).astype(np.float32), background_np.astype(np.float32))
# Combine parts and convert back to uint8
composite_np = cv2.add(foreground_part, background_part)
composite_np = np.clip(composite_np, 0, 255).astype(np.uint8)
# Convert final NumPy array back to a PIL Image
return Image.fromarray(composite_np)
Define Computation Device and Load the Model
Next, define the CUDA device and load the BiRefNet model for background replacement.
# Global Setup (Load Model Once).
print("Setting up device...")
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f"Using device: {device}")
def load_model():
"""Loads the BiRefNet model and sets it to evaluation mode."""
print("Loading BiRefNet model...")
model = BiRefNet(bb_pretrained=False)
model_path = 'weights/BiRefNet-general-epoch_244.pth'
if not os.path.exists(model_path):
raise FileNotFoundError(f"Model weights not found at: {model_path}. Please download them.")
state_dict = torch.load(model_path, map_location=device)
state_dict = check_state_dict(state_dict)
model.load_state_dict(state_dict)
torch.set_float32_matmul_precision(['high', 'highest'][0])
model.to(device)
model.eval()
if device.type == 'cuda':
model.half()
print("BiRefNet model is ready.")
return model
model = load_model()
# Choose from: 'BiRefNet', 'BiRefNet_HR', 'BiRefNet_HR-matting'
model_name = 'BiRefNet_HR'
transform_image = transforms.Compose([
transforms.Resize((1024, 1024) if '_HR' not in model_name else (2048, 2048)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
We also define the model transforms along with that.
Inference Function to Create the Segmented and Final Image
The following process_image function creates the binary mask, the foreground image, and the image with the changed background.
# The Core Inference Function.
def process_image(input_pil_image: Image.Image, background_pil_image: Image.Image = None):
"""
Takes a user-uploaded image, runs inference, and optionally replaces the background.
"""
if input_pil_image is None:
return None, None, None
print("Processing new image...")
original_size = input_pil_image.size
# Prepare the image for the model
input_tensor = transform_image(input_pil_image).unsqueeze(0).to(device)
if device.type == 'cuda':
input_tensor = input_tensor.half()
# Run Prediction
with torch.no_grad():
preds = model(input_tensor)[-1].sigmoid()
# Process prediction back to a PIL Image
pred_tensor = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred_tensor.cpu())
# Resize mask to original image size
mask_pil = pred_pil.resize(original_size, Image.Resampling.LANCZOS)
# Get the subject with a transparent background
subject_pil = refine_foreground(input_pil_image, mask_pil)
# Handle background replacement
if background_pil_image:
print("Applying new background...")
composite_image = apply_new_background(input_pil_image, mask_pil, background_pil_image)
else:
# If no background is provided, the composite result is just the transparent subject
composite_image = subject_pil
print("Processing complete.")
return composite_image, subject_pil, mask_pil
Creating the Gradio UI and Running the App
Finally, we create the Gradio UI and run the application.
# Gradio UI Definition.
with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange")) as demo:
gr.Markdown(
"""
# ✂️ BiRefNet: Background Removal & Replacement
Upload an image to isolate the main subject. You can also provide a new background to create a composite image.
This demo uses the BiRefNet model for high-resolution segmentation.
"""
)
with gr.Row():
input_image = gr.Image(type="pil", label="Input Image")
background_image = gr.Image(type="pil", label="New Background (Optional)")
process_btn = gr.Button("Process Image", variant="primary")
with gr.Tabs():
with gr.TabItem("Final Result"):
output_composite = gr.Image(type="pil", label="Image with New Background")
with gr.TabItem("Transparent Subject"):
output_subject = gr.Image(type="pil", label="Subject (Transparent PNG)", format="png")
with gr.TabItem("Generated Mask"):
output_mask = gr.Image(type="pil", label="Generated Mask")
gr.Examples(
examples=[
[os.path.join(os.path.dirname(__file__), "images/image_1.jpg")],
[os.path.join(os.path.dirname(__file__), "images/image_2.jpg")],
[os.path.join(os.path.dirname(__file__), "images/image_3.jpg")],
],
inputs=input_image,
outputs=[output_composite, output_subject, output_mask],
fn=process_image,
cache_examples=True,
)
process_btn.click(
fn=process_image,
inputs=[input_image, background_image],
outputs=[output_composite, output_subject, output_mask]
)
# Launch the App.
if __name__ == "__main__":
demo.launch(share=True)
We can execute the following command to start the application.
python gradio_app.py
This is what the UI looks like when we run the application for one image.
We can upload the image and background of our choice and run the application to get the final image. The UI also supports separate tabs to visualize the foreground and binary mask image.
The following is a video showing one full run through the application.
You can upload the images of your choice and try out how it works for complex scenarios.
Summary and Conclusion
In this article, we created a simple application for background replacement using the BiRefNet module. Such models open the prospect for real-world applications in image and video editing apps. We can try to tackle such an application in one of the future articles.
If you have any questions, thoughts, or suggestions, please leave them in the comment section. I will surely address them.
You can contact me using the Contact section. You can also find me on LinkedIn, and X.




