[common] Grouped gemm update - nvfp4 for blackwell and fp8 blockwise hopper#2971
[common] Grouped gemm update - nvfp4 for blackwell and fp8 blockwise hopper#2971pggPL wants to merge 22 commits into
Conversation
Use existing nvte_set_grouped_tensor_param with kNVTEGroupedWithGEMMSwizzledScales instead of the dedicated set/get functions. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add CUBLAS_NVFP4_GROUPED_GEMM_VERSION and CUBLAS_FP8_BLOCK_GROUPED_GEMM_VERSION macros (13.4+) - Update check_grouped_gemm_requirements to allow SM90 with cuBLAS 13.4+ - Refactor execute_grouped_gemm to use GroupedGemmConfig struct - Add divisibility-by-128 validation for FP8 block scaling in setup kernel and quantizer - Support scalar alpha/beta for Hopper (no per-group alpha/beta) - Expose get_grouped_gemm_setup_workspace_size to PyTorch via pybind - Update PyTorch tests to run grouped GEMM on Hopper with cuBLAS 13.4+ Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
… scaling tests on Hopper Extend nvte_grouped_gemm_with_discrete_inputA to handle NVFP4 (Float4E2M1) inputs: accept kFloat4E2M1 dtype, propagate scale_inv pointers, collect contiguous amax from discrete tensors, and enforce swizzled-scales checks for NVFP4 alongside MXFP8. Also add GTEST_SKIP for FP8 tensor scaling grouped GEMM on Hopper since cuBLAS does not support it there. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
…M tests The setup kernel computes per-tensor scale pointers as data_offset / block_size, which assumes no padding in the scale buffer. This is only correct when first_dim % 128 == 0 and last_dim % 128 == 0 (MXFP8) or last_dim % 64 == 0 (NVFP4). Add explicit assertions in build_grouped_tensor to catch any future test shapes that violate this. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
…d_hopper
Conflicts resolved (3 files):
* tests/pytorch/test_numerics.py
test_grouped_gemm_grouped_tensor: combined skip rules — Hopper (SM90) requires
cuBLAS 13.4+, Blackwell+ (SM100) requires cuBLAS 13.3+. Kept main's
use_bias_scale parametrization.
* transformer_engine/pytorch/cpp_extensions/gemm.py
general_grouped_gemm_for_grouped_tensor: combined HEAD's num_alphabeta logic
(single scalar on Hopper, per-group on Blackwell+) with main's cached
_get_fp32_ones_tensor / _get_fp32_zeros_tensor helpers.
* transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
- validate_grouped_gemm_inputs: kept HEAD's NVFP4 / FP8 block-scaling
consistency checks, wrapped in main's nullptr-guard / continue-on-no-data
pattern.
- GroupedGemmConfig struct retained; added sm_count from main and
propagated config_.sm_count -> gemm_config.sm_count in all three
public APIs.
- kMaxTensorsPerKernel rename to kMaxGroups (= 64) adopted from main.
- execute_grouped_gemm signature uses GroupedGemmConfig (HEAD); body uses
config.sm_count for CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET (from main).
- Dropped HEAD's simple grouped_bias_add_kernel (dead code); kept main's
advanced grouped_bias_add_kernel + find_tensor_for_row helper.
- Replaced inline SM/cuBLAS preambles with check_grouped_gemm_requirements()
calls in nvte_grouped_gemm, nvte_grouped_gemm_with_discrete_inputA, and
nvte_grouped_gemm_with_discrete_out. The helper supports both
Hopper (SM90 + cuBLAS 13.4+) and Blackwell+ (SM100 + cuBLAS 13.3+).
- Kept HEAD's validate_grouped_gemm_inputs(..., use_per_group_alpha_beta)
signature for proper alpha/beta validation across architectures.
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…or swizzle tests
cublaslt_grouped_gemm.cu:
- Fix incorrect handling of NVFP4/MXFP8 columnwise data in
build_grouped_gemm_multi_inputA_args by adding a swap_dims flag
consistent with choose_grouped_operand_storage. Use A_sel.trans
(post-flip) for gemm_config.avg_k so K is selected from the
correct dim with discrete A_list.
tests/cpp/test_common.{h,cu}:
- Add enforce_grouped_gemm_alignment parameter (default true) to
build_grouped_tensor; the MXFP8/NVFP4 first/last_dim 128/64
alignment asserts are only relevant for the grouped GEMM setup
kernel, so callers that bypass it (swizzle/unswizzle) opt out.
tests/cpp/operator/test_swizzle.cu:
- Pass enforce_grouped_gemm_alignment=false to build_grouped_tensor
in MXFP8 swizzle/unswizzle/roundtrip tests, which intentionally
exercise non-padded shapes.
tests/cpp/operator/test_grouped_gemm.cu:
- Sync GPU/cuBLAS skip rules across all 3 sub-tests, add
cudaDeviceSynchronize() after nvte_multi_tensor_gemm reference for
defensive sync, and skip NVFP4 + AllDifferent in all 3 sub-tests
due to a known flaky bug in the nvte_multi_tensor_gemm reference.
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Made-with: Cursor
…and_hopper Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com> # Conflicts: # transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
for more information, see https://pre-commit.ci
335627f to
3f523e7
Compare
Apply the same fix as upstream PR NVIDIA#2954 (MXFP8 unaligned dims) to the analogous NVFP4 / FP8 block scaling paths in setup_grouped_gemm_kernel. Background: cuBLAS grouped GEMM expects each expert's scale_inv to live at a specific offset in the contiguous grouped buffer. The quantizer allocates each per-expert scale_inv tensor padded to the layout cuBLAS needs (swizzled 128x4 for MX/NV; ceildiv(., 128) x roundup(., 4) for block scaling). The setup kernel was computing these offsets as data_offset / block_size for everything except MXFP8 — silently correct when dims align to 128, but pointing at the middle of the previous expert's scale tile when they do not. In MoE forward this is reachable through variable per-expert token counts. Add three device helpers mirroring compute_grouped_tensor_mxfp8_- scale_inv_offset: - compute_grouped_tensor_nvfp4_scale_inv_offset - compute_grouped_tensor_block_1d_scale_inv_offset - compute_grouped_tensor_block_2d_scale_inv_offset Each sums the same padded per-tensor sizes the quantizer uses at alloc time (Float8BlockQuantizer::get_scale_shape, NVFP4Quantizer::get_scale_- shape). NVFP4 columnwise data is set up via use_columnwise(swap_dims=true), so sel.shape is already pre-transposed for that recipe — the rowwise formula on (first, last) recovers the colwise alloc. For block scaling the formula depends on the canonical orientation, so propagate a new swap_dims field on GroupedOperandSelection and pass effective_rowwise (sel.rowwise || sel.swap_dims) into the kernel. MXFP8 is invariant under this change because swap_dims is always false there and its helper's byte count is invariant under the rowwise flag anyway. Test: add ShapeCase::kUnalignedAllSame with (M, N, K) = (160, 288, 416) — all multiples of 32/16 (per-recipe block size) but none multiples of 128, so each expert's scale tile is padded. Exercise it across MXFP8 / NVFP4 / FP8 block scaling and the three transpose configs that match the existing parameter grid. Relax build_grouped_tensor's defensive %128 / %64 alignment assertions to %32 / %16 (block-size only), which is the actual quantizer requirement now that the offset arithmetic no longer assumes zero padding. Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
…st cleanup
Production:
- nvte_grouped_gemm_with_discrete_inputA no longer requires per-expert amax
buffers to be contiguous. Add `amax_ptrs[kMaxGroups]` to MultiTensorGroupGemmInputArgs
and read each tensor's amax via indirection in setup_grouped_gemm_kernel
(mirrors the existing scale_inv_ptrs pattern). The launcher enables the
NVFP4 alpha computation when amax is available from either source.
- Consolidate four near-identical
compute_grouped_tensor_{mxfp8,nvfp4,block_1d,block_2d}_scale_inv_offset
into a single template `compute_grouped_scale_inv_offset<PaddedFn>` and
collapse the A/B recipe-switch in setup_grouped_gemm_kernel into a local
`fill_scale_ptr` lambda.
Tests:
- Drop the per-test amax staging workaround in run_grouped_gemm_discrete_in_case
(no longer needed after the contiguity relax).
- Fix amax management in make_nvfp4_operand: copy values into result's own
amax buffers instead of aliasing pointers (prevents double-free).
- Extract the three duplicated cuBLAS-version/compute-capability skip blocks
into a shared `grouped_gemm_skip_reason` helper.
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
for more information, see https://pre-commit.ci
Silences -Wunused-variable (NVIDIA#177-D in nvcc). Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
fcefde1 to
ce0e4d2
Compare
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
ce0e4d2 to
a4df7bd
Compare
- nvte_grouped_gemm and nvte_grouped_gemm_with_discrete_out now validate per-operand amax for NVFP4 (previously silently dropped the global-scale factor when amax was missing). discrete_inputA path also checks B's amax. - Remove unused ShapeCase::kUnalignedAllSameNVFP4 enum and its comment. - OperandStorageChoice::swap_dims now defaults to false; rowwise returns no longer pass spurious swap_dims=true. - Unify GroupedGemmSetupWorkspace layout: from_buffers(nullptr, n) returns the total byte count, and required_setup_size derives its result from it so the layout cannot drift between the two. - test_common.cu: consolidate the three gather_*_scales lambdas into a single gather_scale_inv(bytes_per_elem, get_shape, get_cpu_ptr) helper. - test_grouped_gemm.cu: extract make_grouped_gemm_ref / make_alpha_beta / compare_grouped_d_to_multi helpers; the three run_* variants drop from ~1029 to 774 lines with no behavior change. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch |
Greptile SummaryThis PR extends the cuBLAS grouped GEMM implementation to support Hopper (SM90) with cuBLAS 13.4+ (previously Blackwell-only), and adds two new quantization recipes — NVFP4 (Float4E2M1 with per-block E4M3 scales) and FP8 1D/2D block scaling. The implementation refactors workspace sizing, operand layout selection, and per-expert scale-offset arithmetic into a unified, recipe-agnostic framework.
Confidence Score: 3/5The core GEMM dispatch and scale-offset arithmetic look correct for the tested recipes, but the three grouped GEMM entry points are not consistently guarded: only The missing FP4+FP16 output guard is an explicit check the author added to one entry point but not the other two that accept the same input recipes. In practice this would trigger a cuBLAS error rather than silent data corruption, but the inconsistency makes the API surface fragile. The
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
Entry["nvte_grouped_gemm / discrete_out / discrete_inputA"] --> HW["check_grouped_gemm_requirements\n(SM>=90 + cuBLAS 13.3+; SM<100 needs cuBLAS 13.4+)"]
HW --> SM{"SM >= 100 (Blackwell+)?"}
SM -- Yes --> PAB["use_per_group_alpha_beta = true\nAlpha/Beta: per-group arrays"]
SM -- No --> SAB["use_per_group_alpha_beta = false\nAlpha/Beta: single scalar"]
PAB & SAB --> SEL["select_grouped_operand(A, B)\nchoose_grouped_operand_storage\n(TN forced for MXFP8, NVFP4, FP8-block, tensor-FP8)"]
SEL --> RECIPE{"Scaling Recipe?"}
RECIPE -- "BF16/FP16" --> PLAIN["No scale handling"]
RECIPE -- "MXFP8 (SM>=100)" --> MX["set_mxfp8_scale_pointers\npadded_mxfp8_scale_inv_bytes offsets"]
RECIPE -- "NVFP4 (SM>=100)" --> FP4["set_nvfp4_scale_pointers\npadded_nvfp4_scale_inv_bytes offsets\nComputed alpha: a x amax_A x amax_B / factor"]
RECIPE -- "FP8 block (SM=90)" --> BS["set_fp8_block_scaling_scale_pointers\npadded_block_1d/2d_scale_inv_floats offsets"]
RECIPE -- "FP8 tensor" --> FP8["set_fp8_scale_pointers\nOne float per tensor"]
PLAIN & MX & FP4 & BS & FP8 --> SETUP["setup_grouped_gemm_kernel (GPU)\nFills A/B/C/D/scale/alpha/beta pointer arrays"]
SETUP --> CUBLASLT["cublasLtMatmul (grouped GEMM)"]
|
| const bool is_fp8 = is_fp8_dtype(rep_dtype); | ||
| const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); | ||
| const bool mxfp8 = transformer_engine::is_mxfp_scaling(A_list_info.scaling_mode); | ||
| const bool nvfp4 = transformer_engine::is_nvfp_scaling(A_list_info.scaling_mode); | ||
| const bool fp8_block = transformer_engine::is_fp8_block_scaling(A_list_info.scaling_mode); |
There was a problem hiding this comment.
Inconsistent
non_tn_fp8_ok override for FP8 block scaling in discrete-A path
select_grouped_operand (used for B) always forces non_tn_fp8_ok = false for FP8 block scaling, but the discrete-A code path retains the device-capability value of nvte_is_non_tn_fp8_gemm_supported(). On hardware where that function returns true, A and B would disagree on whether TN is required, causing a layout mismatch. It is harmless today because FP8 block scaling is restricted to Hopper where the function returns false, but the inconsistency is fragile.
| const bool is_fp8 = is_fp8_dtype(rep_dtype); | |
| const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported(); | |
| const bool mxfp8 = transformer_engine::is_mxfp_scaling(A_list_info.scaling_mode); | |
| const bool nvfp4 = transformer_engine::is_nvfp_scaling(A_list_info.scaling_mode); | |
| const bool fp8_block = transformer_engine::is_fp8_block_scaling(A_list_info.scaling_mode); | |
| const bool is_fp8 = is_fp8_dtype(rep_dtype); | |
| const bool mxfp8 = transformer_engine::is_mxfp_scaling(A_list_info.scaling_mode); | |
| const bool nvfp4 = transformer_engine::is_nvfp_scaling(A_list_info.scaling_mode); | |
| const bool fp8_block = transformer_engine::is_fp8_block_scaling(A_list_info.scaling_mode); | |
| // FP8 block scaling on Hopper requires TN layout (matches select_grouped_operand logic for B). | |
| const bool non_tn_fp8_ok = fp8_block ? false : nvte_is_non_tn_fp8_gemm_supported(); |
| // Compile-time version macro for Hopper grouped GEMM support (mirrors cublaslt_grouped_gemm.cu) | ||
| #define CUBLAS_GROUPED_GEMM_HOPPER_VERSION 130400 |
There was a problem hiding this comment.
Duplicated version constant that can drift out of sync
CUBLAS_GROUPED_GEMM_HOPPER_VERSION is defined both here (as a test-local macro) and in cublaslt_grouped_gemm.cu. If the implementation version is ever updated, this copy may be forgotten. Consider exposing the constant via a shared header or at least using a static_assert to catch drift at compile time.
| // Compile-time version macro for Hopper grouped GEMM support (mirrors cublaslt_grouped_gemm.cu) | |
| #define CUBLAS_GROUPED_GEMM_HOPPER_VERSION 130400 | |
| // Compile-time version macro for Hopper grouped GEMM support (mirrors cublaslt_grouped_gemm.cu). | |
| // Keep in sync with CUBLAS_GROUPED_GEMM_HOPPER_VERSION in cublaslt_grouped_gemm.cu. | |
| #define CUBLAS_GROUPED_GEMM_HOPPER_VERSION 130400 | |
| static_assert(CUBLAS_GROUPED_GEMM_HOPPER_VERSION == 130400, | |
| "Update this copy to match cublaslt_grouped_gemm.cu"); |
| Tensor make_nvfp4_rowwise(const std::string& name, const std::vector<size_t>& shape) { | ||
| Tensor input_bf16(name + "_bf16", shape, DType::kBFloat16); | ||
| fillUniform(&input_bf16); | ||
|
|
||
| Tensor nvfp4(name, shape, DType::kFloat4E2M1, /*rowwise=*/true, /*columnwise=*/false, | ||
| NVTE_NVFP4_1D_SCALING); | ||
|
|
||
| QuantizationConfigWrapper quant_config; | ||
| nvte_quantize_v2(input_bf16.data(), nvfp4.data(), quant_config, 0); | ||
|
|
||
| Tensor nvfp4_sw(name + "_sw", shape, DType::kFloat4E2M1, | ||
| /*rowwise=*/true, /*columnwise=*/false, NVTE_NVFP4_1D_SCALING); | ||
| nvfp4_sw.set_with_gemm_swizzled_scales(true); | ||
| size_t data_bytes = test::bytes(nvfp4.rowwise_shape(), nvfp4.dtype()); | ||
| NVTE_CHECK_CUDA(cudaMemcpy(nvfp4_sw.rowwise_dptr(), nvfp4.rowwise_dptr(), | ||
| data_bytes, cudaMemcpyDeviceToDevice)); | ||
| nvte_swizzle_scaling_factors(nvfp4.data(), nvfp4_sw.data(), 0); | ||
| NVTE_CHECK_CUDA(cudaDeviceSynchronize()); | ||
| return nvfp4_sw; | ||
| } |
There was a problem hiding this comment.
Why can't we just swizzle the scales in the first nvte_quantize_v2 call instead of going through 2 tensors?
| Tensor rowwise = make_nvfp4_rowwise(name + "_row", shape); | ||
|
|
||
| // 2. Columnwise: transpose input, quantize + swizzle as rowwise of transposed shape | ||
| std::vector<size_t> t_shape = {shape[1], shape[0]}; | ||
| Tensor colwise = make_nvfp4_rowwise(name + "_col", t_shape); |
There was a problem hiding this comment.
Both of those tensors are using different inputs (fillUniform called in both invocations of the make_nvfp4_rowwise function).
ptrendx
left a comment
There was a problem hiding this comment.
There are issues with the test code (most notably the rowwise and columnwise taken from different inputs).
| // Creates an NVFP4 operand with both rowwise and columnwise data, swizzled scales. | ||
| // NVFP4 "columnwise" data is the transposed tensor quantized rowwise. | ||
| // We quantize rowwise directly, and for columnwise we quantize the transposed input rowwise. | ||
| Tensor make_nvfp4_operand(const std::string& name, const std::vector<size_t>& shape, |
There was a problem hiding this comment.
In general I don't understand why you can't just quantize the tensor in both directions at the same time, then all of those issues with using different data for both would not be there.
The FP8 blockwise counterpart is doing just that.
| return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}}; | ||
| case ShapeCase::kUnalignedAllSame: | ||
| default: | ||
| // (M, N, K) all multiples of 32 (MXFP8 block) and 16 (NVFP4 block), but NONE |
There was a problem hiding this comment.
The multiple of 16 for NVFP4 is wrong as TMA requires the alignment to 16B, which is 32 elements in case of the NVFP4. So if anything for the "not nice" shapes we should test multiple of 16 for MXFP8 (although I'm not sure if that would be currently passed by the rest of the logic in cublaslt_gemm.cu, I do have a separate PR to relax some of those requirements) and multiple of 32 for NVFP4.
There was a problem hiding this comment.
That said, the values that you have here are actually common to both of the types and both have K being multiple of 32, so this comment is just wrong.
| return "Grouped GEMM requires cuBLAS 13.3+, but compile-time cuBLAS version is " + | ||
| std::to_string(CUBLAS_VERSION) + "."; |
There was a problem hiding this comment.
I don't believe this is correct here. The code in TE itself needs to check both compile and runtime versions of cublas. The test should not care at all about the compilation version (since it doesn't actually use any API from cublas) and instead should check against the runtime version.
| std::vector<size_t>{M, N}, | ||
| DType::kBFloat16)); | ||
| s.D_multi.emplace_back(Tensor("D_multi" + std::to_string(i), | ||
| std::vector<size_t>{M, N}, DType::kBFloat16)); |
There was a problem hiding this comment.
Do we only support BF16 output?
| AlphaBetaTensors ab = make_alpha_beta(num_gemms); | ||
|
|
||
| constexpr size_t cublas_ws_bytes = 32ull * 1024 * 1024; | ||
| const size_t setup_ws_bytes = nvte_get_grouped_gemm_setup_workspace_size(num_gemms); |
There was a problem hiding this comment.
TBH not a fan of the name of this function, but I guess that ship has sailed already.
| if (auto reason = grouped_gemm_skip_reason(params.input_case); !reason.empty()) { | ||
| GTEST_SKIP() << reason; | ||
| } | ||
| #if CUBLAS_VERSION >= 130300 |
There was a problem hiding this comment.
Why do we need this guard? We are not using any cublas API here so compilation is OK and we already skipped the test at this point anyway if the cublas version is too low.
| kFP8Current, | ||
| kBF16, | ||
| kMXFP8, | ||
| kNVFP4, |
There was a problem hiding this comment.
BTW, why do we have another enum here that is basically the same as the scaling mode?
| NVTE_CHECK(last_dims[i] % 32 == 0, | ||
| "MXFP8 grouped GEMM test: last_dim must be divisible by 32, got ", | ||
| last_dims[i]); |
There was a problem hiding this comment.
We should be able to have 16 here, but probably not yet, see my PR #2894.
| ws.d_cols = reinterpret_cast<int *>(setup_ws_ptr + offset); | ||
|
|
||
| // 8 pointer arrays (each 16-byte aligned), then 6 int arrays, then 1 float array. | ||
| align_ptr(); |
There was a problem hiding this comment.
This is wrong if the base is not aligned already (since it only takes the offset into account and that at the beginning is always 0). Fixing this then produces an issue where the returned workspace size is not actually between the get workspace size call and the actual execution, since in the get workspace size call the base is set to nullptr, so always aligned. This call should therefore assume the worst-case alignment requirement when calculating the workspace size.
| NVTE_CHECK(sm >= 100, api_name, " requires Blackwell (SM100) or newer architecture."); | ||
| NVTE_CHECK(cublas_ver >= CUBLAS_GROUPED_GEMM_VERSION, api_name, | ||
| " requires cuBLAS 13.3+, but run-time cuBLAS version is ", cublas_ver); |
There was a problem hiding this comment.
This seems wrong - what about the situation where we compiled against the cublas version that is not enough for any grouped gemm support (and so some stuff is not compiled I think?) but then run it on a system with newer cublas? This would pass, but the functionality would still not be there.
Description
Adds Hopper (SM90) support to cuBLAS grouped GEMM and enables NVFP4 / FP8 block scaling recipes.
Type of change
Checklist