Published on

Finetuning Gpt-Neo for Medical Q&A

Authors
  • avatar
    Name
    Astik Dahal
    Twitter
github.png

How can small language models like GPT-Neo be tailored for specialized applications such as medical question answering? This project arose from a straightforward yet critical goal: to customize lightweight language models for practical, real-world tasks. By fine-tuning GPT-Neo using a domain-specific dataset, I aimed to assess its performance, understand its limitations, and explore ways to enhance it further.


Summary

This project fine-tuned GPT-Neo for medical Q&A using the MedQuAD dataset, achieving a token overlap score of 0.50 and BERTScore F1 of 0.5812. Results highlight strong semantic understanding but room for improvement in precision.

The Motivation

Small language models like GPT-Neo are ideal for resource-constrained applications, especially in healthcare where accurate, context-sensitive Q&A systems are crucial. They offer several key advantages:

  • Efficiency: Small models are computationally less demanding, enabling deployment in environments with limited resources.
  • Customizability: Fine-tuning smaller models with focused datasets makes them ideal for specialized tasks.
  • Accessibility: Compared to larger models requiring extensive infrastructure, GPT-Neo is far more accessible to researchers and developers.

The Dataset: Building the Knowledge Base

To train GPT-Neo for this project, I employed the MedQuAD dataset, specifically the cancer dataset, a repository of medical Q&A pairs from reliable healthcare sources. Parsing the XML files yielded a dataset comprising 729 Q&A pairs for cancer related Q&A.

files.png

Example:

Q: “What are the stages of uterine sarcoma?

A: “Stage I: Cancer is in the uterus only. Stage II: Cancer has spread within the uterus.” These structured entries formed the basis for model training and evaluation.

github.png

Comparing Training Modes

To evaluate the adaptability of GPT-Neo, I conducted experiments with two different training durations. The model setup included:

Model Architecture: GPT-Neo (125M parameters)

  • Training Parameters:
    • Learning Rate: 1e-4
    • Batch Size: 8
    • Optimizer: AdamW

Experiment 1: Low Epoch Training

Using epochs = 2

StepTraining LossValidation Loss
501.1330000.759739
1000.6392000.636311

Experiment 2: High Epoch Training

#Define training arguments
training_args = TrainingArguments(
    output_dir="./gpt_neo_medical_qa",
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    evaluation_strategy="steps",
    eval_steps=50,
    save_steps=50,
    logging_steps=50,
    learning_rate=1e-4,
    weight_decay=0.01,
    warmup_steps=5,
    logging_dir="./logs",
    push_to_hub=False,
    fp16=torch.cuda.is_available(),
    report_to="none"
)

trainer_high_epoch = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
)
trainer_high_epoch.train()

Using epochs = 10

StepTraining LossValidation Loss
500.3471000.578609
1000.2862000.563492
1500.2616000.551096
2000.1831000.576036
2500.1596000.575546
3000.1493000.585517
3500.1171000.586230
4000.1060000.601184
4500.1016000.597330
5000.0872000.613437
5500.0798000.623454
6000.0752000.627863
6500.0693000.627968
7000.0639000.637445

Key Insight

Both approaches delivered reasonable results, but longer training produced more nuanced and accurate answers. However, this came at the expense of increased computation.


Testing and Evaluation

To measure the effectiveness of the fine-tuned model, I employed three evaluation metrics: These metrics were chosen to provide a comprehensive evaluation of both lexical overlap and semantic similarity, ensuring that the model's performance is assessed from multiple perspectives.

  1. Token Overlap Ratio: Evaluates the proportion of unique tokens in the reference answer found in the model’s prediction.
  2. ROUGE Scores: Measures n-gram overlaps between the generated answers and reference texts.
  3. BERTScore: Assesses semantic similarity by comparing contextual embeddings.

Results Summary

I tested out some of the sample questions and answers from the dataset:

Sample 1

  • Question: What research (or clinical trials) is being done for Childhood Brain Stem Glioma?
  • Reference Answer: New types of treatment are being tested in clinical trials. Patients may want to think about taking part in a clinical trial.
  • Generated Answer: Clinical trials are studies that test new treatments for diseases like Childhood Brain Stem Glioma. Many standard treatments used today ...
  • Token Overlap: 0.43

Sample 2

  • Question: What are the stages of Uterine Sarcoma?
  • Reference Answer: Stage I: Cancer is in the uterus only. Stage II: Cancer has spread within the uterus.
  • Generated Answer: Significant changes in the way the uterus looks and acts are being made. The stages are used to detect, diagnose...
  • Token Overlap: 0.20

Metrics Summary

The results are shown based on the high epochs training:

MetricValue
Avg Token Overlap0.50
ROUGE-10.3256
ROUGE-20.1546
ROUGE-L0.2228
ROUGE-L SUM0.2365
BERTScore Precision (avg)0.5283
BERTScore Recall (avg)0.6463
BERTScore F1 (avg)0.5812

Reflections and Findings

Strengths

  • Semantic Understanding: The model achieves a BERTScore F1 of 0.5812, demonstrating a strong ability to grasp and interpret the context and meaning within the data. The balanced precision (0.5283) and recall (0.6463) scores indicate a well-rounded comprehension of semantic relationships.
  • Token Overlap: With an average token overlap of 0.50, the model effectively aligns its generated responses with reference texts, ensuring a substantial degree of accuracy in word usage and phrasing.
  • ROUGE Performance: The model exhibits solid performance across various ROUGE metrics, which are critical for evaluating the quality of generated text:
    • ROUGE-1: 0.3256 – Captures the overlap of unigrams, reflecting the model's ability to recognize individual important words.
    • ROUGE-2: 0.1546 – Measures bigram overlaps, indicating the model's capability to understand and reproduce pairs of consecutive words.
    • ROUGE-L: 0.2228 – Assesses the longest common subsequence, highlighting the model's proficiency in maintaining the flow and structure of the reference text.
    • ROUGE-L SUM: 0.2365 – Provides an aggregate measure of the longest common subsequences, offering a comprehensive view of the model's performance in preserving the overall content structure.
  • Adaptability: The model consistently generates coherent and contextually relevant responses across diverse topics and domains, showcasing its flexibility and robustness even when trained on a modest dataset.

Weaknesses

  • Precision: Although the average token overlap stands at 0.50, there is still potential to enhance the model's precision. Achieving higher token overlap would result in more exact, word-for-word alignments with reference answers, thereby increasing the reliability of the generated responses.
  • Ambiguity: In certain instances, the model's outputs deviate from the reference texts, leading to ambiguous or less accurate responses. This highlights the necessity for further fine-tuning to minimize inconsistencies and ensure greater fidelity to the desired outputs.

Future Directions

This project underscores GPT-Neo’s capability to handle complex, domain-specific data while identifying key areas for improvement. To advance the model’s performance and applicability, the following strategies will be pursued:

  • Expand Training Data: Incorporate larger and more diverse datasets to enhance precision and semantic understanding.
  • Optimize Fine-Tuning: Focus on improving ROUGE-2 and ROUGE-L scores to better capture complex linguistic structures.
  • Conduct Real-World Testing: Deploy the model in practical scenarios to evaluate performance and identify further improvement areas
  • Custom Loss Functions: Developing loss functions that penalize factual inaccuracies to enhance reliability in sensitive domains.

The work is available on google colab. Feel free to check it.