- Published on
Finetuning Gpt-Neo for Medical Q&A
- Authors
- Name
- Astik Dahal
How can small language models like GPT-Neo be tailored for specialized applications such as medical question answering? This project arose from a straightforward yet critical goal: to customize lightweight language models for practical, real-world tasks. By fine-tuning GPT-Neo using a domain-specific dataset, I aimed to assess its performance, understand its limitations, and explore ways to enhance it further.
Summary
This project fine-tuned GPT-Neo for medical Q&A using the MedQuAD dataset, achieving a token overlap score of 0.50 and BERTScore F1 of 0.5812. Results highlight strong semantic understanding but room for improvement in precision.
The Motivation
Small language models like GPT-Neo are ideal for resource-constrained applications, especially in healthcare where accurate, context-sensitive Q&A systems are crucial. They offer several key advantages:
- Efficiency: Small models are computationally less demanding, enabling deployment in environments with limited resources.
- Customizability: Fine-tuning smaller models with focused datasets makes them ideal for specialized tasks.
- Accessibility: Compared to larger models requiring extensive infrastructure, GPT-Neo is far more accessible to researchers and developers.
The Dataset: Building the Knowledge Base
To train GPT-Neo for this project, I employed the MedQuAD dataset, specifically the cancer dataset, a repository of medical Q&A pairs from reliable healthcare sources. Parsing the XML files yielded a dataset comprising 729 Q&A pairs for cancer related Q&A.
Example:
Q: “What are the stages of uterine sarcoma?
A: “Stage I: Cancer is in the uterus only. Stage II: Cancer has spread within the uterus.” These structured entries formed the basis for model training and evaluation.
Comparing Training Modes
To evaluate the adaptability of GPT-Neo, I conducted experiments with two different training durations. The model setup included:
Model Architecture: GPT-Neo (125M parameters)
- Training Parameters:
Learning Rate: 1e-4
Batch Size: 8
Optimizer: AdamW
Experiment 1: Low Epoch Training
Using epochs = 2
Step | Training Loss | Validation Loss |
---|---|---|
50 | 1.133000 | 0.759739 |
100 | 0.639200 | 0.636311 |
Experiment 2: High Epoch Training
#Define training arguments
training_args = TrainingArguments(
output_dir="./gpt_neo_medical_qa",
overwrite_output_dir=True,
num_train_epochs=10,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
evaluation_strategy="steps",
eval_steps=50,
save_steps=50,
logging_steps=50,
learning_rate=1e-4,
weight_decay=0.01,
warmup_steps=5,
logging_dir="./logs",
push_to_hub=False,
fp16=torch.cuda.is_available(),
report_to="none"
)
trainer_high_epoch = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
)
trainer_high_epoch.train()
Using epochs = 10
Step | Training Loss | Validation Loss |
---|---|---|
50 | 0.347100 | 0.578609 |
100 | 0.286200 | 0.563492 |
150 | 0.261600 | 0.551096 |
200 | 0.183100 | 0.576036 |
250 | 0.159600 | 0.575546 |
300 | 0.149300 | 0.585517 |
350 | 0.117100 | 0.586230 |
400 | 0.106000 | 0.601184 |
450 | 0.101600 | 0.597330 |
500 | 0.087200 | 0.613437 |
550 | 0.079800 | 0.623454 |
600 | 0.075200 | 0.627863 |
650 | 0.069300 | 0.627968 |
700 | 0.063900 | 0.637445 |
Key Insight
Both approaches delivered reasonable results, but longer training produced more nuanced and accurate answers. However, this came at the expense of increased computation.
Testing and Evaluation
To measure the effectiveness of the fine-tuned model, I employed three evaluation metrics: These metrics were chosen to provide a comprehensive evaluation of both lexical overlap and semantic similarity, ensuring that the model's performance is assessed from multiple perspectives.
- Token Overlap Ratio: Evaluates the proportion of unique tokens in the reference answer found in the model’s prediction.
- ROUGE Scores: Measures n-gram overlaps between the generated answers and reference texts.
- BERTScore: Assesses semantic similarity by comparing contextual embeddings.
Results Summary
I tested out some of the sample questions and answers from the dataset:
Sample 1
- Question: What research (or clinical trials) is being done for Childhood Brain Stem Glioma?
- Reference Answer: New types of treatment are being tested in clinical trials. Patients may want to think about taking part in a clinical trial.
- Generated Answer: Clinical trials are studies that test new treatments for diseases like Childhood Brain Stem Glioma. Many standard treatments used today ...
- Token Overlap: 0.43
Sample 2
- Question: What are the stages of Uterine Sarcoma?
- Reference Answer: Stage I: Cancer is in the uterus only. Stage II: Cancer has spread within the uterus.
- Generated Answer: Significant changes in the way the uterus looks and acts are being made. The stages are used to detect, diagnose...
- Token Overlap: 0.20
Metrics Summary
The results are shown based on the high epochs training:
Metric | Value |
---|---|
Avg Token Overlap | 0.50 |
ROUGE-1 | 0.3256 |
ROUGE-2 | 0.1546 |
ROUGE-L | 0.2228 |
ROUGE-L SUM | 0.2365 |
BERTScore Precision (avg) | 0.5283 |
BERTScore Recall (avg) | 0.6463 |
BERTScore F1 (avg) | 0.5812 |
Reflections and Findings
Strengths
- Semantic Understanding: The model achieves a BERTScore F1 of 0.5812, demonstrating a strong ability to grasp and interpret the context and meaning within the data. The balanced precision (0.5283) and recall (0.6463) scores indicate a well-rounded comprehension of semantic relationships.
- Token Overlap: With an average token overlap of 0.50, the model effectively aligns its generated responses with reference texts, ensuring a substantial degree of accuracy in word usage and phrasing.
- ROUGE Performance: The model exhibits solid performance across various ROUGE metrics, which are critical for evaluating the quality of generated text:
- ROUGE-1: 0.3256 – Captures the overlap of unigrams, reflecting the model's ability to recognize individual important words.
- ROUGE-2: 0.1546 – Measures bigram overlaps, indicating the model's capability to understand and reproduce pairs of consecutive words.
- ROUGE-L: 0.2228 – Assesses the longest common subsequence, highlighting the model's proficiency in maintaining the flow and structure of the reference text.
- ROUGE-L SUM: 0.2365 – Provides an aggregate measure of the longest common subsequences, offering a comprehensive view of the model's performance in preserving the overall content structure.
- Adaptability: The model consistently generates coherent and contextually relevant responses across diverse topics and domains, showcasing its flexibility and robustness even when trained on a modest dataset.
Weaknesses
- Precision: Although the average token overlap stands at 0.50, there is still potential to enhance the model's precision. Achieving higher token overlap would result in more exact, word-for-word alignments with reference answers, thereby increasing the reliability of the generated responses.
- Ambiguity: In certain instances, the model's outputs deviate from the reference texts, leading to ambiguous or less accurate responses. This highlights the necessity for further fine-tuning to minimize inconsistencies and ensure greater fidelity to the desired outputs.
Future Directions
This project underscores GPT-Neo’s capability to handle complex, domain-specific data while identifying key areas for improvement. To advance the model’s performance and applicability, the following strategies will be pursued:
- Expand Training Data: Incorporate larger and more diverse datasets to enhance precision and semantic understanding.
- Optimize Fine-Tuning: Focus on improving ROUGE-2 and ROUGE-L scores to better capture complex linguistic structures.
- Conduct Real-World Testing: Deploy the model in practical scenarios to evaluate performance and identify further improvement areas
- Custom Loss Functions: Developing loss functions that penalize factual inaccuracies to enhance reliability in sensitive domains.
The work is available on google colab. Feel free to check it.