Создание локальных ML API с LitServe: мультиэндпоинты, батчинг, стриминг и кэширование

Установка и зависимости

Начнём с подготовки лёгкого окружения для локального развёртывания моделей: LitServe, PyTorch и Transformers от Hugging Face. Примеры подходят для локальной машины или Google Colab и показывают, как быстро экспонировать модели как API.

!pip install litserve torch transformers -q
import litserve as ls
import torch
from transformers import pipeline
import time
from typing import List

Определение API-классов

Ниже приведены примеры типичных API: генератор текста, батчированный анализ тональности, стриминговый генератор, мультизадачный эндпоинт и кэшированное API. Каждый класс показывает декодирование запросов, инференс и кодирование ответов.

class TextGeneratorAPI(ls.LitAPI):
   def setup(self, device):
       self.model = pipeline("text-generation", model="distilgpt2", device=0 if device == "cuda" and torch.cuda.is_available() else -1)
       self.device = device
   def decode_request(self, request):
       return request["prompt"]
   def predict(self, prompt):
       result = self.model(prompt, max_length=100, num_return_sequences=1, temperature=0.8, do_sample=True)
       return result[0]['generated_text']
   def encode_response(self, output):
       return {"generated_text": output, "model": "distilgpt2"}
class BatchedSentimentAPI(ls.LitAPI):
   def setup(self, device):
       self.model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=0 if device == "cuda" and torch.cuda.is_available() else -1)
   def decode_request(self, request):
       return request["text"]
   def batch(self, inputs: List[str]) -> List[str]:
       return inputs
   def predict(self, batch: List[str]):
       results = self.model(batch)
       return results
   def unbatch(self, output):
       return output
   def encode_response(self, output):
       return {"label": output["label"], "score": float(output["score"]), "batched": True}

Стриминг ответов

Чтобы демонстрировать стриминг, можно реализовать API, который отдаёт токены (или имитацию токенов) по мере генерации. Такой подход подходит для реального времени и длительных задач.

class StreamingTextAPI(ls.LitAPI):
   def setup(self, device):
       self.model = pipeline("text-generation", model="distilgpt2", device=0 if device == "cuda" and torch.cuda.is_available() else -1)
   def decode_request(self, request):
       return request["prompt"]
   def predict(self, prompt):
       words = ["Once", "upon", "a", "time", "in", "a", "digital", "world"]
       for word in words:
           time.sleep(0.1)
           yield word + " "
   def encode_response(self, output):
       for token in output:
           yield {"token": token}

Мультизадачный эндпоинт

Один эндпоинт может маршрутизировать запросы к разным пайплайнам в зависимости от параметров входа. Это снижает накладные расходы при развертывании родственных задач.

class MultiTaskAPI(ls.LitAPI):
   def setup(self, device):
       self.sentiment = pipeline("sentiment-analysis", device=-1)
       self.summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", device=-1)
       self.device = device
   def decode_request(self, request):
       return {"task": request.get("task", "sentiment"), "text": request["text"]}
   def predict(self, inputs):
       task = inputs["task"]
       text = inputs["text"]
       if task == "sentiment":
           result = self.sentiment(text)[0]
           return {"task": "sentiment", "result": result}
       elif task == "summarize":
           if len(text.split()) 

Внимание: фрагмент MultiTaskAPI показывает идею динамической маршрутизации. Для продакшн-версии нужно завершить ветку summarization и продумать обработку длинных текстов и батчинг.

Кэширование инференса

Добавив кэш, можно экономить ресурсы при повторных запросах с одинаковым входом. Пример ниже хранит результаты в памяти и возвращает статистику попаданий и промахов.

class CachedAPI(ls.LitAPI):
   def setup(self, device):
       self.model = pipeline("sentiment-analysis", device=-1)
       self.cache = {}
       self.hits = 0
       self.misses = 0
   def decode_request(self, request):
       return request["text"]
   def predict(self, text):
       if text in self.cache:
           self.hits += 1
           return self.cache[text], True
       self.misses += 1
       result = self.model(text)[0]
       self.cache[text] = result
       return result, False
   def encode_response(self, output):
       result, from_cache = output
       return {"label": result["label"], "score": float(result["score"]), "from_cache": from_cache, "cache_stats": {"hits": self.hits, "misses": self.misses}}

Локальное тестирование

Перед развёртыванием удобно протестировать API-классы локально, вызывая их методы напрямую. Пример тестового раннера прогоняет все ключевые эндпоинты и выводит результаты.

def test_apis_locally():
   print("=" * 70)
   print("Testing APIs Locally (No Server)")
   print("=" * 70)


   api1 = TextGeneratorAPI(); api1.setup("cpu")
   decoded = api1.decode_request({"prompt": "Artificial intelligence will"})
   result = api1.predict(decoded)
   encoded = api1.encode_response(result)
   print(f"✓ Result: {encoded['generated_text'][:100]}...")


   api2 = BatchedSentimentAPI(); api2.setup("cpu")
   texts = ["I love Python!", "This is terrible.", "Neutral statement."]
   decoded_batch = [api2.decode_request({"text": t}) for t in texts]
   batched = api2.batch(decoded_batch)
   results = api2.predict(batched)
   unbatched = api2.unbatch(results)
   for i, r in enumerate(unbatched):
       encoded = api2.encode_response(r)
       print(f"✓ '{texts[i]}' -> {encoded['label']} ({encoded['score']:.2f})")


   api3 = MultiTaskAPI(); api3.setup("cpu")
   decoded = api3.decode_request({"task": "sentiment", "text": "Amazing tutorial!"})
   result = api3.predict(decoded)
   print(f"✓ Sentiment: {result['result']}")


   api4 = CachedAPI(); api4.setup("cpu")
   test_text = "LitServe is awesome!"
   for i in range(3):
       decoded = api4.decode_request({"text": test_text})
       result = api4.predict(decoded)
       encoded = api4.encode_response(result)
       print(f"✓ Request {i+1}: {encoded['label']} (cached: {encoded['from_cache']})")


   print("=" * 70)
   print(" All tests completed successfully!")
   print("=" * 70)


test_apis_locally()

Эти примеры показывают, как LitServe помогает быстро организовать разные типы эндпоинтов: одиночные запросы, батчи, стриминг, мультизадачность и кэширование. Шаблон остаётся простым — декодировать запрос, выполнить инференс и закодировать ответ — и его легко масштабировать по мере роста числа моделей и маршрутов.