Skip to content

perf(core): adaptive selector for the native paged-attention decode kernel #331

@inureyes

Description

@inureyes

Summary

Add an adaptive selector that only dispatches the native pooled paged-attention kernel in the regimes it is known to win, and uses the gather-SDPA fallback where it is faster.

Source: idea.md "5. Adaptive Native Paged-Attention Selector" (P1, expected payoff medium in a narrow regime, risk medium-high).

Current state

  • The scheduler requests native paged decode when using the paged backend (src/server/batch/scheduler.rs, around line 4194).
  • paged_decode_attention_pooled() attempts the native kernel when requested or forced by MLXCEL_PAGED_ATTENTION_NATIVE, then falls back to gather-SDPA if the kernel declines (src/lib/mlxcel-core/src/layers.rs:2551; fallback paged_decode_attention_pooled_fallback at line 2464).
  • The native pooled kernel declines once a layer grows beyond one slab, because it currently needs one contiguous pool tensor per side (src/lib/mlxcel-core/src/cache/paged.rs, around line 1598).
  • ADR 0001: the native fused path wins at context 4096 and batched decode B>=4, but loses for B=1 and at context 16384 (docs/adr/0001-paged-attention-gather-vs-fused-kernel.md). The ADR also notes that teaching the kernel per-slab base pointers is possible, but not obviously worth it without a workload where the trade-off flips.

Proposal

  • Add an adaptive selector keyed on batch size, visible context length, backend, and slab count; avoid native dispatch for regimes already known to lose (B=1, long context, multi-slab).
  • Run a multi-slab native-kernel spike ONLY if live traces show a meaningful number of requests in the "B>=4, around 4k context, multi-slab" regime.
  • Cache per-layer/shape dispatch decisions where practical so the selection logic itself adds no hot-path overhead.

Validation

  • Re-run examples/paged_attention_kernel_bench.rs on current M5/GB10 targets.
  • Add server traces with the paged backend at B=1/4/8 and contexts 1k/4k/16k.
  • Confirm the selector never regresses long-context single-sequence decode.

Acceptance criteria

  • The selector dispatches native only in the known-positive regime and uses gather-SDPA elsewhere.
  • No regression on long-context single-sequence (B=1) decode vs current behavior.
  • Dispatch decisions cached per layer/shape; selection adds no measurable hot-path overhead.
  • Bench and server-trace numbers recorded at B=1/4/8 and contexts 1k/4k/16k.
  • MLXCEL_PAGED_ATTENTION_NATIVE override still honored; any multi-slab kernel spike is gated behind real trace evidence.

Risk and priority

idea.md priority P1; risk medium-high. The native path has a known positive island; a selector can exploit it without making the default path worse elsewhere.

References

  • idea.md "5. Adaptive Native Paged-Attention Selector".
  • Code: src/server/batch/scheduler.rs, src/lib/mlxcel-core/src/layers.rs (paged_decode_attention_pooled), src/lib/mlxcel-core/src/cache/paged.rs.
  • docs/adr/0001-paged-attention-gather-vs-fused-kernel.md, examples/paged_attention_kernel_bench.rs.

Metadata

Metadata

Assignees

No one assigned

    Labels

    area:coremlxcel-core: MLX FFI, primitives, KV cache, layersarea:inferenceGeneration, sampling, decoding (incl. speculative, DRY)priority:mediumMedium prioritytype:performancePerformance improvements

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions