Accelerating Medical Symptom Classification with Adala and Google Gemini: A Hands-On Guide
This tutorial demonstrates building an active learning pipeline for medical symptom classification by integrating Adala with Google Gemini for efficient annotation and confidence visualization.
Setting Up Adala and Dependencies
This tutorial starts by installing the Adala framework directly from its GitHub repository along with necessary dependencies like pandas, matplotlib, and Google Generative AI SDK. Verification steps ensure the package is correctly installed and accessible within the Python environment.
!pip install -q git+https://github.com/HumanSignal/Adala.git
!pip install -q google-generativeai pandas matplotlib
!pip list | grep adalaPython path and package directories are inspected to confirm proper installation.
Integrating Google Gemini as a Custom Annotator
Google Gemini is set up as a custom annotator through a dedicated class GeminiAnnotator. This class uses the Google Generative AI client to classify medical symptoms into categories such as Cardiovascular, Respiratory, Gastrointestinal, and Neurological.
class GeminiAnnotator:
def __init__(self, model_name="models/gemini-2.0-flash-lite", categories=None):
self.model = genai.GenerativeModel(model_name=model_name,
generation_config={"temperature": 0.1})
self.categories = categories
def annotate(self, samples):
results = []
for sample in samples:
prompt = f"""Classify this medical symptom into one of these categories:
{', '.join(self.categories)}.
Return JSON format: {{"category": "selected_category",
"confidence": 0.XX, "explanation": "brief_reason"}}
SYMPTOM: {sample.text}"""
try:
response = self.model.generate_content(prompt).text
json_match = re.search(r'(\{.*\})', response, re.DOTALL)
result = json.loads(json_match.group(1) if json_match else response)
labeled_sample = type('LabeledSample', (), {
'text': sample.text,
'labels': result["category"],
'metadata': {
"confidence": result["confidence"],
"explanation": result["explanation"]
}
})
except Exception as e:
labeled_sample = type('LabeledSample', (), {
'text': sample.text,
'labels': "unknown",
'metadata': {"error": str(e)}
})
results.append(labeled_sample)
return resultsPreparing and Annotating Sample Data
A list of sample medical symptoms is wrapped in lightweight TextSample objects and passed to the annotator. The active learning loop runs for three iterations, selectively prioritizing symptoms related to chest, heart, or pain by increasing their sampling scores.
sample_data = [
"Chest pain radiating to left arm during exercise",
"Persistent dry cough with occasional wheezing",
"Severe headache with sensitivity to light",
"Stomach cramps and nausea after eating",
"Numbness in fingers of right hand",
"Shortness of breath when climbing stairs"
]
text_samples = [type('TextSample', (), {'text': text}) for text in sample_data]
annotator = GeminiAnnotator(categories=CATEGORIES)
labeled_samples = []
print("\nRunning Active Learning Loop:")
for i in range(3):
print(f"\n--- Iteration {i+1} ---")
remaining = [s for s in text_samples if s not in [getattr(l, '_sample', l) for l in labeled_samples]]
if not remaining:
break
scores = np.zeros(len(remaining))
for j, sample in enumerate(remaining):
scores[j] = 0.1
if any(term in sample.text.lower() for term in ["chest", "heart", "pain"]):
scores[j] += 0.5
selected_idx = np.argmax(scores)
selected = [remaining[selected_idx]]
newly_labeled = annotator.annotate(selected)
for sample in newly_labeled:
sample._sample = selected[0]
labeled_samples.extend(newly_labeled)
latest = labeled_samples[-1]
print(f"Text: {latest.text}")
print(f"Category: {latest.labels}")
print(f"Confidence: {latest.metadata.get('confidence', 0)}")
print(f"Explanation: {latest.metadata.get('explanation', '')[:100]}...")Visualizing Classification Confidence
Classification confidence scores for each labeled category are extracted and visualized using a bar chart with Matplotlib.
categories = [s.labels for s in labeled_samples]
confidence = [s.metadata.get("confidence", 0) for s in labeled_samples]
plt.figure(figsize=(10, 5))
plt.bar(range(len(categories)), confidence, color='skyblue')
plt.xticks(range(len(categories)), categories, rotation=45)
plt.title('Classification Confidence by Category')
plt.tight_layout()
plt.show()Summary
Combining Adala's modular active learning framework with Google Gemini's generative classification capabilities enables an efficient, extensible medical symptom annotation pipeline. This approach supports priority-based sampling and confidence visualization, offering a practical foundation for more advanced annotation workflows.
Сменить язык
Читать эту статью на русском