Problem / Background
Follow-up to #326 (merged PR #341), which shipped the generalized fused QKV+RMSNorm+RoPE decode path for Qwen3 and Qwen3-MoE as opt-in (default off, MLXCEL_FUSED_QK_NORM=1). It measured about 1 to 3.4% slower on M1 Ultra (qwen3-0.6b 275 vs 284, qwen3-8b 82.3 vs 83.2 tok/s): the kernel cuts Rust/C++ FFI crossings rather than MLX op count, and Apple Silicon's fast FFI makes that net-negative. Decode output is byte-identical to the graph path. On CUDA the op-dispatch and FFI cost profile differs, so the fused path may win there.
Proposed Solution
- Re-bench fused QK-norm on CUDA: qwen3 and qwen3-MoE decode, fused vs graph via the
MLXCEL_FUSED_QK_NORM toggle.
- Reconfirm correctness (byte-identical output, already validated on Metal).
- Decide whether to default it on per-backend for CUDA while keeping Metal opt-in.
Acceptance Criteria
Technical Considerations
The shared primitive already ships and is reused for the deferred QK-norm families; this issue is only the per-backend default decision, mirroring the opt-in MLXCEL_FUSED_MOE_RELU2 pattern. The MLXCEL_FUSED_QK_NORM path is active only when l == 1 (decode) and weights are quantized. Reference: the MLXCEL_FUSED_QK_NORM entry in docs/environment-variables.md, which documents the M1 Ultra regression and the CUDA-pending rationale.
Problem / Background
Follow-up to #326 (merged PR #341), which shipped the generalized fused QKV+RMSNorm+RoPE decode path for Qwen3 and Qwen3-MoE as opt-in (default off,
MLXCEL_FUSED_QK_NORM=1). It measured about 1 to 3.4% slower on M1 Ultra (qwen3-0.6b 275 vs 284, qwen3-8b 82.3 vs 83.2 tok/s): the kernel cuts Rust/C++ FFI crossings rather than MLX op count, and Apple Silicon's fast FFI makes that net-negative. Decode output is byte-identical to the graph path. On CUDA the op-dispatch and FFI cost profile differs, so the fused path may win there.Proposed Solution
MLXCEL_FUSED_QK_NORMtoggle.Acceptance Criteria
Technical Considerations
The shared primitive already ships and is reused for the deferred QK-norm families; this issue is only the per-backend default decision, mirroring the opt-in
MLXCEL_FUSED_MOE_RELU2pattern. TheMLXCEL_FUSED_QK_NORMpath is active only whenl == 1(decode) and weights are quantized. Reference: theMLXCEL_FUSED_QK_NORMentry indocs/environment-variables.md, which documents the M1 Ultra regression and the CUDA-pending rationale.