Агентный RAG в виде дерева решений: умный роутинг, самопроверка и итеративное уточнение
'Пошаговое руководство по созданию агентного RAG в виде дерева решений: маршрутизация запросов, поиск контекста, генерация ответов и их самопроверка с итеративным уточнением.'
Что делает эта система
В этом материале подробно показано, как создать агентный RAG (Retrieval-Augmented Generation), который направляет запросы к подходящим источникам знаний, извлекает релевантный контекст через FAISS, генерирует ответы с Flan-T5, выполняет самопроверку и итеративно уточняет ответы. Система объединяет локальные модели и библиотеки (SentenceTransformers, FAISS, Transformers) в конвейер, напоминающий дерево решений.
Установка зависимостей
Сначала устанавливаем и проверяем зависимости, чтобы обеспечить корректную локальную работу всех компонентов. Ниже — код установки и импорта используемых модулей.
print(" Setting up dependencies...")
import subprocess
import sys
def install_packages():
packages = ['sentence-transformers', 'transformers', 'torch', 'faiss-cpu', 'numpy', 'accelerate']
for package in packages:
print(f"Installing {package}...")
subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])
try:
import faiss
except ImportError:
install_packages()
print("✓ All dependencies installed! Importing modules...\n")
import torch
import numpy as np
from sentence_transformers import SentenceTransformer
from transformers import pipeline
import faiss
from typing import List, Dict, Tuple
import warnings
warnings.filterwarnings('ignore')
print("✓ All modules loaded successfully!\n")Это гарантирует наличие Transformers, FAISS, SentenceTransformers, NumPy и PyTorch и отключает лишние предупреждения.
VectorStore и встраивание (embeddings)
Компонент VectorStore создает эмбеддинги документов через SentenceTransformer и индексирует их в FAISS для быстрого поиска по сходству. Каждый документ сохраняет текст и источник для последующего обоснования ответов.
class VectorStore:
def __init__(self, embedding_model='all-MiniLM-L6-v2'):
print(f"Loading embedding model: {embedding_model}...")
self.embedder = SentenceTransformer(embedding_model)
self.documents = []
self.index = None
def add_documents(self, docs: List[str], sources: List[str]):
self.documents = [{"text": doc, "source": src} for doc, src in zip(docs, sources)]
embeddings = self.embedder.encode(docs, show_progress_bar=False)
dimension = embeddings.shape[1]
self.index = faiss.IndexFlatL2(dimension)
self.index.add(embeddings.astype('float32'))
print(f"✓ Indexed {len(docs)} documents\n")
def search(self, query: str, k: int = 3) -> List[Dict]:
query_vec = self.embedder.encode([query]).astype('float32')
distances, indices = self.index.search(query_vec, k)
return [self.documents[i] for i in indices[0]]Эмбеддинги переводят текст в векторы, FAISS выполняет быстрый nearest-neighbor поиск, а поле 'source' помогает привязывать генерацию к источнику.
Маршрутизация запросов
QueryRouter определяет тип запроса (технический, фактологический, сравнительный, процедурный) с помощью ключевых слов, что позволяет адаптировать стратегию извлечения и генерации.
class QueryRouter:
def __init__(self):
self.categories = {
'technical': ['how', 'implement', 'code', 'function', 'algorithm', 'debug'],
'factual': ['what', 'who', 'when', 'where', 'define', 'explain'],
'comparative': ['compare', 'difference', 'versus', 'vs', 'better', 'which'],
'procedural': ['steps', 'process', 'guide', 'tutorial', 'how to']
}
def route(self, query: str) -> str:
query_lower = query.lower()
scores = {}
for category, keywords in self.categories.items():
score = sum(1 for kw in keywords if kw in query_lower)
scores[category] = score
best_category = max(scores, key=scores.get)
return best_category if scores[best_category] > 0 else 'factual'Этот простой метод легко расширяется дополнительными ключевыми словами или заменяется ML-классификатором.
Генерация ответов и самопроверка
AnswerGenerator использует Flan-T5 для генерации ответов на основе извлеченного контекста. Затем выполняется самопроверка: длина ответа, привязка к контексту и соответствие запросу. При плохой оценке система может запустить итерацию уточнения.
class AnswerGenerator:
def __init__(self, model_name='google/flan-t5-base'):
print(f"Loading generation model: {model_name}...")
self.generator = pipeline('text2text-generation', model=model_name, device=0 if torch.cuda.is_available() else -1, max_length=256)
device_type = "GPU" if torch.cuda.is_available() else "CPU"
print(f"✓ Generator ready (using {device_type})\n")
def generate(self, query: str, context: List[Dict], query_type: str) -> str:
context_text = "\n\n".join([f"[{doc['source']}]: {doc['text']}" for doc in context])
Context:
{context_text}
Question: {query}
Answer:"""
answer = self.generator(prompt, max_length=200, do_sample=False)[0]['generated_text']
return answer.strip()
def self_check(self, query: str, answer: str, context: List[Dict]) -> Tuple[bool, str]:
if len(answer) < 10:
return False, "Answer too short - needs more detail"
context_keywords = set()
for doc in context:
context_keywords.update(doc['text'].lower().split()[:20])
answer_words = set(answer.lower().split())
overlap = len(context_keywords.intersection(answer_words))
if overlap < 2:
return False, "Answer not grounded in context - needs more evidence"
query_keywords = set(query.lower().split())
if len(query_keywords.intersection(answer_words)) < 1:
return False, "Answer doesn't address the query - rephrase needed"
return True, "Answer quality acceptable"Логику самопроверки можно улучшить с помощью моделей для проверки фактичности или entailment-проверок.
Оркестрация: AgenticRAG
AgenticRAG объединяет все компоненты: роутер, векторное хранилище и генератор. На основе типа запроса он определяет размер контекста, генерирует ответ, проверяет его и при необходимости выполняет повторные итерации с уточнением запроса.
class AgenticRAG:
def __init__(self):
self.vector_store = VectorStore()
self.router = QueryRouter()
self.generator = AnswerGenerator()
self.max_iterations = 2
def add_knowledge(self, documents: List[str], sources: List[str]):
self.vector_store.add_documents(documents, sources)
def query(self, question: str, verbose: bool = True) -> Dict:
if verbose:
print(f"\n{'='*60}")
print(f" Query: {question}")
print(f"{'='*60}")
query_type = self.router.route(question)
if verbose:
print(f" Route: {query_type.upper()} query detected")
k_docs = {'technical': 2, 'comparative': 4, 'procedural': 3}.get(query_type, 3)
iteration = 0
answer_accepted = False
while iteration < self.max_iterations and not answer_accepted:
iteration += 1
if verbose:
print(f"\n Iteration {iteration}")
context = self.vector_store.search(question, k=k_docs)
if verbose:
print(f" Retrieved {len(context)} documents from sources:")
for doc in context:
print(f" - {doc['source']}")
answer = self.generator.generate(question, context, query_type)
if verbose:
print(f" Generated answer: {answer[:100]}...")
answer_accepted, feedback = self.generator.self_check(question, answer, context)
if verbose:
status = "✓ ACCEPTED" if answer_accepted else "✗ REJECTED"
print(f" Self-check: {status}")
print(f" Feedback: {feedback}")
if not answer_accepted and iteration < self.max_iterations:
question = f"{question} (provide more specific details)"
k_docs += 1
return {'answer': answer, 'query_type': query_type, 'iterations': iteration, 'accepted': answer_accepted, 'sources': [doc['source'] for doc in context]}Демонстрация
Ниже — пример main(), который загружает небольшую базу знаний и отправляет тестовые запросы в агент. Это полезно для локального отладки и визуализации этапов работы.
def main():
print("\n" + "="*60)
print(" AGENTIC RAG WITH ROUTING & SELF-CHECK")
print("="*60 + "\n")
documents = [
"RAG (Retrieval-Augmented Generation) combines information retrieval with text generation. It retrieves relevant documents and uses them as context for generating accurate answers."
]
sources = ["Python Documentation", "ML Textbook", "Neural Networks Guide", "Deep Learning Paper", "Transformer Architecture", "RAG Research Paper"]
rag = AgenticRAG()
rag.add_knowledge(documents, sources)
test_queries = ["What is Python?", "How does machine learning work?", "Compare neural networks and deep learning"]
for query in test_queries:
result = rag.query(query, verbose=True)
print(f"\n{'='*60}")
print(f" FINAL RESULT:")
print(f" Answer: {result['answer']}")
print(f" Query Type: {result['query_type']}")
print(f" Iterations: {result['iterations']}")
print(f" Accepted: {result['accepted']}")
print(f"{'='*60}\n")
if __name__ == "__main__":
main()Запустите пример локально, пробуйте расширять базу знаний, менять модели и усиливать модуль самопроверки для более высокой надежности и фактической корректности.
Switch Language
Read this article in English