Skip to content

[cuda backend] store scale/zero in int4_plain_mm in [N, n_groups] layout#20038

Open
Gasoonjia wants to merge 1 commit into
mainfrom
g4-opt-coalesced-scale
Open

[cuda backend] store scale/zero in int4_plain_mm in [N, n_groups] layout#20038
Gasoonjia wants to merge 1 commit into
mainfrom
g4-opt-coalesced-scale

Conversation

@Gasoonjia

@Gasoonjia Gasoonjia commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

This PR updates int4_plain_mm in cuda backend to reads scale/zero in the transposed [N, n_groups] layout instead of [n_groups, N]. In this way every warp can load both scale and zero together in one cache line, instead of 32 cache lines previously.

gemma4-31b decode perf: ~27 token/s -> 37.36 token/s.

cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani

@pytorch-bot

pytorch-bot Bot commented Jun 4, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20038

Note: Links to docs will display an error until the docs builds have been completed.

❌ 4 New Failures, 1 Pending

As of commit 8e404c7 with merge base a79f3e4 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 4, 2026
@Gasoonjia Gasoonjia changed the title G4 opt coalesced scale [cuda backend] store scale/zero in [N, n_groups] to reduce cache read cost. Jun 4, 2026
@Gasoonjia Gasoonjia changed the title [cuda backend] store scale/zero in [N, n_groups] to reduce cache read cost. [cuda backend] store scale/zero in int4_plain_mm in [N, n_groups] layout Jun 4, 2026
@Gasoonjia Gasoonjia marked this pull request as ready for review June 4, 2026 17:58
@mergennachin mergennachin requested a review from digantdesai June 4, 2026 18:06
Comment thread backends/cuda/runtime/shims/int4_plain_mm.cuh Outdated
Comment thread backends/cuda/runtime/shims/int4_plain_mm.cuh Outdated
Comment thread backends/cuda/runtime/shims/int4_plain_mm.cuh Outdated
Comment thread backends/cuda/runtime/shims/int4_plain_mm.cuh Outdated
Comment on lines +102 to +103
// Reads scale/zero in the transposed [N, n_groups] layout (transposed AOT at
// export time). With group_size >= 32, one uint4 (32 weights) maps to exactly

@mergennachin mergennachin Jun 5, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the new contract, can you do runtime validation here on the new packing format and reject old formats?

And add a unit test to showcase that it successfully rejects an old format

Comment on lines +43 to +52
w = Int4Tensor(
qdata=w.qdata,
scale=w.scale.t().contiguous(),
zero_point=w.zero_point.t().contiguous(),
block_size=w.block_size,
shape=w.shape,
act_pre_scale=w.act_pre_scale,
activation_dtype=w.activation_dtype,
)
module.weight = nn.Parameter(w, requires_grad=False)

@mergennachin mergennachin Jun 5, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new pack-time/AOT layout is much better than the runtime transpose cache, but I don’t think we should represent it as a plain Int4Tensor anymore. Int4Tensor has a fixed strong contract already.

For example, if the newly format of Int4Tensor is created but it is not dispatched to int4_plain_mm somehow but accidentally goes through regular torchao kernel, this will be an issue. For example, int4_dispatch.py globally overrides Int4Tensor F.linear, so any native torchao Int4Tensor that bypasses pack_cuda.py can be interpreted as newly packed.

Instead of changing the fixed contract of Int4Tensor, I'd rather create a new tensor subclass in ET TorchAOBaseTensor and change int4_dispatch accordingly

wdyt @metascroy @digantdesai

@mergennachin mergennachin requested a review from metascroy June 5, 2026 15:02
@Gasoonjia Gasoonjia force-pushed the g4-int8-decode-op branch from 1a527d2 to 20d021d Compare June 5, 2026 20:00
@linux-foundation-easycla

linux-foundation-easycla Bot commented Jun 8, 2026

Copy link
Copy Markdown

CLA Signed
The committers listed above are authorized under a signed CLA.

@Gasoonjia Gasoonjia force-pushed the g4-opt-coalesced-scale branch from c9bc3fb to df47b27 Compare June 8, 2026 07:44
@Gasoonjia

Copy link
Copy Markdown
Contributor Author

Thanks @mergennachin for your comment. Have introduce a new int4 class CudaCoalescedInt4Tensor living in cuda backend, guarded by tests for mis-dispatch, and the PR also udpated int4_dispatch.py and pack_cuda.py to support the new class. Also the We can further update the tensor into executorch.extension.llm if mlx and other backend need this in the future.

@Gasoonjia Gasoonjia force-pushed the g4-opt-coalesced-scale branch 2 times, most recently from d3632f0 to 6ac4974 Compare June 8, 2026 08:42
@Gasoonjia

Copy link
Copy Markdown
Contributor Author

also add runtime check for layout format.

Base automatically changed from g4-int8-decode-op to main June 9, 2026 04:43
…decode

Coalesce int4 W4A8 decode-matvec scale/zero loads by baking the
[N, n_groups] layout into the weight constant at pack time. Introduces
CudaCoalescedInt4Tensor (an ExecuTorch-internal subclass) that owns the
[n_groups, N] -> [N, n_groups] transpose, registers the int4_plain_mm
dispatch on it by type, and adds the coalesced dp4a matvec kernel that
reads scale/zero row-for-row with qdata (single coalesced load vs 32
stride-N cache lines). ~29.2 -> 37.4 tok/s on gemma group_size=32.

Rebased onto main; INT8 dp4a decode op and the floor_div pass from this
branch landed separately and now live in quantize_op_dispatch/.
@Gasoonjia Gasoonjia force-pushed the g4-opt-coalesced-scale branch from 6ac4974 to 8e404c7 Compare June 9, 2026 06:41
@github-actions

github-actions Bot commented Jun 9, 2026

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/trunk CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: arm Issues related to arm backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants