Add moe prefill/ decode with int2/int4/int8 sym /asym and fp8 e4m3 e5m2#1813
Conversation
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/95841e6d-d5d1-4662-8db0-4dd69690bc28 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/95841e6d-d5d1-4662-8db0-4dd69690bc28 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
for more information, see https://pre-commit.ci
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/91221649-2c90-4404-ae86-3321b1581428 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/91221649-2c90-4404-ae86-3321b1581428 Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
|
@copilot resolve the merge conflicts in this pull request |
…ecode-implementation # Conflicts: # auto_round_extension/ark/auto_round_kernel/ark.cpp Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
Merged |
There was a problem hiding this comment.
Pull request overview
This PR adds an XPU-optimized MoE decode-phase GEMV kernel (small M per expert) with multiple weight formats, and wires it through the C++/PyTorch extension layer with corresponding unit tests.
Changes:
- Added a SYCL decode GEMV kernel supporting FP16/BF16, INT8/INT4/INT2 (sym/asym), and FP8 (E4M3/E5M2) weights.
- Exposed the kernel via pybind (
moe_gemm_decode) and added a Python wrapper with argument validation. - Added unit tests covering the new decode paths and key validation error cases.
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
| auto_round_extension/ark/test/test_moe.py | Adds decode-path unit tests plus packing/dequant reference helpers for INT2/4/8 and FP8. |
| auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_moe_decode.hpp | Introduces the new SYCL MoE decode GEMV kernel implementations and dispatch. |
| auto_round_extension/ark/auto_round_kernel/wrapper/include/sycl_tla_common.hpp | Declares the new moe_gemm_decode API (but docs currently lag implementation). |
| auto_round_extension/ark/auto_round_kernel/ark.cpp | Includes the new header and binds moe_gemm_decode via pybind. |
| auto_round_extension/ark/auto_round_kernel/init.py | Adds the ARK.moe_gemm_decode Python wrapper and validation logic. |
Comments suppressed due to low confidence (2)
auto_round_extension/ark/auto_round_kernel/init.py:871
- num_tokens_per_expert is converted to int32/contiguous but its device is not validated. If it’s a CPU tensor, the kernel will treat a host pointer as device memory. Please ensure num_tokens_per_expert is on XPU (and matches activations.device), or move it to XPU explicitly before calling into the extension.
weights = weights.contiguous()
if num_tokens_per_expert.dtype != torch.int32:
num_tokens_per_expert = num_tokens_per_expert.to(torch.int32)
if not num_tokens_per_expert.is_contiguous():
auto_round_extension/ark/auto_round_kernel/init.py:896
- group_size is used in modulo/division checks (e.g.,
K % group_size) without validating group_size > 0. Passing group_size=0 will raise a ZeroDivisionError rather than a clear ValueError. Please add an explicit check that group_size is a positive integer before any modulo/division operations.
if scales is None:
raise ValueError("scales is required for FP8 weights")
if scales.dtype != activations.dtype:
raise ValueError("scales dtype must match activations dtype")
if K % group_size != 0:
Agent-Logs-Url: https://github.com/intel/auto-round/sessions/132db2ab-85c0-45b6-81a7-b9baaa533e5e Co-authored-by: a32543254 <53296245+a32543254@users.noreply.github.com>
- test_perf_int8_per_tensor mirrors test_perf_fp8_per_tensor with sym round-nearest-clamp int8 packing and [E] fp32 scales; native(ms) / dpas(ms) columns stay '--' since the Variant A DPAS entry point IS the ARK column for this scheme. - test_accuracy_int8_per_tensor_dpas mirrors the FP8 counterpart under the standard _TOL_INT8 tolerance. - Sync README_MOE_PREFILL_PERF.md + _CN.md with a new INT8 per-tensor section (per AGENTS.md CN-docs rule). - Both tests skip silently when the moe_gemm_prefill_int_dpas pybind symbol is absent.
Header now reads 'INT8 per-expert scale (scales=[E] fp32, ...)' instead of 'INT8 per-expert scale int8 (...)'. Cosmetic fix from code review.
`sycl_tla_moe_prefill_int_dpas.hpp` opens its own `ark::moe_dpas_int` namespace and imports select names from `ark::moe_dpas_fp8`, but `make_moe_tensor` (defined in the FP8 header) was missing from the using-declarations. The three call sites in `MoEGEMM_int` therefore failed to resolve at compile time (unqualified lookup found nothing, and ADL cannot help because the arguments are pointers/ints). Add the missing `using ::ark::moe_dpas_fp8::make_moe_tensor;`.
Removes the host `wait()` between the asum precompute kernel and the
INT8 asym DPAS grouped-GEMM submit in `moe_prefill_int_dpas_per_group_dispatch`.
The asum event is now threaded into the DPAS submit via
`cgh.depends_on(...)` on the same SYCL queue, so:
* device-side ordering is preserved (identical numerical results),
* one host round-trip per moe_gemm_prefill call is eliminated, and
* DPAS launch-prep (template inst, hw query, kernel props, arg
marshalling) runs on the host in parallel with asum's device time.
Sym path is byte-identical: the new event parameter on
`MoEGEMMLauncher_int` defaults to an empty (completed) event, and the
sym launch site keeps its previous argument list.
Refs perf gap in `test_perf_int8[True-dtype0]` where asym dpas trails
sym dpas by ~20% -- this closes the host-stall portion of that gap.
Full mainloop-fusion (Lever 2 as originally described) is deferred:
it requires either DPAS w8a16 A-fragment lane-layout knowledge or an
SLM-atomic accumulator whose contention profile needs to be measured
on XPU hardware to confirm it isn't a regression vs. the current
well-coalesced pre-pass kernel. Not landing that blind.
… 1)" This reverts commit 0363ace.
Roll back the INT8 asym DPAS path (perf regressed vs. dequant fallback on hardware). Add INT4-sym and INT2-sym prefill paths that upcast the packed weights into an int8_t [E, N, K] view inside the existing dequant workspace and dispatch through the same per-group INT8 DPAS mainloop the S8-sym branch uses, reusing the packed scale tensor unmodified.
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
Add packed-word decoders `decode_int4_octet` / `decode_int2_octet` and
switch `launch_dequant_int{4,2}` and their sym→int8 upcast siblings to a
one-work-item-per-word fast path (INT4: 4 bytes = 8 K outputs; INT2:
2 bytes = 8 K outputs), amortising packed-byte loads by 4×/2× and
scale/zero loads by 4×/2× relative to the previous byte-per-item path.
* No numerics change: arithmetic still runs in fp32 and the octet
decoders are thin `#pragma unroll` wrappers over the existing
`decode_int4_pair` / `decode_int2_quad`, so results are bit-identical
to the scalar path (verified exhaustively for the INT2 uint16
domain).
* Fast path is guarded by `K % 8 == 0 && group_size % 8 == 0` (upcast:
only `K % 8 == 0`); the scalar byte-per-item path is retained
unchanged as a fallback so short-K unit tests keep passing.
* Decode/GEMV path (`sycl_tla_moe_decode.hpp`), FP8/INT8 kernels, and
the shared scalar decoders are untouched — decode↔prefill bit-parity
is preserved by construction.
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
… "minimax real" shapes" This reverts commit 3da3d0a.
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
…del_perf - Add ``--all-shapes`` pytest CLI flag in ark/test/conftest.py. - test_moe_prefill_perf.py: default sweep restricted to the 2K rows (4 shapes). ``--all-shapes`` restores the full 12-row matrix; the existing ``--minimax-real-only`` flag still composes. - test_moe_decode_perf.py: default sweep restricted to bs1 (2 shapes). ``--all-shapes`` re-adds the bs32 rows. - Delete test/test_ark/test_moe_model_perf.py.
|
/azp run Unit-Test-CUDA-AutoRound |
|
Azure Pipelines could not run because the pipeline triggers exclude this branch/path. |
Signed-off-by: Dong, Bo1 <bo1.dong@intel.com>
Restore the qwen MoE perf benchmark to
Qwen/Qwen1.5-MoE-A2.7Band remove the DeepSeek-V2-Lite case fromtest/test_ark/test_moe_model_perf.py.Description
"Qwen/Qwen3-30B-A3B"to the sharedqwen_moe_name_or_pathhelper; deepseek-v2-lite entry removed.qwen_moe_name_or_path, drop now-unuseddeepseek_v2_name_or_path.Type of Change
Test
Checklist Before Submitting
/azp run Unit-Test-CUDA-AutoRound.