[cuda backend] store scale/zero in int4_plain_mm in [N, n_groups] layout#20038
[cuda backend] store scale/zero in int4_plain_mm in [N, n_groups] layout#20038Gasoonjia wants to merge 1 commit into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20038
Note: Links to docs will display an error until the docs builds have been completed. ❌ 4 New Failures, 1 PendingAs of commit 8e404c7 with merge base a79f3e4 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| // Reads scale/zero in the transposed [N, n_groups] layout (transposed AOT at | ||
| // export time). With group_size >= 32, one uint4 (32 weights) maps to exactly |
There was a problem hiding this comment.
If this is the new contract, can you do runtime validation here on the new packing format and reject old formats?
And add a unit test to showcase that it successfully rejects an old format
| w = Int4Tensor( | ||
| qdata=w.qdata, | ||
| scale=w.scale.t().contiguous(), | ||
| zero_point=w.zero_point.t().contiguous(), | ||
| block_size=w.block_size, | ||
| shape=w.shape, | ||
| act_pre_scale=w.act_pre_scale, | ||
| activation_dtype=w.activation_dtype, | ||
| ) | ||
| module.weight = nn.Parameter(w, requires_grad=False) |
There was a problem hiding this comment.
The new pack-time/AOT layout is much better than the runtime transpose cache, but I don’t think we should represent it as a plain Int4Tensor anymore. Int4Tensor has a fixed strong contract already.
For example, if the newly format of Int4Tensor is created but it is not dispatched to int4_plain_mm somehow but accidentally goes through regular torchao kernel, this will be an issue. For example, int4_dispatch.py globally overrides Int4Tensor F.linear, so any native torchao Int4Tensor that bypasses pack_cuda.py can be interpreted as newly packed.
Instead of changing the fixed contract of Int4Tensor, I'd rather create a new tensor subclass in ET TorchAOBaseTensor and change int4_dispatch accordingly
wdyt @metascroy @digantdesai
1a527d2 to
20d021d
Compare
c9bc3fb to
df47b27
Compare
|
Thanks @mergennachin for your comment. Have introduce a new int4 class |
d80b1b6 to
99f20f8
Compare
d3632f0 to
6ac4974
Compare
|
also add runtime check for layout format. |
…decode Coalesce int4 W4A8 decode-matvec scale/zero loads by baking the [N, n_groups] layout into the weight constant at pack time. Introduces CudaCoalescedInt4Tensor (an ExecuTorch-internal subclass) that owns the [n_groups, N] -> [N, n_groups] transpose, registers the int4_plain_mm dispatch on it by type, and adds the coalesced dp4a matvec kernel that reads scale/zero row-for-row with qdata (single coalesced load vs 32 stride-N cache lines). ~29.2 -> 37.4 tok/s on gemma group_size=32. Rebased onto main; INT8 dp4a decode op and the floor_div pass from this branch landed separately and now live in quantize_op_dispatch/.
6ac4974 to
8e404c7
Compare
This PR needs a
|
This PR updates int4_plain_mm in cuda backend to reads scale/zero in the transposed [N, n_groups] layout instead of [n_groups, N]. In this way every warp can load both scale and zero together in one cache line, instead of 32 cache lines previously.
gemma4-31b decode perf: ~27 token/s -> 37.36 token/s.
cc @digantdesai @freddan80 @per @zingo @oscarandersson8218 @mansnils @Sebastian-Larsson @robell @rascani