Coding Agents5 min read

KV Cache in LLMs: How Sebastian Raschka Implements It

A step-by-step tutorial and readable Python code showing register_buffer, use_cache forwarding, cache reset and token-position tracking.

The Brieftide

TL;DR

  • 01A step-by-step tutorial and readable Python code showing register_buffer, use_cache forwarding, cache reset and token-position tracking.
  • 02Raschka uses a short text example, "Time", which becomes "Time flies" and then "Time flies fast" to show redundancy in autoregressive decoding.
  • 03The KV cache resolves that inefficiency by storing previously computed key and value tensors so they can be retrieved instead of recomputed.

Sebastian Raschka, PhD published a technical tutorial on Jun 17, 2025 that explains KV caches and ships two self-contained Python scripts, gpt_ch04.py and gpt_with_kv_cache.py, which implement an LLM with and without a KV cache.

The article defines a KV cache as a mechanism that stores intermediate key and value computations for reuse during inference, notes that it is inapplicable during training, and highlights the trade-offs: substantial generation speed-ups at the cost of more code complexity and higher memory requirements.

How the tutorial frames the problem and the simple example

Raschka uses a short text example, "Time", which becomes "Time flies" and then "Time flies fast" to show redundancy in autoregressive decoding. The writeup walks through the attention math that produces key and value vectors for tokens and shows how without a cache the model recomputes the same k and v vectors for earlier tokens at every generation step. The KV cache resolves that inefficiency by storing previously computed key and value tensors so they can be retrieved instead of recomputed.

The article includes two Python files for readers to inspect: gpt_ch04.py, which contains the baseline implementation, and gpt_with_kv_cache.py, which contains the KV cache changes. Raschka marks the new cache-relevant lines with # NEW to make the differences easy to follow.

The implementation steps shown in code

Raschka highlights five concrete changes required to add a KV cache to a readable from-scratch model implementation:

  1. Registering cache buffers

Inside the MultiHeadAttention constructor he adds two buffers named cache_k and cache_v by calling:

self.register_buffer("cache_k", None) self.register_buffer("cache_v", None)

These buffers will hold concatenated keys and values across incremental decoding steps.

  1. Forward pass with a use_cache flag

The MultiHeadAttention.forward signature is extended to accept use_cache with the example header:

def forward(self, x, use_cache=False): b, num_tokens, d_in = x.shape keys_new = self.W_key(x) values_new = self.W_value(x) queries = self.W_query(x)

When use_cache is true the code either initializes the cache with the first keys and values or concatenates newly computed ones via torch.cat, then sets keys, values = self.cache_k, self.cache_v. When use_cache is false the block uses the freshly computed keys_new and values_new.

  1. Clearing the cache

A reset_cache method ensures stale context does not leak across generation calls. The example method is:

def reset_cache(self): self.cache_k, self.cache_v = None, None

  1. Propagating use_cache and tracking positions

At the model level Raschka adds a token counter self.current_pos = 0. In GPTModel.forward the tutorial shows constructing position ids differently when use_cache is true by using

pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long) self.current_pos += seq_len

or when not using the cache by starting at 0. The code then builds pos_embeds = self.pos_emb(pos_ids).unsqueeze(0) and adds them to token embeddings. Transformer blocks are invoked in an explicit loop so that use_cache can be passed down:

for blk in self.trf_blocks: x = blk(x, use_cache=use_cache)

TransformerBlock.forward is likewise changed to accept use_cache and to call self.att(x, use_cache=use_cache).

  1. Model-level reset convenience

Raschka adds a GPTModel.reset_kv_cache that iterates all blocks and calls blk.att.reset_cache() and sets self.current_pos = 0 so the entire cached state can be cleared between independent generation sessions.

The article concludes by showing how these pieces are used in a simple text generation function, and it points readers to the GitHub scripts for a side-by-side diff.

Why it matters

KV caches are one of the most critical techniques for compute-efficient LLM inference in production because they avoid recomputing identical key and value vectors for previously seen tokens. The tutorial demonstrates that the implementation is conceptually small: add per-block buffers, propagate a use_cache flag, concatenate new k and v tensors, and track token positions. The clear, from-scratch code makes the trade-offs explicit: faster autoregressive decoding versus increased memory use and added implementation complexity, and reminds readers that KV caches are for inference only.

What to watch

Check the gpt_with_kv_cache.py and the marked # NEW lines to verify how the buffers are registered and concatenated, and confirm that GPTModel.reset_kv_cache clears all block caches and resets current_pos to 0. The next practical milestone is integrating the demonstrated pattern into larger, production-grade transformer stacks and measuring end-to-end latency and peak memory usage with long contexts.

Advertisement

Written by The Brieftide · Source: Ahead of AI

The Brieftide Daily · 06:00

Briefs like this one, in your inbox every morning.

 

FreeOne email a dayEvery claim sourcedUnsubscribe in one click
Advertisement