Build Multi-Endpoint ML APIs Locally with LitServe: Batching, Streaming, Caching & More
Setup and dependencies
We start by preparing a lightweight environment for serving models locally using LitServe, PyTorch and Hugging Face Transformers. The examples below run on a local machine or Google Colab and demonstrate how to expose models as APIs with minimal boilerplate.
!pip install litserve torch transformers -q
import litserve as ls
import torch
from transformers import pipeline
import time
from typing import List
Defining API classes
The following API classes illustrate common serving patterns: a text generator, a batched sentiment endpoint, a streaming generator, a multi-task endpoint, and a cached API. Each class shows how to decode requests, run inference, and encode responses.
class TextGeneratorAPI(ls.LitAPI):
def setup(self, device):
self.model = pipeline("text-generation", model="distilgpt2", device=0 if device == "cuda" and torch.cuda.is_available() else -1)
self.device = device
def decode_request(self, request):
return request["prompt"]
def predict(self, prompt):
result = self.model(prompt, max_length=100, num_return_sequences=1, temperature=0.8, do_sample=True)
return result[0]['generated_text']
def encode_response(self, output):
return {"generated_text": output, "model": "distilgpt2"}
class BatchedSentimentAPI(ls.LitAPI):
def setup(self, device):
self.model = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english", device=0 if device == "cuda" and torch.cuda.is_available() else -1)
def decode_request(self, request):
return request["text"]
def batch(self, inputs: List[str]) -> List[str]:
return inputs
def predict(self, batch: List[str]):
results = self.model(batch)
return results
def unbatch(self, output):
return output
def encode_response(self, output):
return {"label": output["label"], "score": float(output["score"]), "batched": True}
Streaming generation
To illustrate streaming responses, implement an API that yields tokens (or simulated tokens) as they are produced. This pattern maps well to real-time text generation or long-running tasks.
class StreamingTextAPI(ls.LitAPI):
def setup(self, device):
self.model = pipeline("text-generation", model="distilgpt2", device=0 if device == "cuda" and torch.cuda.is_available() else -1)
def decode_request(self, request):
return request["prompt"]
def predict(self, prompt):
words = ["Once", "upon", "a", "time", "in", "a", "digital", "world"]
for word in words:
time.sleep(0.1)
yield word + " "
def encode_response(self, output):
for token in output:
yield {"token": token}
Multi-task endpoint
A single endpoint can route requests to different model pipelines based on input parameters. This reduces deployment overhead when related tasks share infrastructure.
class MultiTaskAPI(ls.LitAPI):
def setup(self, device):
self.sentiment = pipeline("sentiment-analysis", device=-1)
self.summarizer = pipeline("summarization", model="sshleifer/distilbart-cnn-6-6", device=-1)
self.device = device
def decode_request(self, request):
return {"task": request.get("task", "sentiment"), "text": request["text"]}
def predict(self, inputs):
task = inputs["task"]
text = inputs["text"]
if task == "sentiment":
result = self.sentiment(text)[0]
return {"task": "sentiment", "result": result}
elif task == "summarize":
if len(text.split())
Note: the MultiTaskAPI snippet demonstrates dynamic routing. In a production-ready version you’d complete the summarization branch and tune batching/length handling for the summarizer.
Caching repeated inferences
Adding a cache layer to an API can save compute when the same input is requested repeatedly. The example below keeps a simple in-memory cache and exposes hit/miss stats.
class CachedAPI(ls.LitAPI):
def setup(self, device):
self.model = pipeline("sentiment-analysis", device=-1)
self.cache = {}
self.hits = 0
self.misses = 0
def decode_request(self, request):
return request["text"]
def predict(self, text):
if text in self.cache:
self.hits += 1
return self.cache[text], True
self.misses += 1
result = self.model(text)[0]
self.cache[text] = result
return result, False
def encode_response(self, output):
result, from_cache = output
return {"label": result["label"], "score": float(result["score"]), "from_cache": from_cache, "cache_stats": {"hits": self.hits, "misses": self.misses}}
Local testing without a server
Before running a full server, it’s useful to test API classes locally by calling their methods directly. The helper below runs through the main endpoints and prints results.
def test_apis_locally():
print("=" * 70)
print("Testing APIs Locally (No Server)")
print("=" * 70)
api1 = TextGeneratorAPI(); api1.setup("cpu")
decoded = api1.decode_request({"prompt": "Artificial intelligence will"})
result = api1.predict(decoded)
encoded = api1.encode_response(result)
print(f"✓ Result: {encoded['generated_text'][:100]}...")
api2 = BatchedSentimentAPI(); api2.setup("cpu")
texts = ["I love Python!", "This is terrible.", "Neutral statement."]
decoded_batch = [api2.decode_request({"text": t}) for t in texts]
batched = api2.batch(decoded_batch)
results = api2.predict(batched)
unbatched = api2.unbatch(results)
for i, r in enumerate(unbatched):
encoded = api2.encode_response(r)
print(f"✓ '{texts[i]}' -> {encoded['label']} ({encoded['score']:.2f})")
api3 = MultiTaskAPI(); api3.setup("cpu")
decoded = api3.decode_request({"task": "sentiment", "text": "Amazing tutorial!"})
result = api3.predict(decoded)
print(f"✓ Sentiment: {result['result']}")
api4 = CachedAPI(); api4.setup("cpu")
test_text = "LitServe is awesome!"
for i in range(3):
decoded = api4.decode_request({"text": test_text})
result = api4.predict(decoded)
encoded = api4.encode_response(result)
print(f"✓ Request {i+1}: {encoded['label']} (cached: {encoded['from_cache']})")
print("=" * 70)
print(" All tests completed successfully!")
print("=" * 70)
test_apis_locally()
These examples show how LitServe helps you structure multiple, production-ready endpoints quickly: single-request, batched, streaming, multi-task, and cached endpoints. The pattern is simple — decode requests, run inference, then encode responses — and scales as you add more models and routes.