Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20004
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 26 Pending, 2 Unrelated FailuresAs of commit 2743bd6 with merge base 12684ef ( NEW FAILURE - The following job has failed:
FLAKY - The following jobs failed but were likely due to flakiness present on trunk:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
| path: str, | ||
| ) -> Iterator[tuple[str, torch.Tensor]]: | ||
| """Yield ``(name, result)`` for each tensor in a GGUF file. | ||
| q6k_raw: bool = False, |
There was a problem hiding this comment.
np: instead of a bool, maybe just have a set of enum GGMLQuantizationType values that will be kept in raw so that it can be extended to other types like Q5_K
There was a problem hiding this comment.
I'll do this as part of promotion work.
I think I'll probably make the loader use et::gguf_linear based on raw weights, and then have a configurable pass that coverts et::gguf_linear to other formats.
| byte layout so it can be consumed directly by the fused | ||
| ``mlx::gguf_linear`` Metal kernel. |
There was a problem hiding this comment.
Not necessarily Metal. Also, CUDA, right?
There was a problem hiding this comment.
I'll do promotion as mentioned below
| return d, scales_16, qvals, _make_q6_k_block(d, scales_16, qvals) | ||
|
|
||
| def test_raw_blob_preserves_bytes(self): | ||
| from executorch.examples.models.gemma4_31b.quant.gguf import _raw_q6_k |
There was a problem hiding this comment.
_raw_q6_k is an internal implementation detail. Test public API contracts (e.g., iter_gguf_tensors) as opposed to relying on internal private functions.
| if args.gguf and args.backend == "mlx": | ||
| os.environ["ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS"] = "1" |
There was a problem hiding this comment.
There was a problem hiding this comment.
I think gemma4 is the only place, other than the generic support for ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS, which we can leave?
| from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import Q6K_BLOCK_BYTES, QK_K | ||
|
|
||
|
|
||
| class GGUFLinear(nn.Module): |
There was a problem hiding this comment.
should this be promoted somewhere else other than examples/models/*?
There was a problem hiding this comment.
I could promote to extension/llm/export (next to nvfp4), and define ET eager ops for et::gguf_linear and et::gguf_embedding there. I would delete the mlx specific ones mlx::gguf_linear, and have MLX backend focus on emitting the new et::gguf_linear ops.
| Wrap raw GGUF ``block_q6_K`` blobs and dispatch to the fused ``mlx::gguf_linear`` | ||
| (matmul) and ``mlx::gguf_embedding`` (gather) Metal kernels, instead of the slow | ||
| non-fused dequantize paths that group_size=16 affine quant takes through the MLX | ||
| ``QUANTIZED_LINEAR`` / quantized-embedding patterns. |
There was a problem hiding this comment.
Does it have to be Metal specific?
There was a problem hiding this comment.
I can make them et namespace as part of promotion work. See comment above.
| Compute is keyed on the activation dtype (matching GGUF/llama.cpp): the Metal | ||
| kernels are templated on ``InT``, accumulate in ``float32``, read ``d`` as | ||
| ``half``, and produce output in the activation dtype. | ||
|
|
There was a problem hiding this comment.
Either here or give credit/attribution to libraries?
There was a problem hiding this comment.
Credit is given in code comments ("ported from llama.cpp"), but I'll put credit in one place at the top of the file.
|
@claude review this code |
|
Claude finished @metascroy's task in 2m 43s —— View job Code Review: [MLX][Gemma4] Q6K Kernels
Overall this is a well-structured, cleanly organized PR. The architecture -- format routers dispatching to per-format implementations, with shared primitives kept lightweight -- is extensible and easy to navigate. The kernel porting from llama.cpp is well-attributed, and the runtime Architecture & OrganizationThe new
The rename from Issues & Suggestions1. Empty file: This appears to be a leftover/placeholder. If it's not imported anywhere, it should probably be removed to avoid confusion with the 2. Potential integer overflow in matvec grid calculation ( n_r0 = 2
nsg = _Q6K_MV_NSG # 4
num_row_groups = (N + nsg * n_r0 - 1) // (nsg * n_r0)
grid_x = num_row_groups * 32 * nsgFor very large N (e.g., vocab 262144 for lm_head), 3. Mat-mat kernel: threadgroup half sa[4096]; // NR0 * NK storage (strided by 64)
threadgroup half sb[4096]; // NR1 * NK storage (strided by 64)These are 4096 * 2 = 8KB each (16KB total threadgroup memory per threadgroup). Given Apple GPUs typically have 32KB threadgroup memory, this uses 50% -- comfortable but worth documenting if future tile-size tuning is planned. 4. Mat-mat kernel: M computation uses raw int M = 1;
for (uint d = 0; d + 1 < x_ndim; ++d) { M *= (int) x_shape[d]; }This relies on 5. The choice of 6. Serialization string dedup: relies on FlatBuffers' create = getattr(builder, "CreateSharedString", None) or builder.CreateStringGood defensive fallback for older flatbuffers versions. The test in 7. auto& slot = map[static_cast<const void*>(s)];This is correct: identical strings sharing a single FlatBuffer offset will have the same pointer, so dedup works for newly-serialized files. For old files (without dedup), each string instance gets its own key/entry -- correct, just not deduplicated. 8. Reviewer feedback alignment: promotion to The new 9. Test coverage is thorough
One observation: 10. Minor: The eager body calls Nits (non-blocking)
SummaryStrong PR. The kernel port from llama.cpp is well-attributed, the architecture is extensible (adding Q4_K or Q8_0 formats requires no changes to the router or schema), and the runtime IfNode for dynamic M is clean. The serialization dedup is a nice size win (62% on Gemma 4 31B metadata) with full backward compat. The test matrix is comprehensive and exercises both kernel correctness and the infrastructure (IfNode branching, string dedup, dynamic shapes). Main action items from this review:
|
|
@mergennachin as requested, I promoted GGUF handling out of examples. It is now in a backend-agnostic extension/llm/export/gguf.py, mirroring nvfp4.py. This defines ExportableGGUFTensor (a TorchAOBaseTensor), the single loading representation that wraps the raw GGUF bytes and desugars linear/embedding to a torchao::dequantize_gguf so it survives export and backends can pattern-match/fuse it. ExportableGGUFTensor also has conversion methods to Int4Tensor / IntxUnpackedToInt8Tensor. I also defined ExportableInt4Tensor and used this instead of IntxUnpackedToInt8Tensor in the MLX Int4 packer path. MLX has new pattern matchers for ExportableGGUFTensor and ExportableInt4Tensor. Now prequantized MLX lowering paths (SocialLocal Int4/Int8 and Q4KM GGUF) are more memory efficient and faster. On my computer, pte lowering time for Gemma4-31B on both paths is 3 minutes. |
mergennachin
left a comment
There was a problem hiding this comment.
Great work, @metascroy
Especially the memory footprint during export and standardization across the board.
A few comments (and inline comments)
-
Can you update examples/models/gemma4_31b/model.md and README.md files? There might be some stale comments
-
In general, i was planning to promote everything in examples/models/gemma4_31b/quant/* -- not in this PR, but I'd like to get your thought on this direction
| # Int4Tensor → IntxUnpackedToInt8Tensor conversion | ||
|
|
||
|
|
||
| def _int4_to_intx_unpacked(w: torch.Tensor) -> torch.Tensor: |
There was a problem hiding this comment.
Don't you need to update quant/tests/test_pack_mlx.py?
There was a problem hiding this comment.
Updated. Also added test to mlx.yml, looks like it wasn't running
| def _is_embedding(model, model_key: str) -> bool: | ||
| parent = model.get_submodule(model_key.rsplit(".", 1)[0]) | ||
| return isinstance(parent, torch.nn.Embedding) |
| @implements([aten.linear.default]) | ||
| def _(func, types, args, kwargs): | ||
| input_tensor, weight = args[0], args[1] | ||
| bias = args[2] if len(args) > 2 else None | ||
| return torch.nn.functional.linear( | ||
| input_tensor, weight.dequantize(input_tensor.dtype), bias | ||
| ) | ||
|
|
||
|
|
||
| @implements([aten.embedding.default]) | ||
| def _(func, types, args, kwargs): | ||
| weight, indices = args[0], args[1] | ||
| return torch.nn.functional.embedding(indices, weight.dequantize()) | ||
|
|
||
|
|
||
| @implements([aten.t.default]) | ||
| def _(func, types, args, kwargs): | ||
| return args[0].dequantize().t() | ||
|
|
||
|
|
||
| @implements([aten.detach.default, aten.alias.default]) | ||
| def _(func, types, args, kwargs): | ||
| return args[0] | ||
|
|
||
|
|
||
| @implements([aten._to_copy.default]) | ||
| def _(func, types, args, kwargs): | ||
| return args[0].dequantize(output_dtype=kwargs.get("dtype", args[0].orig_dtype)) |
There was a problem hiding this comment.
Are these duplicates in nvfp4/int4/gguf.py files?
There was a problem hiding this comment.
They are, but this is the typical pattern for tensor subclasses.
| ) | ||
|
|
||
|
|
||
| class TestGgufLinearMlx(unittest.TestCase): |
There was a problem hiding this comment.
Can you add a similar test for CUDA too? Seems like CUDA path (inference.py as well as export.py) for GGUF isn't tested (my bad) and with this refactoring may increase the risk of breaking.
Updated docs. On promotion of other methods in quant, I think some of them are promotable, but others look like local implementations of existing torchao methods. For example,
If by promote, you mean promote fixes to torchao, then I'm on board, but generally I want to make sure we're not re-inventing what already exists. |
Summary
Adds fused GGUF Q6_K custom Metal kernels to the MLX backend and wires them into the Gemma 4 31B GGUF export path, so Q6_K-quantized linear and embedding weights run directly from llama.cpp's packed block layout instead of taking the slow non-fused dequantize path. Also shrinks the exported
.pte(and its in-memory footprint) by de-duplicating repeated kernel source blobs.New custom kernel ops (
backends/mlx/custom_kernel_ops/gguf/)The
gguf/package is organized as format routers over per-format implementations, so new GGUF formats (e.g. Q4_K) can be added without touching the op definitions:gguf/linear.py/gguf/embedding.py: thin format routers — each owns the op identity (mlx::gguf_linear/mlx::gguf_embedding: custom op, fake, and lowering registration) and dispatches on theformatarg. Only"q6k"is supported today; other formats raiseNotImplementedError.gguf/q6k/common.py: shared Q6_K primitives — constants, the pure-torchdequantize_q6_kreference, and the Metal header (block_q6_Kstruct + dequant helpers). Lightweight (no builder import), re-exported fromgguf/q6k/__init__.py.gguf/q6k/linear.py:out = x @ dequant(weight)^T (+bias)against a raw GGUFblock_q6_Kblob (no repacking). Emits two Metal kernels — a fused mat-vec for decode (M==1, ported from llama.cppkernel_mul_mv_q6_K_f32_impl) and a tiled simdgroup mat-mat for prefill (M>1). For dynamic/symbolicM, both chains are emitted and selected at runtime via a newIfNode.gguf/q6k/embedding.py: gather counterpart that dequantizes Q6_K rows directly.Runtime / schema
New
IfNodeinschema.fbs(runtime conditional selecting one of two instruction chains on an integer condition) plusexec_ifdispatch inMLXInterpreter.h.Serialization: smaller
.pte+ lower load-time RAMgenerate.py/mlx_graph_serialize.py). The big repeatedMetalKernelNodesource/header blobs are now written once. On Gemma 4 31B this cut the MLX graph metadata from ~1.23 MiB to ~0.47 MiB (~62%).std::shared_ptr<const std::string>keyed by the FlatBuffer string pointer (StringPoolinMLXLoader.{h,cpp}.tmpl;MLXInterpreter.hderefs the handle), so a newly-produced.ptealso uses less RAM at runtime..ptefiles load unchanged (just without the dedup).Gemma 4 31B GGUF loader (
examples/models/gemma4_31b/)iter_gguf_tensorsnow yields the tensor's quant type and can emit Q6_K tensors as the raw(N, n_blocks*210)uint8 blob (q6k_raw); added_raw_q6_khelper and made_unpack_q6_kaccept an already-materialized tensor.mlx_gguf_linear.pycarrier modules (GGUFLinear/GGUFEmbedding) and_handle_mlx_q6krouting: Linear weights →gguf_linear, token embedding →gguf_embedding, tied lm_head reuses the embedding blob viagguf_linear, with a quantized-tensor fallback for any other Q6_K module.ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPSenv-var workaround inexport.pysince the fused path no longer needs it.Refactor
backends/mlx/model_ops/→backends/mlx/custom_kernel_ops/(with atest/subpackage) and updated all imports (turboquant_cache.py,qwen3_5_moe/mlx_source_transformations.py).Test plan
custom_kernel_ops/gguf/test/test_linear.py,test_embedding.py;backends/mlx/test/test_serialization_dedup.py(asserts identical source/header are written once);examples/models/gemma4_31b/quant/tests/test_gguf.pyandexamples/models/gemma4_31b/tests/test_mlx_pipeline.py..github/workflows/mlx.yml) discovers op tests recursively (custom_kernel_ops/**/test/test_*.py) so per-format subpackage tests run with no per-op CI edit.Run locally: