Fine-Tuning Gemma 4 for Vision


Fine-Tuning Gemma 4 for Vision

In the last two articles, we covered Gemma 4 inference for various tasks and carried out fine-tuning for transcription and translation of audio. In this article, we will be fine-tuning Gemma 4 for a vision task.

Gemma 4 E2B vision fine-tuning evaluation sample results.
Figure 1. Gemma 4 E2B vision fine-tuning evaluation sample results.

We will focus on a medical use case. Specifically, we will fine-tune the Gemma 4 E2B model on a radiology VQA dataset. We will discuss the details of the dataset later in the article. All the training will happen via the Unsloth library.

What will we cover while fine-tuning Gemma 4 for the vision task?

  • Discussing the radiology VQA dataset.
  • The dataset preparation, training, and evaluation of the model.
  • Inference using the fine-tuned Gemma 4 vision model.

The VQA RAD Dataset

The VQA RAD dataset is a radiology question-answering dataset. Here, clinicians have asked naturally occurring questions in radiology related to several images, along with the reference answers.

You can download the dataset from here on OSF.

The following is the directory structure after downloading and extracting the dataset locally.

osfstorage-archive
├── VQA_RAD Image Folder  [315 entries exceeds filelimit, not opening dir]
├── Readme.docx
├── VQA_RAD Dataset Public.json
├── VQA_RAD Dataset Public.xlsx
└── VQA_RAD Dataset Public.xml

There are 315 images in total. The Readme.docx file contains additional information about the dataset that we can go through. The XML, JSON, and XLSX files contain the same QA and reference answer samples. However, going further, we will focus on the VQA_RAD Dataset Public.json to prepare the dataset.

Let’s take a look at one of the samples from the JSON file along with its images.

{
      "qid": "0",
      "phrase_type": "freeform",
      "qid_linked_id": "03f451ca-de62-4617-9679-e836026a7642",
      "image_case_url": "https://medpix.nlm.nih.gov/case?id=48e1dd0e-8552-46ad-a354-5eb55be86de6",
      "image_name": "synpic54610.jpg",
      "image_organ": "HEAD",
      "evaluation": "not evaluated",
      "question": "Are regions of the brain infarcted?",
      "question_rephrase": "NULL",
      "question_relation": "NULL",
      "question_frame": "NULL",
      "question_type": "PRES",
      "answer": "Yes",
      "answer_type": "CLOSED"
   },

The above is the first sample from the JSON file. The following is the corresponding image.

Image sample from the VQA RAD dataset.
Figure 2. Image sample from the VQA RAD dataset.

We have several key-value pairs. Some of the important ones are:

  • image_organ: Type of organ, e.g., Head, Chest, Abdomen.
  • question: Question about the image.
  • question_rephrase: A rephrased form of the question, if possible.
  • question_type: The type of question: MODALITY, PLANE, ORGAN (Organ System). etc.
  • answer: Answer to the question.
  • answer_type: Type of answer, whether closed-ended or open-ended.

Going through the DOCX file will give you additional information.

In one of the previous articles, we fine-tuned Qwen3.5 on the same dataset. You can go through the article and compare later how Gemma 4 fares against Qwen3.5 under similar scenarios.

You may observe that while there are only 315 images, there are 2247 question IDs. This is because for the same images, we can have multiple questions.

Later in the article, we will see how we structure the question and answer pairs for the Gemma 4 fine-tuning process.

Project Directory Structure

The following is the project directory structure that we are using:

├── gemma4_e2b_lora
│   ├── adapter_config.json
│   ├── adapter_model.safetensors
│   ├── chat_template.jinja
│   ├── processor_config.json
│   ├── README.md
│   ├── tokenizer_config.json
│   └── tokenizer.json
├── hf_eval_dataset
│   ├── data-00000-of-00001.arrow
│   ├── dataset_info.json
│   └── state.json
├── input
│   ├── osfstorage-archive
│   └── osfstorage-archive.zip
├── annotations.csv
├── gemma4_e2b_ft.ipynb
├── gemma4_e2b_vqa_rad_fine_tuned.ipynb
├── README.md
└── requirements.txt
  • The input directory contains the dataset we discussed in the previous section. You can download, extract it, and put it in the input directory.
  • We have two Jupyter Notebooks. One for fine-tuning, another for inference.
  • The hf_eval_dataset is the test dataset that we are also saving separately for easier inference and evaluation.
  • The gemma4_e2b_lora directory contains the trained adapter.

The article comes with a downloadable zip file containing the Jupyter Notebooks, evaluation dataset, and the trained adapters. You can download the zip file, extract it, and arrange the input directory in the above format to start the fine-tuning process.

Download Code

Installing Dependencies

Although the notebooks contain the installation commands, you can install beforehand using the requirements file.

pip install -r requirements.txt

This is all the setup we need. Let’s move on to discuss the code now.

We can also fine-tune Gemma 4 for audio transcription and translation. The linked article covers all the code and processes in detail, along with dataset preparation.

Fine-Tuning Gemma 4 for Vision Question Answering

All the training code is present in the gemma4_e2b_ft.ipynb Jupyter Notebook. Let’s cover that in detail here.

Imports

The first code block imports all the necessary libraries and modules.

from unsloth import FastVisionModel
from datasets import load_dataset
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig
from transformers import TextStreamer
from datasets import Dataset

import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import pprint

The above code block imports everything we need for loading the model, preparing the dataset, and fine-tuning it.

Loading the Model and Preparing for Fine-Tuning

The next block loads the Gemma 4 E2B model from Unsloth and prepares it for fine-tuning.

model, tokenizer = FastVisionModel.from_pretrained(
    model_name='unsloth/gemma-4-E2B-it',
    dtype=None,
    max_seq_length=512,
    load_in_4bit=False,
    full_finetuning=False,
)

model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers=True,
    finetune_language_layers=True,
    finetune_attention_modules=True,
    finetune_mlp_modules=True,

    r=32,           
    lora_alpha=64, 
    lora_dropout=0,
    bias='none',
    random_state=3407,
    use_rslora=False,  
    loftq_config=None,
    # target_modules='all-linear',
)

One important thing to observe in the above code block is that we are not loading the model in quantized format. This is the current recommendation from the Unsloth library at the moment. While loading the model, it requires around 10GB of VRAM.

Furthermore, we are using a rank of 32 and an alpha of 64 for the fine-tuning process. As this is a slightly complex task, we can benefit from a higher rank.

Dataset Preparation

The next step is quite crucial. We need to prepare the dataset in the best possible way while combining all the information that the model needs to learn the patterns of the dataset.

Let’s load the annotation JSON file. Along with that, we will also save a CSV format of the dataset that we can load in Hugging Face Datasets format.

root_dir = 'input/osfstorage-archive'
image_folder = 'VQA_RAD Image Folder'
annotation_file = 'VQA_RAD Dataset Public.json'

annotations = pd.read_json(f"{root_dir}/{annotation_file}")
annotations.head()

# Convert dataframe to HF dataset format.
# To CSV.
annotations.to_csv('annotations.csv', index=False)
dataset = load_dataset('csv', data_files='annotations.csv')['train']

# Shuffle the dataset.
dataset = dataset.shuffle(seed=3407)

The next step is to prepare the training and validation sets.

# Split into train and eval.
dataset = dataset.train_test_split(test_size=0.1, seed=3407)
train_dataset = dataset['train']
eval_dataset = dataset['test']

print(f"Train dataset size: {len(train_dataset)}")
print(f"Eval dataset size: {len(eval_dataset)}")

We have 2023 samples for training and 225 samples for testing.

The model needs to see all the relevant information in a structured way, along with the ground truth answer. For that, we are using the following format.

instruction = """Answer the following question based on the given image.
Here is some additional information about the type of question that you will encounter:
Type of question:
MODALITY
PLANE
ORGAN (Organ System)
ABN (Abnormality)
PRES (Object/Condition Presence)
POS (Positional Reasoning)
COLOR
SIZE
ATTRIB (Attribute Other)
COUNT (Counting)
Other
"""

def convert_to_conversation(sample):
    image_name = sample['image_name']
    image = f"{root_dir}/{image_folder}/{image_name}"
    # Question to model.
    question = sample['question']
    # Managing model's answers.
    answer = str(sample['answer'])
    rephrased_question = sample['question_rephrase'] if sample['question_rephrase'] is not None else ''
    organ = sample['image_organ'] if sample['image_organ'] is not None else ''
    question_type = sample['question_type'] if sample['question_type'] is not None else ''
    if rephrased_question == '':
        final_answer = f"This is question about {organ} and the question type is {question_type}. The answer is {answer}."
    else:
        final_answer = f"This is question about {organ} and the question type is {question_type}. The question can also be rephrased as: {rephrased_question}. The answer is {answer}."

    # print("Question:", question)
    # print("Rephrased Question:", rephrased_question)
    # print("Answer:", answer)

    conversation = [
        { 'role': 'user',
          'content' : [
            {'type' : 'text',  'text'  : instruction + "QUESTION: " + question},
            {'type' : 'image', 'image' : image} ]
        },
        { 'role' : 'assistant',
          'content' : [
            {'type' : 'text',  'text'  : final_answer} ]
        },
    ]
    return { 'messages' : conversation }
pass

converted_dataset_train = [convert_to_conversation(sample) for sample in train_dataset]
converted_dataset_eval = [convert_to_conversation(sample) for sample in eval_dataset]

We have a general instruction for the model containing all types of questions it can encounter. Along with that, convert_to_conversation constructs each sample in a very specific format. We have the instruction first, then the question. If the question can be rephrased in a certain way, then that is also part of the model’s response, along with the organ type. Although there are several better ways to construct the QA pairs, the current structure gives the model maximum context about a certain image.

The following is a sample ground truth pair after the dataset conversion.

{'messages': [{'role': 'user',
               'content': [{'type': 'text',
                            'text': 'Answer the following question based on '
                                    'the given image.\n'
                                    'Here is some additional information about '
                                    'the type of question that you will '
                                    'encounter:\n'
                                    'Type of question:\n'
                                    'MODALITY\n'
                                    'PLANE\n'
                                    'ORGAN (Organ System)\n'
                                    'ABN (Abnormality)\n'
                                    'PRES (Object/Condition Presence)\n'
                                    'POS (Positional Reasoning)\n'
                                    'COLOR\n'
                                    'SIZE\n'
                                    'ATTRIB (Attribute Other)\n'
                                    'COUNT (Counting)\n'
                                    'Other\n'
                                    'QUESTION: Is there biliary duct '
                                    'dilation?'},
                           {'type': 'image',
                            'image': 'input/osfstorage-archive/VQA_RAD Image '
                                     'Folder/synpic33889.jpg'}]},
              {'role': 'assistant',
               'content': [{'type': 'text',
                            'text': 'This is question about ABD and the '
                                    'question type is SIZE. The question can '
                                    'also be rephrased as: Are the biliary '
                                    'ducts dilated?. The answer is Yes.'}]}]}

Inference Before Fine-Tuning

Let’s do a sample inference run on the evaluation set before fine-tuning so that we can compare the results later.

FastVisionModel.for_inference(model) # Enable for inference!

text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)


for i in range(10):
    image = converted_dataset_eval[i]['messages'][0]['content'][1]['image'] # Extract the image.
    instruction = converted_dataset_eval[i]['messages'][0]['content'][0]['text'] # Extract the instruction text.

    messages = [
        {'role': 'user', 'content': [
            {'type': 'image', 'image': image},
            {'type': 'text', 'text': instruction}
        ]}
    ]
    input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
    inputs = tokenizer(
        image,
        input_text,
        add_special_tokens=False,
        return_tensors='pt',
    ).to('cuda')

    plt.imshow(plt.imread(image))
    plt.axis('off')
    plt.show()

    print(f"INSTRUCTION: {instruction}")

    print('MODEL OUTPUT:')
    _ = model.generate(
        **inputs, 
        streamer=text_streamer, 
        max_new_tokens=256,
        use_cache=True, 
        temperature=1.0, 
        top_p=0.95, 
        top_k=64
    )

    print('GROUND TRUTH:')
    print(converted_dataset_eval[i]['messages'][1]['content'][0]['text'])
    print()
    print('*' * 50)

The following is the first sample image along with the model’s response:

VQA RAD sample image to check Gemma 4 E2B results before vision fine-tuning.
Figure 3. VQA RAD sample image to check Gemma 4 E2B results before vision fine-tuning.
MODEL OUTPUT:
Based on the provided axial CT image of the abdomen:

1.  **Locate the kidneys:** The kidneys are located in the retroperitoneum, typically in the flank area.
2.  **Examine the left kidney:** Look specifically at the left kidney on this cross-section.

In the image, the left kidney appears generally intact, and there is no obvious, well-defined cyst visible within the parenchyma of the left kidney.

**Answer:** **No**
GROUND TRUTH:
This is question about ABD and the question type is PRES. The question can also be rephrased as: Is a cystic cavity present in the left kidney on this image?. The answer is No.

Currently, the model output is verbose with a “thinking like” process that deviates from the dataset’s response structure. You can go through the other output samples to check whether they are right or wrong. Let’s hope that after fine-tuning, the model learns the structure and also gives the correct answer for most of the questions.

Fine-Tuning the Model

The following code block contains all the code to start the fine-tuning process.

FastVisionModel.for_training(model) # Enable for training!

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    data_collator=UnslothVisionDataCollator(model, tokenizer),
    train_dataset=converted_dataset_train,
    eval_dataset=converted_dataset_eval,
    args=SFTConfig(
        per_device_train_batch_size=16, # 24
        per_device_eval_batch_size=4,
        gradient_accumulation_steps=2,
        warmup_steps=5,
        num_train_epochs=4,
        learning_rate=2e-4,
        logging_steps=50,
        eval_steps=50,
        eval_strategy='steps',
        do_eval=True,
        optim='adamw_8bit',
        weight_decay=0.001,
        lr_scheduler_type='linear',
        seed=3407,
        output_dir='outputs',
        report_to='none',

        # You MUST put the below items for vision finetuning:
        remove_unused_columns=False,
        dataset_text_field='',
        dataset_kwargs={'skip_prepare_dataset': True},
        max_length=2048,
    ),
)

trainer_stats = trainer.train()
  • We are training the model for 4 epochs.
  • The training batch size is 16 and the validation batch size is 4 along with a gradient accumulation step of 2.

With the current setup, the fine-tuning process requires ~20GB VRAM. The training was done on a 24GB NVIDIA L4 GPU, and it ran for around 1 hour.

The following are the fine-tuning logs.

Fine-tuning 4 logs for Gemma 4 vision.
Figure 4. Fine-tuning 4 logs for Gemma 4 vision.

Saving the Model Locally

Finally, we are saving the last checkpoint locally.

model.save_pretrained('gemma4_e2b_lora')
tokenizer.save_pretrained('gemma4_e2b_lora')

Note that we are saving a slightly overfit model rather than the model with the best loss here. In such scenarios, an overfit model might capture the patterns of responses better. But this is not ideal for all kinds of datasets.

Running Inference using the Fine-Tuned Model

The gemma4_e2b_vqa_rad_fine_tuned.ipynb Jupyter Notebook contains all the code to load the fine-tuned model and run inference. Let’s run through the code quickly.

Imports and Loading the Model

Let’s import the necessary modules and load the Gemma 4 Vision fine-tuned model.

from unsloth import FastVisionModel
from PIL import Image
from transformers import TextStreamer
from datasets import load_from_disk

import matplotlib.pyplot as plt

# Load model.
def load_model(model_name='gemma4_e2b_lora'):
    model, tokenizer = FastVisionModel.from_pretrained(
        model_name,
        dtype=None,
        max_seq_length=512,
        load_in_4bit=False,
        full_finetuning=False,
    )

    return model, tokenizer

model, tokenizer = load_model(
    model_name='gemma4_e2b_lora'
)

FastVisionModel.for_inference(model)

Load the Evaluation Dataset and Run Inference

We have saved the evaluation set locally, which we next load and run the inference on 10 samples.

# Load eval dataset in HF format.
eval_dataset = load_from_disk('hf_eval_dataset')

FastVisionModel.for_inference(model)

text_streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)


for i in range(10):
    image = eval_dataset[i]['messages'][0]['content'][1]['image'] # Extract the image.
    instruction = eval_dataset[i]['messages'][0]['content'][0]['text'] # Extract the instruction text.

    messages = [
        {'role': 'user', 'content': [
            {'type': 'image', 'image': image},
            {'type': 'text', 'text': instruction}
        ]}
    ]
    input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
    inputs = tokenizer(
        image,
        input_text,
        add_special_tokens=False,
        return_tensors='pt',
    ).to('cuda')

    plt.imshow(plt.imread(image))
    plt.axis('off')
    plt.show()

    print(f"INSTRUCTION: {instruction}")

    print('MODEL OUTPUT:')
    _ = model.generate(
        **inputs, 
        streamer=text_streamer, 
        max_new_tokens=256,
        use_cache=True, 
        temperature=1.0, 
        top_p=0.95, 
        top_k=64
    )

    print('GROUND TRUTH:')
    print(eval_dataset[i]['messages'][1]['content'][0]['text'])
    print()
    print('*' * 50)

Analyzing the Results

If we compare the results to the pretrained version of the model, we observe that the response now matches the dataset structure much more closely. However, in some scenarios, the model still refuses to answer, such as the first sample, where the question is about the size of the lesion.

Inference sample 1 after fine-tuning Gemma 4 E2B Vision.
Figure 5. Inference sample 1 after fine-tuning Gemma 4 E2B Vision.

This might be due to the model’s safety alignment, which means it does not answer questions it is not very confident about. Such scenarios might be more prominent in medical imaging Q&A use cases.

The next figure shows instances where the model gives reasoning along with the answer.

Inference sample 4 after fine-tuning Gemma 4 E2B Vision.
Figure 6. Inference sample 4 after fine-tuning Gemma 4 E2B Vision.

However, we cannot corroborate that this reasoning is entirely correct here.

There are other instances as well where the model gives a reasoning-like trace.

Inference sample 10 after fine-tuning Gemma 4 E2B Vision.
Figure 7. Inference sample 10 after fine-tuning Gemma 4 E2B Vision.

However, we cannot verify whether the reasoning is correct or not, as the ground truth answer does not contain that.

Key Takeaways

Although we fine-tuned the Gemma 4 for a vision use case here, there are several loopholes when training on such medical imaging use cases.

Firstly, this article should only be taken as a fine-tuning tutorial and not as an advocacy of training a proper VLM for medical use cases.

Secondly, the model did learn some structure of the output. However, it did not learn the patterns to answer about the organ, the reframing of the questions, and also the single-word responses. In fact, the long reasoning traces make it even more difficult for us to validate whether it is correct or not.

Summary and Conclusion

In this article, we carried out fine-tuning of the Gemma 4 Vision model on a medical imaging use case. We started with the dataset discussion, followed by fine-tuning and inference, and finally understanding the nuances and caveats of the results.

If you have any questions, thoughts, or suggestions, please leave them in the comment section. I will surely address them.

You can contact me using the Contact section. You can also find me on LinkedIn, and X.

Liked it? Take a second to support Sovit Ranjan Rath on Patreon!
Become a patron at Patreon!

Leave a Reply

Your email address will not be published. Required fields are marked *