🚀 The feature, motivation and pitch
1. The Problem
1.1 Pain Point: LRU Has No Memory of How Hot a Block Was
vLLM's GPU KV cache (BlockPool) uses a single-queue LRU implemented inFreeKVCacheBlockQueue. When a request finishes, its blocks are appended to the tail of the free queue. When new blocks are needed, eviction drains from the
head.
This design has one fundamental blind spot: once ref_cnt drops to zero, all blocks are equal candidates for eviction, regardless of how many times they were previously reused. A system-prompt block touched by 10,000 requests and a
one-time-use block are indistinguishable at eviction time.
1.2 Pain Point: Burst Requests Cause Scan Pollution
In production, LLM serving commonly experiences bursty traffic: a wave of requests with unique prompts arrives simultaneously, each generating a large number of KV blocks that immediately become eviction candidates. These low-value
blocks flood the tail of the free queue, pushing high-value cached prefix blocks toward the eviction head.
This is the classic scan pollution problem, well-known in OS page-cache literature. Linux solved it decades ago with the Active/Inactive list design.
vLLM has not yet applied equivalent protection to the GPU KV cache.
1.3 Why This Matters More Over Time
| Deployment Trend |
Effect on LRU Weakness |
| Longer system prompts (agents, RAG context) |
More blocks per hot prefix → eviction loss per cache miss grows |
| Higher concurrency |
More simultaneous request completions → larger burst pollution events |
| Disaggregated Prefill/Decode |
Prefix blocks transferred from P → D must survive long enough to be reused; premature eviction nullifies the transfer cost |
| Multi-tenant serving |
Different tenants share the block pool; one tenant's burst can evict another's hot prefix |
2. Root Cause Analysis
2.1 Current Architecture
BlockPool └── free_block_queue: FreeKVCacheBlockQueue # one doubly-linked list head (evict first) ←——————————————— tail (evict last) [ old cold blocks ] ... [ recently freed blocks ]
The "reverse-order free" heuristic in single_type_kv_cache_manager.py:
# Free blocks in reverse order so that tail blocks are evicted first
ordered_blocks = reversed(req_blocks)
self.block_pool.free_blocks(ordered_blocks)
2.2 Where the Logic Breaks Down
Consider a 2,048-block GPU cache with a 64-block system prompt that has been reused by many requests:
Timeline
t=0 System prompt (64 blocks) used by 500 requests.
ref_cnt > 0 → blocks NOT in free queue → safe.
t=1 All 500 requests finish.
ref_cnt drops to 0 → 64 blocks appended to queue TAIL.
t=2 32 unique long-prompt requests arrive.
Each uses 62 blocks → 1,984 blocks freed → appended to TAIL.
Queue state after t=2:
head [...stale blocks...][system-prompt 64 blocks][1,984 one-time blocks] tail
↑ evicted first ↑ evicted last
t=3 Next batch needs 100 new blocks.
Eviction drains from head → system-prompt blocks EVICTED
before the 1,984 single-use blocks.
The 64 most valuable blocks are lost before 1,984 worthless blocks.
The Inconsistency: CPU Offload Already Has ARC
vLLM already solved this problem for the CPU offload layer:
# vllm/v1/kv_offload/cpu/manager.py
_CACHE_POLICIES = {
"lru": LRUCachePolicy,
"arc": ARCCachePolicy, # full T1/T2/B1/B2 ARC — already implemented
}
class CPUOffloadingManager:
def __init__(self, num_blocks, cache_policy: Literal["lru", "arc"] = "lru"):
self._policy = policy_cls(cache_capacity=num_blocks)
ARCCachePolicy (vllm/v1/kv_offload/cpu/policies/arc.py) provides:
T1: recently accessed blocks (accessed once)
T2: frequently accessed blocks (accessed multiple times)
B1/B2: ghost lists for adaptive tuning
The GPU BlockPool has no equivalent. This RFC proposes closing that gap.
Quantified Impact
Cache Miss Cost
A prefix cache miss forces the engine to recompute tokens from scratch. For a model with hidden size 4096, 32 KV heads, 128 head dim, and block size 16:
Cost to recompute one missed block ≈ 16 tokens × full prefill attention cost
Typical TTFT (Time To First Token) regression from one evicted system-prompt block in a sequence of 2,048 tokens:
| Model Size |
Block Recompute Time |
Relative TTFT Increase |
| 7B |
~3 ms |
~2–5% |
| 13B |
~6 ms |
~4–8% |
| 70B |
~25 ms |
~8–15% |
For a 64-block system prompt, multiply by 64.
Hit Rate Simulation (Analytical Model)
Assume:
Pool: N = 2,048 blocks
Hot prefix: H = 64 blocks, request arrival rate λ_h = 100 req/s
Unique long prompts: L = 62 blocks/req, arrival rate λ_l = 10 req/s (burst)
Under pure LRU, the hot prefix survives as long as:
N - H = 1,984 unique-prompt blocks arrive between two consecutive
hot-prefix-using requests.
At λ_l = 10 req/s, a burst of 32 requests fills 1,984 blocks
in ~3.2 seconds → hot prefix evicted in ~3.2 s of inactivity.
Under Two-Queue:
Hot prefix lives in the "hot" queue.
Eviction drains the "cold" queue first (1,984 one-time blocks).
Hot prefix is only touched after all 1,984 cold blocks are exhausted.
Effective survival time: ∞ as long as cold blocks keep arriving.
Expected hit-rate improvement in this workload: ~15–40% for system-prompt blocks, depending on burst intensity.
Memory Overhead of the Fix
| Policy |
Extra Data Structures |
Memory Overhead |
| LRU (current) |
None |
0 |
| Two-Queue (proposed) |
1 extra FreeKVCacheBlockQueue + set[int] |
~8 × N bytes for N=2048 → 16 KB |
| ARC (future) |
Ghost lists B1/B2, target_t1_size |
~48 × N bytes for N=2048 → 96 KB |
At 200,000 blocks (typical A100 80GB with Llama-3 70B), the Two-Queue overhead is ~1.6 MB — negligible against 80 GB GPU memory.
Proposed Solution
Design Principles
- Zero behavior change by default. The existing LRU is wrapped, not replaced. Users who do not set
- gpu_eviction_policy see identical behavior.
- O(1) per operation. Scheduling is on the critical path; no linear scans.
- Incremental. Two-Queue is a stepping stone toward full ARC, using the same abstraction.
- Consistent with CPU layer. Mirror the pluggable CachePolicy design already used by CPUOffloadingManager.
Architecture Overview
Before After
────────────────────── ──────────────────────────────────────
BlockPool BlockPool
└── free_block_queue └── _policy: GPUCachePolicy
FreeKVCacheBlockQueue ├── LRUGPUCachePolicy (default)
│ └── FreeKVCacheBlockQueue
└── TwoQueueGPUCachePolicy
├── _cold: FreeKVCacheBlockQueue
├── _hot: FreeKVCacheBlockQueue
└── _hot_set: set[int]
Alternatives
No response
Additional context
No response
Before submitting a new issue...
🚀 The feature, motivation and pitch
1. The Problem
1.1 Pain Point: LRU Has No Memory of How Hot a Block Was
vLLM's GPU KV cache (
BlockPool) uses a single-queue LRU implemented inFreeKVCacheBlockQueue. When a request finishes, its blocks are appended to the tail of the free queue. When new blocks are needed, eviction drains from thehead.
This design has one fundamental blind spot: once
ref_cntdrops to zero, all blocks are equal candidates for eviction, regardless of how many times they were previously reused. A system-prompt block touched by 10,000 requests and aone-time-use block are indistinguishable at eviction time.
1.2 Pain Point: Burst Requests Cause Scan Pollution
In production, LLM serving commonly experiences bursty traffic: a wave of requests with unique prompts arrives simultaneously, each generating a large number of KV blocks that immediately become eviction candidates. These low-value
blocks flood the tail of the free queue, pushing high-value cached prefix blocks toward the eviction head.
This is the classic scan pollution problem, well-known in OS page-cache literature. Linux solved it decades ago with the Active/Inactive list design.
vLLM has not yet applied equivalent protection to the GPU KV cache.
1.3 Why This Matters More Over Time
2. Root Cause Analysis
2.1 Current Architecture
BlockPool └── free_block_queue: FreeKVCacheBlockQueue # one doubly-linked list head (evict first) ←——————————————— tail (evict last) [ old cold blocks ] ... [ recently freed blocks ]
The "reverse-order free" heuristic in
single_type_kv_cache_manager.py:2.2 Where the Logic Breaks Down
Consider a 2,048-block GPU cache with a 64-block system prompt that has been reused by many requests:
Timeline
t=0 System prompt (64 blocks) used by 500 requests.
ref_cnt > 0 → blocks NOT in free queue → safe.
t=1 All 500 requests finish.
ref_cnt drops to 0 → 64 blocks appended to queue TAIL.
t=2 32 unique long-prompt requests arrive.
Each uses 62 blocks → 1,984 blocks freed → appended to TAIL.
Queue state after t=2:
head [...stale blocks...][system-prompt 64 blocks][1,984 one-time blocks] tail
↑ evicted first ↑ evicted last
t=3 Next batch needs 100 new blocks.
Eviction drains from head → system-prompt blocks EVICTED
before the 1,984 single-use blocks.
The 64 most valuable blocks are lost before 1,984 worthless blocks.
The Inconsistency: CPU Offload Already Has ARC
vLLM already solved this problem for the CPU offload layer:
ARCCachePolicy (vllm/v1/kv_offload/cpu/policies/arc.py) provides:
T1: recently accessed blocks (accessed once)
T2: frequently accessed blocks (accessed multiple times)
B1/B2: ghost lists for adaptive tuning
The GPU BlockPool has no equivalent. This RFC proposes closing that gap.
Quantified Impact
Cache Miss Cost
A prefix cache miss forces the engine to recompute tokens from scratch. For a model with hidden size 4096, 32 KV heads, 128 head dim, and block size 16:
Cost to recompute one missed block ≈ 16 tokens × full prefill attention cost
Typical TTFT (Time To First Token) regression from one evicted system-prompt block in a sequence of 2,048 tokens:
For a 64-block system prompt, multiply by 64.
Hit Rate Simulation (Analytical Model)
Assume:
Pool: N = 2,048 blocks
Hot prefix: H = 64 blocks, request arrival rate λ_h = 100 req/s
Unique long prompts: L = 62 blocks/req, arrival rate λ_l = 10 req/s (burst)
Under pure LRU, the hot prefix survives as long as:
N - H = 1,984 unique-prompt blocks arrive between two consecutive
hot-prefix-using requests.
At λ_l = 10 req/s, a burst of 32 requests fills 1,984 blocks
in ~3.2 seconds → hot prefix evicted in ~3.2 s of inactivity.
Under Two-Queue:
Hot prefix lives in the "hot" queue.
Eviction drains the "cold" queue first (1,984 one-time blocks).
Hot prefix is only touched after all 1,984 cold blocks are exhausted.
Effective survival time: ∞ as long as cold blocks keep arriving.
Expected hit-rate improvement in this workload: ~15–40% for system-prompt blocks, depending on burst intensity.
Memory Overhead of the Fix
At 200,000 blocks (typical A100 80GB with Llama-3 70B), the Two-Queue overhead is ~1.6 MB — negligible against 80 GB GPU memory.
Proposed Solution
Design Principles
Architecture Overview
Alternatives
No response
Additional context
No response
Before submitting a new issue...