Summary
Add an adaptive selector that only dispatches the native pooled paged-attention kernel in the regimes it is known to win, and uses the gather-SDPA fallback where it is faster.
Source: idea.md "5. Adaptive Native Paged-Attention Selector" (P1, expected payoff medium in a narrow regime, risk medium-high).
Current state
- The scheduler requests native paged decode when using the paged backend (
src/server/batch/scheduler.rs, around line 4194).
paged_decode_attention_pooled() attempts the native kernel when requested or forced by MLXCEL_PAGED_ATTENTION_NATIVE, then falls back to gather-SDPA if the kernel declines (src/lib/mlxcel-core/src/layers.rs:2551; fallback paged_decode_attention_pooled_fallback at line 2464).
- The native pooled kernel declines once a layer grows beyond one slab, because it currently needs one contiguous pool tensor per side (
src/lib/mlxcel-core/src/cache/paged.rs, around line 1598).
- ADR 0001: the native fused path wins at context 4096 and batched decode B>=4, but loses for B=1 and at context 16384 (
docs/adr/0001-paged-attention-gather-vs-fused-kernel.md). The ADR also notes that teaching the kernel per-slab base pointers is possible, but not obviously worth it without a workload where the trade-off flips.
Proposal
- Add an adaptive selector keyed on batch size, visible context length, backend, and slab count; avoid native dispatch for regimes already known to lose (B=1, long context, multi-slab).
- Run a multi-slab native-kernel spike ONLY if live traces show a meaningful number of requests in the "B>=4, around 4k context, multi-slab" regime.
- Cache per-layer/shape dispatch decisions where practical so the selection logic itself adds no hot-path overhead.
Validation
- Re-run
examples/paged_attention_kernel_bench.rs on current M5/GB10 targets.
- Add server traces with the paged backend at B=1/4/8 and contexts 1k/4k/16k.
- Confirm the selector never regresses long-context single-sequence decode.
Acceptance criteria
Risk and priority
idea.md priority P1; risk medium-high. The native path has a known positive island; a selector can exploit it without making the default path worse elsewhere.
References
- idea.md "5. Adaptive Native Paged-Attention Selector".
- Code:
src/server/batch/scheduler.rs, src/lib/mlxcel-core/src/layers.rs (paged_decode_attention_pooled), src/lib/mlxcel-core/src/cache/paged.rs.
docs/adr/0001-paged-attention-gather-vs-fused-kernel.md, examples/paged_attention_kernel_bench.rs.
Summary
Add an adaptive selector that only dispatches the native pooled paged-attention kernel in the regimes it is known to win, and uses the gather-SDPA fallback where it is faster.
Source: idea.md "5. Adaptive Native Paged-Attention Selector" (P1, expected payoff medium in a narrow regime, risk medium-high).
Current state
src/server/batch/scheduler.rs, around line 4194).paged_decode_attention_pooled()attempts the native kernel when requested or forced byMLXCEL_PAGED_ATTENTION_NATIVE, then falls back to gather-SDPA if the kernel declines (src/lib/mlxcel-core/src/layers.rs:2551; fallbackpaged_decode_attention_pooled_fallbackat line 2464).src/lib/mlxcel-core/src/cache/paged.rs, around line 1598).docs/adr/0001-paged-attention-gather-vs-fused-kernel.md). The ADR also notes that teaching the kernel per-slab base pointers is possible, but not obviously worth it without a workload where the trade-off flips.Proposal
Validation
examples/paged_attention_kernel_bench.rson current M5/GB10 targets.Acceptance criteria
MLXCEL_PAGED_ATTENTION_NATIVEoverride still honored; any multi-slab kernel spike is gated behind real trace evidence.Risk and priority
idea.md priority P1; risk medium-high. The native path has a known positive island; a selector can exploit it without making the default path worse elsewhere.
References
src/server/batch/scheduler.rs,src/lib/mlxcel-core/src/layers.rs(paged_decode_attention_pooled),src/lib/mlxcel-core/src/cache/paged.rs.docs/adr/0001-paged-attention-gather-vs-fused-kernel.md,examples/paged_attention_kernel_bench.rs.