<RETURN_TO_BASE

Optimizing Token Generation with KV Caching

Learn how KV caching accelerates token generation in LLMs.

Problem with Slow Token Generation

You’re deploying an LLM in production. Generating the first few tokens is fast; however, as the sequence grows, each additional token takes progressively longer to generate. If compute isn’t the primary bottleneck, the inefficiency causing this slowdown can be addressed with a model redesign.

What is KV Caching and How Does it Accelerate Inference?

KV caching is an optimization technique used in text generation for large language models. Typically, in autoregressive generation, the model computes attention over all tokens each time, but the keys (K) and values (V) calculated earlier remain unchanged.

With KV caching, the model stores these keys and values after the first computation. For each new token, it only recomputes the query (Q) and utilizes the cached K and V, drastically reducing redundant calculations. This results in faster inference, particularly for long sequences, though it requires extra memory for the cache.

Evaluating the Impact of KV Caching on Inference Speed

To benchmark KV caching, we test the same prompt with caching enabled and disabled, measuring average generation time while keeping model parameters constant. The results demonstrate that reusing cached keys and values significantly reduces computational redundancy, thereby speeding up inference.

import numpy as np
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
 
device = "cuda" if torch.cuda.is_available() else "cpu"
 
model_name = "gpt2-medium"  
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
 
prompt = "Explain KV caching in transformers."
 
inputs = tokenizer(prompt, return_tensors="pt").to(device)
 
for use_cache in (True, False):
    times = []
    for _ in range(5):  
        start = time.time()
        model.generate(
            **inputs,
            use_cache=use_cache,
            max_new_tokens=1000
        )
        times.append(time.time() - start)
 
    print(
        f"{'with' if use_cache else 'without'} KV caching: "
        f"{round(np.mean(times), 3)} ± {round(np.std(times), 3)} seconds"
    )

The results indicate that enabling KV caching can reduce the generation time for 1000 tokens from over 107 seconds without caching to around 21.7 seconds with caching—nearly a 5× speedup. This occurs because the model avoids recomputing attention over all tokens at each step, preventing exponential computation growth.

In summary, KV caching is essential for efficient deployment of autoregressive language models, enabling them to generate longer sequences faster without unnecessary computation.

🇷🇺

Сменить язык

Читать эту статью на русском

Переключить на Русский