Skip to content

[MLX][Gemma4] Introduce Q6K kernels#20004

Merged
metascroy merged 18 commits into
mainfrom
mlx-q6k
Jun 8, 2026
Merged

[MLX][Gemma4] Introduce Q6K kernels#20004
metascroy merged 18 commits into
mainfrom
mlx-q6k

Conversation

@metascroy

@metascroy metascroy commented Jun 4, 2026

Copy link
Copy Markdown
Contributor

Summary

Adds fused GGUF Q6_K custom Metal kernels to the MLX backend and wires them into the Gemma 4 31B GGUF export path, so Q6_K-quantized linear and embedding weights run directly from llama.cpp's packed block layout instead of taking the slow non-fused dequantize path. Also shrinks the exported .pte (and its in-memory footprint) by de-duplicating repeated kernel source blobs.

New custom kernel ops (backends/mlx/custom_kernel_ops/gguf/)

The gguf/ package is organized as format routers over per-format implementations, so new GGUF formats (e.g. Q4_K) can be added without touching the op definitions:

  • gguf/linear.py / gguf/embedding.py: thin format routers — each owns the op identity (mlx::gguf_linear / mlx::gguf_embedding: custom op, fake, and lowering registration) and dispatches on the format arg. Only "q6k" is supported today; other formats raise NotImplementedError.
  • gguf/q6k/common.py: shared Q6_K primitives — constants, the pure-torch dequantize_q6_k reference, and the Metal header (block_q6_K struct + dequant helpers). Lightweight (no builder import), re-exported from gguf/q6k/__init__.py.
  • gguf/q6k/linear.py: out = x @ dequant(weight)^T (+bias) against a raw GGUF block_q6_K blob (no repacking). Emits two Metal kernels — a fused mat-vec for decode (M==1, ported from llama.cpp kernel_mul_mv_q6_K_f32_impl) and a tiled simdgroup mat-mat for prefill (M>1). For dynamic/symbolic M, both chains are emitted and selected at runtime via a new IfNode.
  • gguf/q6k/embedding.py: gather counterpart that dequantizes Q6_K rows directly.

Runtime / schema

New IfNode in schema.fbs (runtime conditional selecting one of two instruction chains on an integer condition) plus exec_if dispatch in MLXInterpreter.h.

Serialization: smaller .pte + lower load-time RAM

  • Serializer de-duplicates identical strings into a single FlatBuffer offset (shared-string emission in the generated serializers / generate.py / mlx_graph_serialize.py). The big repeated MetalKernelNode source/header blobs are now written once. On Gemma 4 31B this cut the MLX graph metadata from ~1.23 MiB to ~0.47 MiB (~62%).
  • Loader interns those shared blobs into one std::shared_ptr<const std::string> keyed by the FlatBuffer string pointer (StringPool in MLXLoader.{h,cpp}.tmpl; MLXInterpreter.h derefs the handle), so a newly-produced .pte also uses less RAM at runtime.
  • Fully backward-compatible: no schema/format change. Old .pte files load unchanged (just without the dedup).

Gemma 4 31B GGUF loader (examples/models/gemma4_31b/)

  • iter_gguf_tensors now yields the tensor's quant type and can emit Q6_K tensors as the raw (N, n_blocks*210) uint8 blob (q6k_raw); added _raw_q6_k helper and made _unpack_q6_k accept an already-materialized tensor.
  • New mlx_gguf_linear.py carrier modules (GGUFLinear/GGUFEmbedding) and _handle_mlx_q6k routing: Linear weights → gguf_linear, token embedding → gguf_embedding, tied lm_head reuses the embedding blob via gguf_linear, with a quantized-tensor fallback for any other Q6_K module.
  • Removed the ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS env-var workaround in export.py since the fused path no longer needs it.

Refactor

  • Renamed backends/mlx/model_ops/backends/mlx/custom_kernel_ops/ (with a test/ subpackage) and updated all imports (turboquant_cache.py, qwen3_5_moe/mlx_source_transformations.py).

Test plan

  • New/updated unit tests: custom_kernel_ops/gguf/test/test_linear.py, test_embedding.py; backends/mlx/test/test_serialization_dedup.py (asserts identical source/header are written once); examples/models/gemma4_31b/quant/tests/test_gguf.py and examples/models/gemma4_31b/tests/test_mlx_pipeline.py.
  • CI (.github/workflows/mlx.yml) discovers op tests recursively (custom_kernel_ops/**/test/test_*.py) so per-format subpackage tests run with no per-op CI edit.

Run locally:

# Build the op runner once (per CI):
cmake --preset mlx-release -DEXECUTORCH_BUILD_TESTS=ON -DEXECUTORCH_MLX_ENABLE_SANITIZERS=OFF
cmake --build cmake-out --target op_test_runner -j

# GPU op tests (export + run on device):
python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear run -v
python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_embedding run -v

# Pure-Python checks:
python -m pytest backends/mlx/test/test_serialization_dedup.py \
  examples/models/gemma4_31b/quant/tests/test_gguf.py \
  examples/models/gemma4_31b/tests/test_mlx_pipeline.py -v

@pytorch-bot

pytorch-bot Bot commented Jun 4, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20004

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 26 Pending, 2 Unrelated Failures

As of commit 2743bd6 with merge base 12684ef (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 4, 2026
@github-actions

github-actions Bot commented Jun 4, 2026

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@mergennachin mergennachin self-requested a review June 5, 2026 19:50

@mergennachin mergennachin left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, see inline

path: str,
) -> Iterator[tuple[str, torch.Tensor]]:
"""Yield ``(name, result)`` for each tensor in a GGUF file.
q6k_raw: bool = False,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

np: instead of a bool, maybe just have a set of enum GGMLQuantizationType values that will be kept in raw so that it can be extended to other types like Q5_K

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do this as part of promotion work.

I think I'll probably make the loader use et::gguf_linear based on raw weights, and then have a configurable pass that coverts et::gguf_linear to other formats.

Comment on lines +174 to +175
byte layout so it can be consumed directly by the fused
``mlx::gguf_linear`` Metal kernel.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily Metal. Also, CUDA, right?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do promotion as mentioned below

return d, scales_16, qvals, _make_q6_k_block(d, scales_16, qvals)

def test_raw_blob_preserves_bytes(self):
from executorch.examples.models.gemma4_31b.quant.gguf import _raw_q6_k

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_raw_q6_k is an internal implementation detail. Test public API contracts (e.g., iter_gguf_tensors) as opposed to relying on internal private functions.

Comment on lines -474 to -475
if args.gguf and args.backend == "mlx":
os.environ["ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS"] = "1"

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think gemma4 is the only place, other than the generic support for ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS, which we can leave?

from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import Q6K_BLOCK_BYTES, QK_K


class GGUFLinear(nn.Module):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be promoted somewhere else other than examples/models/*?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could promote to extension/llm/export (next to nvfp4), and define ET eager ops for et::gguf_linear and et::gguf_embedding there. I would delete the mlx specific ones mlx::gguf_linear, and have MLX backend focus on emitting the new et::gguf_linear ops.

Comment on lines +9 to +12
Wrap raw GGUF ``block_q6_K`` blobs and dispatch to the fused ``mlx::gguf_linear``
(matmul) and ``mlx::gguf_embedding`` (gather) Metal kernels, instead of the slow
non-fused dequantize paths that group_size=16 affine quant takes through the MLX
``QUANTIZED_LINEAR`` / quantized-embedding patterns.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it have to be Metal specific?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can make them et namespace as part of promotion work. See comment above.

Compute is keyed on the activation dtype (matching GGUF/llama.cpp): the Metal
kernels are templated on ``InT``, accumulate in ``float32``, read ``d`` as
``half``, and produce output in the activation dtype.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either here or give credit/attribution to libraries?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Credit is given in code comments ("ported from llama.cpp"), but I'll put credit in one place at the top of the file.

@metascroy metascroy requested a review from larryliu0820 as a code owner June 7, 2026 18:59
@metascroy

Copy link
Copy Markdown
Contributor Author

@claude review this code

@claude

claude Bot commented Jun 7, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 2m 43s —— View job


Code Review: [MLX][Gemma4] Q6K Kernels

  • Gather context and read changed files
  • Review new kernel ops (gguf/q6k/)
  • Review serialization changes (dedup, schema)
  • Review runtime changes (IfNode, interpreter)
  • Review Gemma 4 31B integration
  • Review tests
  • Post final review

Overall this is a well-structured, cleanly organized PR. The architecture -- format routers dispatching to per-format implementations, with shared primitives kept lightweight -- is extensible and easy to navigate. The kernel porting from llama.cpp is well-attributed, and the runtime IfNode for dynamic M is a nice solution. Here's a detailed review:


Architecture & Organization

The new custom_kernel_ops/gguf/ package layout is well-thought-out:

  • patterns.py as a single registration point that dispatches by ggml_type
  • Per-format subpackages (q6k/, q4k/) with common.py, linear.py, embedding.py
  • Lazy imports of builder-heavy modules so common.py stays lightweight

The rename from model_ops/ to custom_kernel_ops/ is a good semantic improvement.


Issues & Suggestions

1. Empty file: backends/mlx/custom_kernel_ops/gguf/q4k.py (0 bytes)

This appears to be a leftover/placeholder. If it's not imported anywhere, it should probably be removed to avoid confusion with the q4k/ directory that lives alongside it.

2. Potential integer overflow in matvec grid calculation (q6k/linear.py:321-322)

n_r0 = 2
nsg = _Q6K_MV_NSG  # 4
num_row_groups = (N + nsg * n_r0 - 1) // (nsg * n_r0)
grid_x = num_row_groups * 32 * nsg

For very large N (e.g., vocab 262144 for lm_head), grid_x = (262144 / 8) * 128 = 4,194,304. This is fine for Metal's grid limits (~1B), but worth noting the test only goes to N=16384. The production vocab (262144) will produce grid_x = 4.2M which is well within limits -- just confirming this was considered.

3. Mat-mat kernel: threadgroup half sa[4096] fixed sizing (q6k/linear.py:166-167)

threadgroup half sa[4096];  // NR0 * NK storage (strided by 64)
threadgroup half sb[4096];  // NR1 * NK storage (strided by 64)

These are 4096 * 2 = 8KB each (16KB total threadgroup memory per threadgroup). Given Apple GPUs typically have 32KB threadgroup memory, this uses 50% -- comfortable but worth documenting if future tile-size tuning is planned.

4. Mat-mat kernel: M computation uses raw x_shape pointer (q6k/linear.py:176-177)

int M = 1;
for (uint d = 0; d + 1 < x_ndim; ++d) { M *= (int) x_shape[d]; }

This relies on x_ndim and x_shape being passed by the Metal kernel dispatch infrastructure. Since the mat-mat kernel is only emitted for static M > 1 or the dynamic path, this should always be populated -- but there's no bounds check on x_ndim in the kernel itself. If x_ndim were 0 (shouldn't happen), the loop is safe (never executes, M stays 1). Fine.

5. IfNode semantics: cond = M - 1 (q6k/linear.py:499-506)

The choice of cond = M - 1 is clean (nonzero => matmul, zero => matvec). The runtime exec_if correctly evaluates cond != 0. Just double-checking: if M is somehow 0 (empty batch), cond = -1, which is nonzero, so it takes the matmul path. The matmul kernel would then compute nr1 = (0 - r1) < NR1 which would clamp to (0 - 0) = 0, so the write loop doesn't execute. This is technically safe, though M=0 is likely unreachable given export constraints.

6. Serialization string dedup: relies on FlatBuffers' CreateSharedString (generate.py:636)

create = getattr(builder, "CreateSharedString", None) or builder.CreateString

Good defensive fallback for older flatbuffers versions. The test in test_serialization_dedup.py correctly validates the dedup behavior at the byte level.

7. StringPool in the loader uses the FlatBuffer string pointer as key (MLXLoader.h.tmpl:340-350)

auto& slot = map[static_cast<const void*>(s)];

This is correct: identical strings sharing a single FlatBuffer offset will have the same pointer, so dedup works for newly-serialized files. For old files (without dedup), each string instance gets its own key/entry -- correct, just not deduplicated.

8. Reviewer feedback alignment: promotion to extension/llm/export/

The new extension/llm/export/gguf.py and extension/llm/export/int4.py address mergennachin's review feedback about promoting GGUF handling out of examples/. The ExportableGGUFTensor subclass and torchao::dequantize_gguf custom op are now backend-agnostic (good). The reviewer's suggestion about using a set of enum values for raw quant types (instead of a bool) is reflected in the _GGML_ID_BY_TYPE / _TYPE_BY_GGML_ID design.

9. Test coverage is thorough

  • test_linear.py covers: decode (M=1), prefill (M=8/64/128), dynamic M (IfNode both branches), ragged N/M, real Gemma shapes, no-bias variants, Q4_K native path, bf16/fp16/fp32
  • test_embedding.py: Q6_K gather kernel
  • test_serialization_dedup.py: pins the string-sharing behavior
  • extension/llm/export/test/test_gguf.py: validates ExportableGGUFTensor dequant against the gguf package reference

One observation: test_linear.py production shapes (N=5376, K=21504) are substantial -- CI should be fine on the macos-14-xlarge runner but these may slow the GPU test loop. Consider whether timeout adjustments are needed.

10. Minor: dequantize_gguf eager body calls into gguf package (extension/llm/export/gguf.py:85-93)

The eager body calls gguf.dequantize via numpy, which is CPU-only and incurs a detach().cpu().numpy() -> from_numpy -> .to(device) round-trip. This is fine for tracing/testing (never runs on-device), but the import-time import gguf inside the function body prevents import errors if gguf isn't installed -- this is good.


Nits (non-blocking)

  • q6k/linear.py:9-41: The module docstring is thorough and helpful, but given the repo's "minimal comments" style (CLAUDE.md), you might consider trimming it. (Personal preference -- the attribution section is important to keep.)
  • The empty backends/mlx/custom_kernel_ops/gguf/q4k.py file (0 bytes) alongside the q4k/ directory is confusing -- likely a git artifact or placeholder. Remove if unused.
  • patterns.py:67: output_dtype = torch.bfloat16 as the default when no explicit dtype arg is present -- this matches the GGUF convention but is worth a brief inline note.

Summary

Strong PR. The kernel port from llama.cpp is well-attributed, the architecture is extensible (adding Q4_K or Q8_0 formats requires no changes to the router or schema), and the runtime IfNode for dynamic M is clean. The serialization dedup is a nice size win (62% on Gemma 4 31B metadata) with full backward compat. The test matrix is comprehensive and exercises both kernel correctness and the infrastructure (IfNode branching, string dedup, dynamic shapes).

Main action items from this review:

  1. Remove the empty q4k.py file if unused
  2. Consider the reviewer feedback about promoting ops to et:: namespace (mentioned in existing review thread)

@metascroy

Copy link
Copy Markdown
Contributor Author

@mergennachin as requested, I promoted GGUF handling out of examples. It is now in a backend-agnostic extension/llm/export/gguf.py, mirroring nvfp4.py.

This defines ExportableGGUFTensor (a TorchAOBaseTensor), the single loading representation that wraps the raw GGUF bytes and desugars linear/embedding to a torchao::dequantize_gguf so it survives export and backends can pattern-match/fuse it. ExportableGGUFTensor also has conversion methods to Int4Tensor / IntxUnpackedToInt8Tensor.

I also defined ExportableInt4Tensor and used this instead of IntxUnpackedToInt8Tensor in the MLX Int4 packer path.

MLX has new pattern matchers for ExportableGGUFTensor and ExportableInt4Tensor.

Now prequantized MLX lowering paths (SocialLocal Int4/Int8 and Q4KM GGUF) are more memory efficient and faster. On my computer, pte lowering time for Gemma4-31B on both paths is 3 minutes.

@mergennachin mergennachin left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, @metascroy

Especially the memory footprint during export and standardization across the board.

A few comments (and inline comments)

  • Can you update examples/models/gemma4_31b/model.md and README.md files? There might be some stale comments

  • In general, i was planning to promote everything in examples/models/gemma4_31b/quant/* -- not in this PR, but I'd like to get your thought on this direction

# Int4Tensor → IntxUnpackedToInt8Tensor conversion


def _int4_to_intx_unpacked(w: torch.Tensor) -> torch.Tensor:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't you need to update quant/tests/test_pack_mlx.py?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated. Also added test to mlx.yml, looks like it wasn't running

Comment on lines +90 to +92
def _is_embedding(model, model_key: str) -> bool:
parent = model.get_submodule(model_key.rsplit(".", 1)[0])
return isinstance(parent, torch.nn.Embedding)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this called?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed

Comment on lines +115 to +142
@implements([aten.linear.default])
def _(func, types, args, kwargs):
input_tensor, weight = args[0], args[1]
bias = args[2] if len(args) > 2 else None
return torch.nn.functional.linear(
input_tensor, weight.dequantize(input_tensor.dtype), bias
)


@implements([aten.embedding.default])
def _(func, types, args, kwargs):
weight, indices = args[0], args[1]
return torch.nn.functional.embedding(indices, weight.dequantize())


@implements([aten.t.default])
def _(func, types, args, kwargs):
return args[0].dequantize().t()


@implements([aten.detach.default, aten.alias.default])
def _(func, types, args, kwargs):
return args[0]


@implements([aten._to_copy.default])
def _(func, types, args, kwargs):
return args[0].dequantize(output_dtype=kwargs.get("dtype", args[0].orig_dtype))

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these duplicates in nvfp4/int4/gguf.py files?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are, but this is the typical pattern for tensor subclasses.

)


class TestGgufLinearMlx(unittest.TestCase):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a similar test for CUDA too? Seems like CUDA path (inference.py as well as export.py) for GGUF isn't tested (my bad) and with this refactoring may increase the risk of breaking.

@metascroy

Copy link
Copy Markdown
Contributor Author

Great work, @metascroy

Especially the memory footprint during export and standardization across the board.

A few comments (and inline comments)

  • Can you update examples/models/gemma4_31b/model.md and README.md files? There might be some stale comments
  • In general, i was planning to promote everything in examples/models/gemma4_31b/quant/* -- not in this PR, but I'd like to get your thought on this direction

Updated docs.

On promotion of other methods in quant, I think some of them are promotable, but others look like local implementations of existing torchao methods. For example,

  • dequantize_weight can be accomplished by just calling t.dequantize() on the subclass; we don't need to redefine it.

  • _to_intx_tensor, _to_int4_tensor, and the quantization methods _quantize_hqq_symmetric look like re-implementations of similar methods in torchao, perhaps to address the bf16/fp32 discrepancy. But I think the fix, if any, belongs in torchao. (I'm pretty sure it's not fp32/bf16, but eps (can be user passed today) that causes the issue. Affine/HQQ torchao algorithms already use fp32 internally.)

  • quantize_model looks like a re-implementation of quantize_ using QuantRecipe instead of the torchao configs.

If by promote, you mean promote fixes to torchao, then I'm on board, but generally I want to make sure we're not re-inventing what already exists.

@metascroy metascroy merged commit a9d5674 into main Jun 8, 2026
264 of 271 checks passed
@metascroy metascroy deleted the mlx-q6k branch June 8, 2026 21:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/cuda CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants