<RETURN_TO_BASE

Accelerating Medical Symptom Classification with Adala and Google Gemini: A Hands-On Guide

This tutorial demonstrates building an active learning pipeline for medical symptom classification by integrating Adala with Google Gemini for efficient annotation and confidence visualization.

Setting Up Adala and Dependencies

This tutorial starts by installing the Adala framework directly from its GitHub repository along with necessary dependencies like pandas, matplotlib, and Google Generative AI SDK. Verification steps ensure the package is correctly installed and accessible within the Python environment.

!pip install -q git+https://github.com/HumanSignal/Adala.git
!pip install -q google-generativeai pandas matplotlib
!pip list | grep adala

Python path and package directories are inspected to confirm proper installation.

Integrating Google Gemini as a Custom Annotator

Google Gemini is set up as a custom annotator through a dedicated class GeminiAnnotator. This class uses the Google Generative AI client to classify medical symptoms into categories such as Cardiovascular, Respiratory, Gastrointestinal, and Neurological.

class GeminiAnnotator:
    def __init__(self, model_name="models/gemini-2.0-flash-lite", categories=None):
        self.model = genai.GenerativeModel(model_name=model_name,
                                          generation_config={"temperature": 0.1})
        self.categories = categories
       
    def annotate(self, samples):
        results = []
        for sample in samples:
            prompt = f"""Classify this medical symptom into one of these categories:
            {', '.join(self.categories)}.
            Return JSON format: {{"category": "selected_category",
            "confidence": 0.XX, "explanation": "brief_reason"}}
           
            SYMPTOM: {sample.text}"""
           
            try:
                response = self.model.generate_content(prompt).text
                json_match = re.search(r'(\{.*\})', response, re.DOTALL)
                result = json.loads(json_match.group(1) if json_match else response)
               
                labeled_sample = type('LabeledSample', (), {
                    'text': sample.text,
                    'labels': result["category"],
                    'metadata': {
                        "confidence": result["confidence"],
                        "explanation": result["explanation"]
                    }
                })
            except Exception as e:
                labeled_sample = type('LabeledSample', (), {
                    'text': sample.text,
                    'labels': "unknown",
                    'metadata': {"error": str(e)}
                })
            results.append(labeled_sample)
        return results

Preparing and Annotating Sample Data

A list of sample medical symptoms is wrapped in lightweight TextSample objects and passed to the annotator. The active learning loop runs for three iterations, selectively prioritizing symptoms related to chest, heart, or pain by increasing their sampling scores.

sample_data = [
    "Chest pain radiating to left arm during exercise",
    "Persistent dry cough with occasional wheezing",
    "Severe headache with sensitivity to light",
    "Stomach cramps and nausea after eating",
    "Numbness in fingers of right hand",
    "Shortness of breath when climbing stairs"
]
 
text_samples = [type('TextSample', (), {'text': text}) for text in sample_data]
 
annotator = GeminiAnnotator(categories=CATEGORIES)
labeled_samples = []
 
print("\nRunning Active Learning Loop:")
for i in range(3):  
    print(f"\n--- Iteration {i+1} ---")
   
    remaining = [s for s in text_samples if s not in [getattr(l, '_sample', l) for l in labeled_samples]]
    if not remaining:
        break
       
    scores = np.zeros(len(remaining))
    for j, sample in enumerate(remaining):
        scores[j] = 0.1
        if any(term in sample.text.lower() for term in ["chest", "heart", "pain"]):
            scores[j] += 0.5  
   
    selected_idx = np.argmax(scores)
    selected = [remaining[selected_idx]]
   
    newly_labeled = annotator.annotate(selected)
    for sample in newly_labeled:
        sample._sample = selected[0]  
    labeled_samples.extend(newly_labeled)
   
    latest = labeled_samples[-1]
    print(f"Text: {latest.text}")
    print(f"Category: {latest.labels}")
    print(f"Confidence: {latest.metadata.get('confidence', 0)}")
    print(f"Explanation: {latest.metadata.get('explanation', '')[:100]}...")

Visualizing Classification Confidence

Classification confidence scores for each labeled category are extracted and visualized using a bar chart with Matplotlib.

categories = [s.labels for s in labeled_samples]
confidence = [s.metadata.get("confidence", 0) for s in labeled_samples]
 
plt.figure(figsize=(10, 5))
plt.bar(range(len(categories)), confidence, color='skyblue')
plt.xticks(range(len(categories)), categories, rotation=45)
plt.title('Classification Confidence by Category')
plt.tight_layout()
plt.show()

Summary

Combining Adala's modular active learning framework with Google Gemini's generative classification capabilities enables an efficient, extensible medical symptom annotation pipeline. This approach supports priority-based sampling and confidence visualization, offering a practical foundation for more advanced annotation workflows.

🇷🇺

Сменить язык

Читать эту статью на русском

Переключить на Русский