diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 38914f7612b..bd6c8f3ed06 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -13,6 +13,7 @@ on: - backends/mlx/** - extension/llm/export/** - extension/audio/** + - examples/models/gemma4_31b/** - examples/models/parakeet/** - examples/models/voxtral_realtime/** - examples/models/qwen3_5_moe/** @@ -77,6 +78,8 @@ jobs: backends/mlx/test/test_passes.py \ backends/mlx/test/test_pattern_utils.py \ backends/mlx/test/test_partitioner.py \ + backends/mlx/test/test_serialization_dedup.py \ + examples/models/gemma4_31b/quant/tests/test_pack_mlx.py \ examples/models/gemma4_31b/tests/test_mlx_pipeline.py \ -v echo "::endgroup::" @@ -89,20 +92,16 @@ jobs: ./cmake-out/backends/mlx/test/multi_thread_test_runner echo "::endgroup::" - echo "::group::Run gated_delta_rule op tests" - ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v - echo "::endgroup::" - - echo "::group::Run tq_norm op tests" - ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_norm run -v - echo "::endgroup::" - - echo "::group::Run tq4_compress op tests" - ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq4_compress run -v - echo "::endgroup::" - - echo "::group::Run tq_dequant op tests" - ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_dequant run -v + echo "::group::Run custom_kernel_ops op tests" + # Run every custom_kernel_ops/**/test/test_*.py via its OpTestCase `run` + # CLI. Recurses into per-format subpackages (e.g. gguf/test), so adding a + # new op test file requires no change here. + set -e + for t in $(find backends/mlx/custom_kernel_ops -path '*/test/test_*.py' | sort); do + mod="executorch.$(echo "${t%.py}" | tr '/' '.')" + echo "--- ${mod} ---" + ${CONDA_RUN} python -m "${mod}" run -v + done echo "::endgroup::" test-mlx-qwen35-moe: diff --git a/backends/mlx/builder/op_helpers.py b/backends/mlx/builder/op_helpers.py index be199f75340..2f94a808adc 100644 --- a/backends/mlx/builder/op_helpers.py +++ b/backends/mlx/builder/op_helpers.py @@ -329,6 +329,79 @@ def emit_quantized_biases( return biases +def emit_quantized_gather( + P: MLXProgramBuilder, + out: Slot, + indices_slot: Slot, + qdata_slot: Slot, + scales_slot: Slot, + biases_slot: Optional[Slot], + *, + group_size: int, + bits: int, + mode: str, + out_dtype: torch.dtype, +) -> None: + """Gather quantized rows by index and dequantize them into ``out``. + + Emits ``TakeNode`` for qdata and scales (and biases when present), then a + ``DequantizeNode``. + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + DequantizeNode, + IntOrVidOrTid, + TakeNode, + ) + + ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(indices_slot)) + + _, wq_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(qdata_slot), + index=ids_index, + out=P.slot_to_tid(wq_sel), + axis=0, + ) + ) + + _, sc_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(scales_slot), + index=ids_index, + out=P.slot_to_tid(sc_sel), + axis=0, + ) + ) + + biases_tid = None + if biases_slot is not None: + _, b_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(biases_slot), + index=ids_index, + out=P.slot_to_tid(b_sel), + axis=0, + ) + ) + biases_tid = P.slot_to_tid(b_sel) + + P.emit( + DequantizeNode( + w=P.slot_to_tid(wq_sel), + scales=P.slot_to_tid(sc_sel), + out=P.slot_to_tid(out), + biases=biases_tid, + group_size=group_size, + bits=bits, + mode=mode, + dtype=torch_dtype_to_scalar_type(out_dtype), + ) + ) + + def to_mlx_qparams( qdata: torch.Tensor, scale: torch.Tensor, @@ -421,6 +494,34 @@ def parse_dequant_nvfp4_node( return qdata, scale, per_tensor_scale, output_dtype +def parse_dequant_int4_node( + node: Node, +) -> Optional[Tuple[Node, Node, Node, int, Optional[torch.dtype]]]: + """Parse a torchao.dequantize_int4_tensor node. + + Returns (qdata, scale, zero_point, group_size, output_dtype) or None if not a + dequantize_int4_tensor node or the custom op is not registered. + """ + target = get_aten_target(node.target) + try: + import executorch.extension.llm.export.int4 # noqa: F401 + except ImportError: + return None + + if target is not torch.ops.torchao.dequantize_int4_tensor.default: + return None + + qdata, scale, zero_point, group_size = node.args[0:4] + + output_dtype = None + if len(node.args) > 4: + output_dtype = node.args[4] + elif "output_dtype" in node.kwargs: + output_dtype = node.kwargs["output_dtype"] + + return qdata, scale, zero_point, group_size, output_dtype + + def parse_dequant_node( node: Node, ) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]: diff --git a/backends/mlx/model_ops/__init__.py b/backends/mlx/custom_kernel_ops/__init__.py similarity index 100% rename from backends/mlx/model_ops/__init__.py rename to backends/mlx/custom_kernel_ops/__init__.py diff --git a/backends/mlx/model_ops/gated_delta_rule.py b/backends/mlx/custom_kernel_ops/gated_delta_rule.py similarity index 100% rename from backends/mlx/model_ops/gated_delta_rule.py rename to backends/mlx/custom_kernel_ops/gated_delta_rule.py diff --git a/backends/mlx/custom_kernel_ops/gguf/__init__.py b/backends/mlx/custom_kernel_ops/gguf/__init__.py new file mode 100644 index 00000000000..1b6c1c5373c --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/__init__.py @@ -0,0 +1,18 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""GGUF-quantized weight lowering for the MLX backend. + +Import :mod:`.patterns` for its side effect to enable lowering of +``torchao::dequantize_gguf -> linear/embedding`` to the Q6_K / Q4_K kernels:: + + import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 + +This ``__init__`` is side-effect free, so importing ``.q6k`` for the pure-torch +dequant does not pull in the MLX builder/registry. +""" diff --git a/backends/mlx/custom_kernel_ops/gguf/patterns.py b/backends/mlx/custom_kernel_ops/gguf/patterns.py new file mode 100644 index 00000000000..7d3a5bc307c --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/patterns.py @@ -0,0 +1,167 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""MLX pattern handlers for GGUF-quantized weights. + +``ExportableGGUFTensor`` (extension/llm/export/gguf.py) lowers a quantized +linear/embedding to:: + + linear(x, torchao::dequantize_gguf(weight, ggml_type, out_dtype), bias) + embedding(torchao::dequantize_gguf(weight, ggml_type, out_dtype), indices) + +These handlers match that ``dequantize_gguf -> linear/embedding`` subgraph and +lower it without materializing the dequantized weight: + +* **Q6_K** -> fused custom Metal kernels in :mod:`.q6k`. +* **Q4_K** -> MLX's native 4-bit affine ops via :mod:`.q4k` (GGUF blocks + repacked into MLX qparams at export time). + +Both cover linear and embedding. + +Other quant types are left unmatched (the caller is expected to convert them to a +torchao ``Int4Tensor`` / ``IntxUnpackedToInt8Tensor`` first). + +Importing this module registers the patterns as a side effect. +""" + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +from executorch.backends.mlx.builder.op_helpers import get_aten_target +from executorch.backends.mlx.builder.op_registry import PatternHandler, REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.pattern_utils import has_single_user, match_target +from torch.export.exported_program import ExportedProgram +from torch.fx.node import Node + +# Quant types each pattern can lower (Q6_K via custom Metal kernels, Q4_K via +# MLX-native affine ops). +_LINEAR_TYPES = {"q4_k", "q6_k"} +_EMBEDDING_TYPES = {"q4_k", "q6_k"} + + +def parse_dequantize_gguf_node( + node: Node, +) -> Optional[Tuple[Node, str, torch.dtype]]: + """Parse a ``torchao::dequantize_gguf`` node. + + Returns ``(weight_node, ggml_type, output_dtype)`` or ``None`` if ``node`` is + not a ``dequantize_gguf`` node (or the op isn't registered). + """ + try: + import executorch.extension.llm.export.gguf # noqa: F401 registers the op + except ImportError: + return None + + if get_aten_target(node.target) is not torch.ops.torchao.dequantize_gguf.default: + return None + + weight = node.args[0] + ggml_type = node.args[1] + output_dtype = torch.bfloat16 + if len(node.args) > 2: + output_dtype = node.args[2] + elif "output_dtype" in node.kwargs: + output_dtype = node.kwargs["output_dtype"] + return weight, ggml_type, output_dtype + + +@REGISTRY.register_pattern(name="GGUF_QUANTIZED_LINEAR") +class GGUFQuantizedLinearHandler(PatternHandler): + """Lower ``dequantize_gguf + linear`` to a fused quantized matmul. + + Matches ``linear(x, dequantize_gguf(weight, ggml_type, out_dtype), bias)`` + and dispatches on ``ggml_type``: Q6_K -> custom Metal kernels, Q4_K -> MLX + 4-bit ``quantized_matmul``. + """ + + def __init__(self, head, body, weight, ggml_type, output_dtype): + super().__init__(head, body) + self.weight = weight + self.ggml_type = ggml_type + self.output_dtype = output_dtype + + @classmethod + def maybe_create(cls, ep: ExportedProgram, head: Node): + if not match_target(head, torch.ops.aten.linear.default): + return None + if len(head.args) < 2 or not isinstance(head.args[1], Node): + return None + dequant = head.args[1] + if not has_single_user(dequant): + return None + parsed = parse_dequantize_gguf_node(dequant) + if parsed is None: + return None + weight, ggml_type, output_dtype = parsed + if ggml_type not in _LINEAR_TYPES: + return None + return cls(head, [dequant], weight, ggml_type, output_dtype) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + x_node = n.args[0] + bias_node = n.args[2] if len(n.args) > 2 else None + if self.ggml_type == "q6_k": + from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.linear import ( + emit_linear, + ) + else: # q4_k + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear import ( + emit_linear, + ) + return emit_linear(P, n, x_node, self.weight, bias_node) + + +@REGISTRY.register_pattern(name="GGUF_QUANTIZED_EMBEDDING") +class GGUFQuantizedEmbeddingHandler(PatternHandler): + """Lower ``dequantize_gguf + embedding`` to a quantized gather. + + Matches ``embedding(dequantize_gguf(weight, ggml_type, out_dtype), indices)`` + and dispatches on ``ggml_type``: Q6_K -> custom Metal gather, Q4_K -> MLX + quantized gather. + """ + + def __init__(self, head, body, weight, ggml_type, output_dtype): + super().__init__(head, body) + self.weight = weight + self.ggml_type = ggml_type + self.output_dtype = output_dtype + + @classmethod + def maybe_create(cls, ep: ExportedProgram, head: Node): + if not match_target(head, torch.ops.aten.embedding.default): + return None + if len(head.args) < 2 or not isinstance(head.args[0], Node): + return None + dequant = head.args[0] + if not has_single_user(dequant): + return None + parsed = parse_dequantize_gguf_node(dequant) + if parsed is None: + return None + weight, ggml_type, output_dtype = parsed + if ggml_type not in _EMBEDDING_TYPES: + return None + return cls(head, [dequant], weight, ggml_type, output_dtype) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + indices_node = n.args[1] + if self.ggml_type == "q6_k": + from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.embedding import ( + emit_embedding, + ) + else: # q4_k + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding import ( + emit_embedding, + ) + return emit_embedding(P, n, self.weight, indices_node, self.output_dtype) diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py b/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py new file mode 100644 index 00000000000..6f89cfe2c82 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py @@ -0,0 +1,14 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""GGUF Q4_K format lowering for the MLX backend (native affine 4-bit). + +See :mod:`.linear` / :mod:`.embedding` for the ``emit_*`` lowerings (called by +``custom_kernel_ops.gguf.patterns``); they are not imported here to keep the +package import light. +""" diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/common.py b/backends/mlx/custom_kernel_ops/gguf/q4k/common.py new file mode 100644 index 00000000000..d58a8b71afd --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/common.py @@ -0,0 +1,46 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""Shared Q4_K -> MLX qparam repack for the Q4_K lowering. + +Q4_K maps cleanly onto MLX's affine 4-bit kernels (group_size 32): the GGUF +blocks are unpacked to the torchao ``IntxUnpackedToInt8Tensor`` layout and +repacked into MLX qparams (``S * Q + B``) at export time, so the weight is +stored MLX-ready and decoded by MLX itself. +""" + +from __future__ import annotations + +from typing import Tuple + +from executorch.backends.mlx.builder.op_helpers import to_mlx_qparams +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from torch.fx.node import Node + +_BITS = 4 + + +def _repack_mlx( + P: MLXProgramBuilder, weight_node: Node +) -> Tuple[Slot, Slot, Slot, int]: + """Unpack a raw Q4_K blob and repack into MLX qparam constants. + + Returns ``(packed_slot, scales_slot, biases_slot, group_size)``. + """ + from executorch.extension.llm.export.gguf import ExportableGGUFTensor + + weight_target, raw = P.get_placeholder_target_and_tensor(weight_node) + intx = ExportableGGUFTensor.from_raw(raw, "q4_k").to_intx_unpacked_to_int8_tensor() + group_size = int(intx.block_size[-1]) + packed, biases = to_mlx_qparams(intx.qdata, intx.scale, intx.zero_point, _BITS) + + packed_slot = P.make_or_get_constant(f"{weight_target}_q4k_packed", packed) + scales_slot = P.make_or_get_constant(f"{weight_target}_q4k_scales", intx.scale) + biases_slot = P.make_or_get_constant(f"{weight_target}_q4k_biases", biases) + return packed_slot, scales_slot, biases_slot, group_size diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py b/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py new file mode 100644 index 00000000000..7b5bbcff0e1 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py @@ -0,0 +1,55 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""GGUF **Q4_K** embedding lowering via MLX's native 4-bit quantized gather. + +Lowers a ``dequantize_gguf -> embedding`` pattern to a quantized gather: gather +the packed quants / scales / biases by index, then dequantize the gathered rows +(``DequantizeNode``, mode "affine"). The GGUF blob is repacked into MLX qparams +at export time (see :mod:`.common`). +""" + +from __future__ import annotations + +from executorch.backends.mlx.builder.op_helpers import emit_quantized_gather +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.common import _BITS, _repack_mlx +from torch.fx.node import Node + + +def emit_embedding( + P: MLXProgramBuilder, + head: Node, + weight_node: Node, + indices_node: Node, + output_dtype, +) -> Slot: + """Lower a Q4_K ``dequantize_gguf -> embedding`` pattern to a quantized gather. + + Gathers the packed quants / scales / biases by index, then dequantizes the + gathered rows (MLX affine 4-bit) -- the same shape as MLX's generic quantized + embedding. + """ + w_slot, scales_slot, biases_slot, group_size = _repack_mlx(P, weight_node) + (indices_slot,) = P.slot_map([indices_node]) + + out = P.make_or_get_slot(head) + emit_quantized_gather( + P, + out, + indices_slot, + w_slot, + scales_slot, + biases_slot, + group_size=group_size, + bits=_BITS, + mode="affine", + out_dtype=output_dtype, + ) + return out diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py b/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py new file mode 100644 index 00000000000..41d032a2d4a --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py @@ -0,0 +1,82 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""GGUF **Q4_K** linear lowering via MLX's native 4-bit quantized matmul. + +Lowers a ``dequantize_gguf -> linear`` pattern to a ``QuantizedMatmulNode`` +(mode "affine", group_size 32); the GGUF blob is repacked into MLX qparams at +export time (see :mod:`.common`). +""" + +from __future__ import annotations + +from typing import Optional + +from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.common import _BITS, _repack_mlx +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AddNode, + AsTypeNode, + QuantizedMatmulNode, +) +from torch.fx.node import Node + + +def emit_linear( + P: MLXProgramBuilder, + head: Node, + x_node: Node, + weight_node: Node, + bias_node: Optional[Node], +) -> Slot: + """Lower a Q4_K ``dequantize_gguf -> linear`` pattern to MLX 4-bit matmul. + + ``weight_node`` is the raw GGUF blob constant; ``head`` is the ``aten.linear`` + node. The blob is repacked into MLX qparams at export time, so only the + MLX-format constants are serialized. + """ + w_slot, scales_slot, biases_slot, group_size = _repack_mlx(P, weight_node) + x_slot, bias_slot = P.slot_map([x_node, bias_node]) + + out = P.make_or_get_slot(head) + P.emit( + QuantizedMatmulNode( + x=P.slot_to_tid(x_slot), + w=P.slot_to_tid(w_slot), + scales=P.slot_to_tid(scales_slot), + biases=P.slot_to_tid(biases_slot), + out=P.slot_to_tid(out), + group_size=group_size, + bits=_BITS, + mode="affine", + transpose=True, + ) + ) + + if bias_node is not None: + P.emit( + AddNode( + a=P.slot_to_tid(out), + b=P.slot_to_tid(bias_slot), + out=P.slot_to_tid(out), + ) + ) + + out_dtype = head.meta["val"].dtype + if out_dtype != x_node.meta["val"].dtype: + P.emit( + AsTypeNode( + x=P.slot_to_tid(out), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(out_dtype), + ) + ) + + return out diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py b/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py new file mode 100644 index 00000000000..deb39c4d3c0 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py @@ -0,0 +1,21 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""GGUF Q6_K format implementation (fused custom Metal kernels). + +Re-exports the lightweight constants/header from :mod:`.common` so they can be +imported without pulling in the MLX builder. The ``emit_*`` lowerings live in +:mod:`.linear` / :mod:`.embedding` (called by ``custom_kernel_ops.gguf.patterns``) +and are not imported here. +""" + +from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.common import ( # noqa: F401 + _Q6K_HEADER, + Q6K_BLOCK_BYTES, + QK_K, +) diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/common.py b/backends/mlx/custom_kernel_ops/gguf/q6k/common.py new file mode 100644 index 00000000000..69ddbb0f406 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/common.py @@ -0,0 +1,134 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""Shared GGUF **Q6_K** primitives for the MLX backend. + +This module holds the pieces common to every Q6_K kernel (linear matmul/matvec +and the embedding gather), so format-specific op modules import from here rather +than from each other: + +* ``QK_K`` / ``Q6K_BLOCK_BYTES`` and the per-super-block byte layout constants. +* ``_Q6K_HEADER`` -- the Metal header (the ``block_q6_K`` struct plus the + per-element and vectorized dequant helpers) shared by all Q6_K Metal kernels. + +Q6_K layout + +Q6_K layout (per 256-element super-block, 210 bytes, see llama.cpp +``block_q6_K`` in ``ggml-common.h``):: + + uint8 ql[128] # quants, lower 4 bits + uint8 qh[64] # quants, upper 2 bits + int8 scales[16] # per-16-element sub-block scales (8-bit) + half d # super-block scale + +The dequantized value for a 6-bit code ``q`` (0..63) in sub-block ``s`` is +``d * scales[s] * (q - 32)``. + +Attribution +----------- +The Q6_K block layout and the Metal dequant helpers in ``_Q6K_HEADER`` follow +llama.cpp +(``ggml-common.h`` / ``ggml-metal.metal``: ``block_q6_K``, ``dequantize_q6_K``), +which is MIT-licensed (Copyright (c) 2023-2024 The ggml authors). +""" + +from __future__ import annotations + + +# --------------------------------------------------------------------------- +# Q6_K constants +# --------------------------------------------------------------------------- + +QK_K = 256 +# Per-super-block byte counts. +_Q6K_QL_BYTES = QK_K // 2 # 128 +_Q6K_QH_BYTES = QK_K // 4 # 64 +_Q6K_SCALES = QK_K // 16 # 16 +_Q6K_D_BYTES = 2 # one fp16 +Q6K_BLOCK_BYTES = _Q6K_QL_BYTES + _Q6K_QH_BYTES + _Q6K_SCALES + _Q6K_D_BYTES # 210 + + +# --------------------------------------------------------------------------- +# Shared Metal header +# --------------------------------------------------------------------------- + +# The GGUF block_q6_K struct (matches llama.cpp ggml-common.h; sizeof == 210, no +# padding since max align is 2) plus dequant helpers for both per-element +# (embedding) and vectorized (matmul) use. +_Q6K_HEADER = """ +#include +#include +using namespace metal; + +#define QK_K 256 + +typedef struct { + uint8_t ql[QK_K/2]; // lower 4 bits + uint8_t qh[QK_K/4]; // upper 2 bits + int8_t scales[QK_K/16]; // per-16-element sub-block scales + half d; // super-block scale +} block_q6_K; + +// Dequantize a single element at within-block position p (0..255) of a +// block_q6_K. Used by the embedding kernel. +inline float dequant_q6k_elem(device const block_q6_K * blk, int p) { + const int h = p >> 7; // which 128-element half (0/1) + const int pp = p & 127; // position within half (0..127) + const int g = pp >> 5; // group: 0=q1, 1=q2, 2=q3, 3=q4 + const int l = pp & 31; // 0..31 + device const uint8_t * ql = blk->ql + h * 64; + device const uint8_t * qh = blk->qh + h * 32; + device const int8_t * sc = blk->scales + h * 8; + const int is = l >> 4; // 0/1 + const uint8_t qhb = qh[l]; + int q; + if (g == 0) { q = (ql[l] & 0xF) | ((qhb & 0x03) << 4); } + else if (g == 1) { q = (ql[l + 32] & 0xF) | ((qhb & 0x0C) << 2); } + else if (g == 2) { q = (ql[l] >> 4) | ((qhb & 0x30) << 0); } + else { q = (ql[l + 32] >> 4) | ((qhb & 0xC0) >> 2); } + const float scale = (float) sc[is + 2 * g]; + return (float) blk->d * scale * (float)(q - 32); +} + +// Vectorized Q6_K dequantize: decodes 16 values per call into a 4x4 half +// register. Ported from llama.cpp dequantize_q6_K. `il` ranges 0..15 and +// selects which 16-element slice of the 256-element block to decode. +inline void dequantize_q6_K_16(device const block_q6_K * xb, short il, + thread half4x4 & reg) { + const half d_all = xb->d; + device const uint16_t * ql = (device const uint16_t *)xb->ql; + device const uint16_t * qh = (device const uint16_t *)xb->qh; + device const int8_t * scales = (device const int8_t *)xb->scales; + + ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1); + qh = qh + 16*(il/8) + 8*(il&1); + float sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; + + const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303); + const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F; + const float coeff = d_all * sc; + const float ml = coeff * 32.f; + const float dl0 = coeff; + const float dl1 = dl0 / 256.f; + const float dl2 = dl0 / (256.f * 256.f); + const float dl3 = dl0 / (256.f * 256.f * 256.f); + const uint8_t shr_h = il>2 ? 2 : 0; + const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4); + const uint8_t shr_l = il>1 ? 4 : 0; + for (int i = 0; i < 4; ++i) { + const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2; + const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1; + const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l); + reg[i][0] = (half)(dl0 * ((half)(q & 0xFF)) - ml); + reg[i][1] = (half)(dl1 * ((float)(q & 0xFF00)) - ml); + reg[i][2] = (half)(dl2 * ((float)(q & 0xFF0000)) - ml); + reg[i][3] = (half)(dl3 * ((float)(q & 0xFF000000)) - ml); + } +} +""" diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py b/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py new file mode 100644 index 00000000000..2e7401bdaf4 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py @@ -0,0 +1,122 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""GGUF **Q6_K** embedding lowering for the MLX GGUF pattern handler. + +A custom gather Metal kernel is needed because MLX's affine dequantize has no +group_size=16 kernel, so a Q6_K embedding (group_size 16) can't use the generic +quantized-embedding path. +""" + +from __future__ import annotations + +import torch +from executorch.backends.mlx.builder.op_helpers import ( + emit_product, + emit_shape, + torch_dtype_to_scalar_type, +) +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.common import ( + _Q6K_HEADER, + Q6K_BLOCK_BYTES, + QK_K, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + MetalKernelNode, +) +from torch.fx.node import Node + + +# --------------------------------------------------------------------------- +# Metal kernel source +# --------------------------------------------------------------------------- + + +# One thread per output element. grid = (K, num_idx, 1): x picks the feature j, +# y picks the gathered row; each thread dequantizes a single Q6_K element. +_Q6K_EMBED_SOURCE = """ + const uint j = thread_position_in_grid.x; // 0..K-1 + const uint r = thread_position_in_grid.y; // gathered row + const int row = (int) indices[r]; + const int nb = K / QK_K; + device const block_q6_K * blk = + ((device const block_q6_K *) weight) + (uint)row * nb + (j / QK_K); + out[r * (uint)K + j] = (OutT) dequant_q6k_elem(blk, j % QK_K); +""" + + +def emit_embedding( + P: MLXProgramBuilder, + head: Node, + weight_node: Node, + indices_node: Node, + output_dtype: torch.dtype, +) -> Slot: + """Lower a Q6_K ``dequantize_gguf`` -> ``embedding`` pattern to a fused gather. + + ``weight_node`` is the raw GGUF blob (the dequantize op's weight input) and + ``head`` is the ``aten.embedding`` node that owns the output slot. + """ + weight_slot, indices_slot = P.slot_map([weight_node, indices_node]) + + weight_meta = weight_node.meta["val"] + if weight_meta.dim() != 2: + raise NotImplementedError( + f"gguf q6k embedding: weight must be 2-D (vocab, row_bytes); got " + f"shape {tuple(weight_meta.shape)}" + ) + row_bytes = weight_meta.shape[1] + if not isinstance(row_bytes, int): + raise NotImplementedError( + "gguf q6k embedding: weight shape must be statically known" + ) + if row_bytes % Q6K_BLOCK_BYTES != 0: + raise ValueError( + f"gguf q6k embedding: weight row bytes {row_bytes} must be a " + f"multiple of {Q6K_BLOCK_BYTES}" + ) + K = (row_bytes // Q6K_BLOCK_BYTES) * QK_K + + out_dtype_int = torch_dtype_to_scalar_type(output_dtype) + + out = P.make_or_get_slot(head) + leading = emit_shape(P, indices_node, indices_slot, end_dim=None) + num_idx_iov = emit_product(P, leading) + out_shape_flat = leading + [IntOrVid.from_literal(K)] + + # threadgroup.x must divide grid.x (= K, a multiple of 256). + tg_x = 256 if K % 256 == 0 else K + + P.emit( + MetalKernelNode( + name="gguf_q6k_embedding", + source=_Q6K_EMBED_SOURCE, + header=_Q6K_HEADER, + inputs=[P.slot_to_tid(weight_slot), P.slot_to_tid(indices_slot)], + outputs=[P.slot_to_tid(out)], + grid=[IntOrVid.from_literal(K), num_idx_iov, IntOrVid.from_literal(1)], + threadgroup=[ + IntOrVid.from_literal(tg_x), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=["weight", "indices"], + output_names=["out"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[out_dtype_int], + template_arg_names=["OutT", "K"], + template_arg_kinds=[2, 0], # dtype, int + template_arg_values=[out_dtype_int, K], + ) + ) + + return out diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py b/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py new file mode 100644 index 00000000000..99a82053e90 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py @@ -0,0 +1,549 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""GGUF **Q6_K** linear implementation. + +Provides the Q6_K linear pieces used by the MLX GGUF pattern handler +(:mod:`..patterns`): + +* :func:`eager_linear` -- pure-torch reference (``x @ dequant(weight)^T``). +* :func:`emit_linear` -- lowers a ``dequantize_gguf -> linear`` pattern to fused + Q6_K Metal kernels. + +Compute is keyed on the activation dtype (matching GGUF/llama.cpp): the Metal +kernels are templated on ``InT``, accumulate in ``float32``, read ``d`` as +``half``, and produce output in the activation dtype. + +Two kernels are emitted depending on the number of activation rows ``M``: + + * ``M == 1`` (decode): a fused mat-vec kernel ported from llama.cpp + ``kernel_mul_mv_q6_K_f32_impl``. + * static ``M > 1`` (prefill): a tiled simdgroup mat-mat kernel that + dequantizes weight tiles into threadgroup memory and reuses them across + the activation rows. + * dynamic/symbolic ``M`` (single program serving both prefill and decode): + both kernels are emitted into separate instruction chains and selected at + runtime via an ``IfNode`` on ``M`` (``M > 1`` -> mat-mat, ``M == 1`` -> + mat-vec). + +Attribution +----------- +The Q6_K Metal kernels and dequant routines here are ported from llama.cpp +(``ggml/src/ggml-metal/ggml-metal.metal`` -- ``kernel_mul_mv_q6_K_f32_impl``, +``kernel_mul_mm``, ``dequantize_q6_K``), which is MIT-licensed +(Copyright (c) 2023-2024 The ggml authors). Inline ``ported from ...`` notes +point at the specific upstream function for each kernel. +""" + +from __future__ import annotations + +from typing import Optional + +from executorch.backends.mlx.builder.op_helpers import ( + emit_product, + emit_shape, + torch_dtype_to_scalar_type, +) +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.common import ( + _Q6K_HEADER, + Q6K_BLOCK_BYTES, + QK_K, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AddIntNode, + FloorDivideIntNode, + IfNode, + IntOrVid, + MetalKernelNode, + MultiplyIntNode, + SubtractIntNode, +) +from torch.fx.node import Node + + +# --------------------------------------------------------------------------- +# Metal kernel sources +# --------------------------------------------------------------------------- + + +# Decode mat-vec kernel, ported from llama.cpp kernel_mul_mv_q6_K_f32_impl. +# Threadgroup = (32 * NSG, 1, 1): NSG simdgroups, each computing N_R0 output +# rows for one activation row (grid.y). Accumulate in float, reduce via simd_sum. +def _q6k_matvec_source(has_bias: bool) -> str: + write = "out[(uint)m * N + r] = (InT)(tot" + write += " + (float)bias[r]);" if has_bias else ");" + return f""" + constexpr short N_R0 = 2; + + const ushort tiisg = thread_index_in_simdgroup; + const ushort sgitg = simdgroup_index_in_threadgroup; + const uint m = thread_position_in_grid.y; + const uint tgx = thread_position_in_grid.x / (32u * NSG); + const int nb = K / QK_K; + const int first_row = (int)(tgx * NSG + sgitg) * N_R0; + + const short tid = tiisg / 2; + const short ix = tiisg % 2; + const short ip = tid / 8; // 0 or 1 (which 128-half) + const short il = tid % 8; + const short l0 = 4 * il; + const short is = 8 * ip + l0 / 16; + + const short y_offset = 128 * ip + l0; + const short q_offset_l = 64 * ip + l0; + const short q_offset_h = 32 * ip + l0; + + device const block_q6_K * xrows = (device const block_q6_K *) weight; + device const InT * yy = x + (uint)m * (uint)K; + + float sumf[N_R0]; + for (short r = 0; r < N_R0; ++r) {{ sumf[r] = 0.f; }} + + float yl[16]; + for (int i = ix; i < nb; i += 2) {{ + device const InT * yb = yy + i * QK_K + y_offset; + for (short l = 0; l < 4; ++l) {{ + yl[4*l + 0] = (float) yb[l + 0]; + yl[4*l + 1] = (float) yb[l + 32]; + yl[4*l + 2] = (float) yb[l + 64]; + yl[4*l + 3] = (float) yb[l + 96]; + }} + + for (short row = 0; row < N_R0; ++row) {{ + const int r = first_row + row; + if (r >= N) {{ break; }} + device const block_q6_K * blk = xrows + (uint)r * nb + i; + device const uint8_t * q1 = blk->ql + q_offset_l; + device const uint8_t * q2 = q1 + 32; + device const uint8_t * qh = blk->qh + q_offset_h; + device const int8_t * sc = blk->scales + is; + const float d = (float) blk->d; + + float4 sums = {{0.f, 0.f, 0.f, 0.f}}; + for (short l = 0; l < 4; ++l) {{ + sums[0] += yl[4*l + 0] * (float)((int8_t)((q1[l] & 0xF) | ((qh[l] & 0x03) << 4)) - 32); + sums[1] += yl[4*l + 1] * (float)((int8_t)((q2[l] & 0xF) | ((qh[l] & 0x0C) << 2)) - 32); + sums[2] += yl[4*l + 2] * (float)((int8_t)((q1[l] >> 4) | ((qh[l] & 0x30) << 0)) - 32); + sums[3] += yl[4*l + 3] * (float)((int8_t)((q2[l] >> 4) | ((qh[l] & 0xC0) >> 2)) - 32); + }} + sumf[row] += d * (sums[0]*sc[0] + sums[1]*sc[2] + sums[2]*sc[4] + sums[3]*sc[6]); + }} + }} + + for (short row = 0; row < N_R0; ++row) {{ + const int r = first_row + row; + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && r < N) {{ + {write} + }} + }} +""" + + +# Prefill mat-mat kernel, ported from llama.cpp kernel_mul_mm (Q6_K variant). +# 64x32 output tiles, 4 simdgroups / 128 threads per threadgroup. +# Uses vectorized dequantize_q6_K_16 to decode 16 weight values per thread +# into threadgroup memory, then runs simdgroup_multiply_accumulate on 8x8 +# tiles. NL=16 for Q6_K (QK_K / 16 = 16 dequant steps per super-block). +# C[m, n] = sum_k x[m, k] * dequant(weight)[n, k] (+ bias[n]). +def _q6k_matmul_source(has_bias: bool) -> str: + bias_add = "+ (float) bias[r0 + i]" if has_bias else "" + return f""" + constexpr short NR0 = 64; // weight/output rows per tile (N dim) + constexpr short NR1 = 32; // activation rows per tile (M dim) + constexpr short NK = 32; // K-chunk per iteration + constexpr short NL = 16; // Q6_K: QK_K / 16 + constexpr short NL0 = NK / 16; // = 2 — dequant iterations per thread for weight + constexpr short NL1 = NK / 8; // = 4 — load iterations per thread for activation + + threadgroup half sa[4096]; // NR0 * NK storage (strided by 64) + threadgroup half sb[4096]; // NR1 * NK storage (strided by 64) + + const ushort tid = thread_index_in_threadgroup; // 0..127 + const ushort sgitg = simdgroup_index_in_threadgroup; // 0..3 + + const uint r0 = thread_position_in_grid.y * NR0; // first weight row + const uint r1 = (thread_position_in_grid.x / 128u) * NR1; // first activation row + + // M (number of activation rows) read at runtime. + int M = 1; + for (uint d = 0; d + 1 < x_ndim; ++d) {{ M *= (int) x_shape[d]; }} + + const int nb = K / QK_K; + + // Clamp tile edges. + const short nr0 = (N - (int)r0 < NR0) ? (N - (int)r0) : NR0; + const short nr1 = (M - (int)r1 < NR1) ? (M - (int)r1) : NR1; + + // Thread → element mapping for cooperative loads. + const short lr0 = ((short)(tid / NL0) < nr0) ? (short)(tid / NL0) : (nr0 - 1); // 0..63 + const short lr1 = ((short)(tid / NL1) < nr1) ? (short)(tid / NL1) : (nr1 - 1); // 0..31 + + short il0 = tid % NL0; + short il = il0; // current dequant sub-block index within Q6_K block + + const short offset1 = il0 / NL; // always 0 for NL=16, NL0=2 + + // Pointer to weight block for this thread's assigned row. + device const block_q6_K * wblk = (device const block_q6_K *) weight + + (uint)(r0 + lr0) * nb + offset1; + + // Pointer to activation row for this thread. + const short iy = 8 * (tid % NL1); + device const InT * yp = x + (uint)(r1 + lr1) * (uint)K + iy; + + // Accumulator: 8 simdgroup 8x8 matrices (4 sgitg configs x 2 sub-tiles). + simdgroup_half8x8 ma[4]; + simdgroup_half8x8 mb[2]; + simdgroup_float8x8 mc[8]; + for (short i = 0; i < 8; ++i) {{ + mc[i] = make_filled_simdgroup_matrix(0.f); + }} + + for (int loop_k = 0; loop_k < K; loop_k += NK) {{ + // --- Cooperative load: dequantized weight tile (NR0 x NK) into sa --- + half4x4 temp_a; + dequantize_q6_K_16(wblk, il, temp_a); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short i = 0; i < 16; ++i) {{ + const short sx = 2 * il0 + i / 8; + const short sy = (tid / NL0) / 8; + const short lx = (tid / NL0) % 8; + const short ly = i % 8; + const short ib = 8 * sx + sy; + *(sa + 64 * ib + 8 * ly + lx) = temp_a[i / 4][i % 4]; + }} + + // --- Cooperative load: activation tile (NR1 x NK) into sb --- + const short sx_b = tid % NL1; + const short sy_b = (tid / NL1) / 8; + const short ly_b = (tid / NL1) % 8; + const short ib_b = 4 * sx_b + sy_b; + + for (short i = 0; i < 8; ++i) {{ + *(sb + 64 * ib_b + 8 * ly_b + i) = (half) *(yp + i); + }} + + // Advance weight pointer through Q6_K sub-blocks. + il = (il + 2 < NL) ? il + 2 : il % 2; + wblk = (il < 2) ? wblk + (2 + NL - 1) / NL : wblk; + + yp += NK; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // --- Simdgroup matmul on loaded tiles --- + threadgroup const half * lsma = sa + 4 * 64 * (sgitg % 2); + threadgroup const half * lsmb = sb + 2 * 64 * (sgitg / 2); + + for (short ik = 0; ik < NK / 8; ++ik) {{ + simdgroup_barrier(mem_flags::mem_none); + for (short i = 0; i < 4; ++i) {{ + simdgroup_load(ma[i], lsma + 64 * i, 8, ulong2(0, 0), false); + }} + simdgroup_barrier(mem_flags::mem_none); + for (short i = 0; i < 2; ++i) {{ + simdgroup_load(mb[i], lsmb + 64 * i, 8, ulong2(0, 0), false); + }} + simdgroup_barrier(mem_flags::mem_none); + for (short i = 0; i < 8; ++i) {{ + simdgroup_multiply_accumulate(mc[i], mb[i / 4], ma[i % 4], mc[i]); + }} + lsma += 8 * 64; + lsmb += 4 * 64; + }} + }} + + // --- Write results: always via threadgroup memory for float→InT cast --- + // Barrier needed: sa was used for weight tiles during the K-loop and is now + // reused as float staging for the output. Without this barrier, a fast + // simdgroup could start writing mc[] into sa while a slower one is still + // reading the last weight tile via simdgroup_load(ma[]). + // (Mirrors the barrier in llama.cpp kernel_mul_mm's bounds-checked write path.) + threadgroup_barrier(mem_flags::mem_threadgroup); + {{ + threadgroup float * temp_str = ((threadgroup float *) sa) + + 32 * (sgitg & 1) + (16 * (sgitg >> 1)) * NR0; + for (short i = 0; i < 8; ++i) {{ + simdgroup_store(mc[i], temp_str + 8 * (i % 4) + 8 * NR0 * (i / 4), + NR0, ulong2(0, 0), false); + }} + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) {{ + for (int j = tid; j < nr1; j += NR1) {{ + device InT * D = out + (uint)(r1 + j) * (uint)N + r0; + threadgroup float * Cp = ((threadgroup float *) sa) + j * NR0; + for (int i = 0; i < nr0; ++i) {{ + float v = Cp[i]; + D[i] = (InT)(v {bias_add}); + }} + }} + }} + }} +""" + + +# Number of simdgroups per threadgroup for the mat-vec kernel. +_Q6K_MV_NSG = 4 +# Tile sizes for the mat-mat kernel (from llama.cpp kernel_mul_mm). +_Q6K_MM_NR0 = 64 # weight/output rows (N dim) per threadgroup +_Q6K_MM_NR1 = 32 # activation rows (M dim) per threadgroup + + +def _emit_q6k_matvec( + P: MLXProgramBuilder, + x_node: Node, + x_slot: Slot, + weight_slot: Slot, + bias_slot: Optional[Slot], + N: int, + K: int, + out: Slot, +) -> None: + in_dtype_int = torch_dtype_to_scalar_type(x_node.meta["val"].dtype) + + leading = emit_shape(P, x_node, x_slot, end_dim=-1) + M_iov = emit_product(P, leading) + out_shape_flat = leading + [IntOrVid.from_literal(N)] + + n_r0 = 2 + nsg = _Q6K_MV_NSG + num_row_groups = (N + nsg * n_r0 - 1) // (nsg * n_r0) + grid_x = num_row_groups * 32 * nsg + + has_bias = bias_slot is not None + inputs = [P.slot_to_tid(x_slot), P.slot_to_tid(weight_slot)] + input_names = ["x", "weight"] + if has_bias: + inputs.append(P.slot_to_tid(bias_slot)) + input_names.append("bias") + + P.emit( + MetalKernelNode( + name="gguf_q6k_matvec", + source=_q6k_matvec_source(has_bias), + header=_Q6K_HEADER, + inputs=inputs, + outputs=[P.slot_to_tid(out)], + grid=[ + IntOrVid.from_literal(grid_x), + M_iov, + IntOrVid.from_literal(1), + ], + threadgroup=[ + IntOrVid.from_literal(32 * nsg), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=input_names, + output_names=["out"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[in_dtype_int], + template_arg_names=["InT", "N", "K", "NSG"], + template_arg_kinds=[2, 0, 0, 0], # dtype, int, int, int + template_arg_values=[in_dtype_int, N, K, nsg], + ) + ) + + +def _emit_q6k_matmul( + P: MLXProgramBuilder, + x_node: Node, + x_slot: Slot, + weight_slot: Slot, + bias_slot: Optional[Slot], + N: int, + K: int, + blocks_m_iov: IntOrVid, + out: Slot, +) -> None: + in_dtype_int = torch_dtype_to_scalar_type(x_node.meta["val"].dtype) + + leading = emit_shape(P, x_node, x_slot, end_dim=-1) + out_shape_flat = leading + [IntOrVid.from_literal(N)] + + # grid.x = ceil(M / NR1) * 128 threads (activation tiles) + # grid.y = ceil(N / NR0) (weight tiles) + blocks_n = (N + _Q6K_MM_NR0 - 1) // _Q6K_MM_NR0 + + has_bias = bias_slot is not None + inputs = [P.slot_to_tid(x_slot), P.slot_to_tid(weight_slot)] + input_names = ["x", "weight"] + if has_bias: + inputs.append(P.slot_to_tid(bias_slot)) + input_names.append("bias") + + # blocks_m_iov = ceil(M / NR1); multiply by 128 for grid.x + _, grid_x_slot = P.make_tmp_value_slot() + P.emit( + MultiplyIntNode( + a=blocks_m_iov, + b=IntOrVid.from_literal(128), + out=P.slot_to_vid(grid_x_slot), + ) + ) + grid_x_iov = IntOrVid.from_vid(P.slot_to_vid(grid_x_slot)) + + P.emit( + MetalKernelNode( + name="gguf_q6k_matmul", + source=_q6k_matmul_source(has_bias), + header=_Q6K_HEADER, + inputs=inputs, + outputs=[P.slot_to_tid(out)], + grid=[ + grid_x_iov, + IntOrVid.from_literal(blocks_n), + IntOrVid.from_literal(1), + ], + threadgroup=[ + IntOrVid.from_literal(128), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=input_names, + output_names=["out"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[in_dtype_int], + template_arg_names=["InT", "N", "K"], + template_arg_kinds=[2, 0, 0], + template_arg_values=[in_dtype_int, N, K], + ) + ) + + +def emit_linear( + P: MLXProgramBuilder, + head: Node, + x_node: Node, + weight_node: Node, + bias_node: Optional[Node], +) -> Slot: + """Lower a Q6_K ``dequantize_gguf`` -> ``linear`` pattern to fused kernels. + + ``weight_node`` is the raw GGUF blob (the dequantize op's weight input) and + ``head`` is the ``aten.linear`` node that owns the output slot. + """ + x_slot, weight_slot, bias_slot = P.slot_map([x_node, weight_node, bias_node]) + + weight_meta = weight_node.meta["val"] + if weight_meta.dim() != 2: + raise NotImplementedError( + f"gguf q6k linear: weight must be 2-D (N, row_bytes); got " + f"shape {tuple(weight_meta.shape)}" + ) + N = weight_meta.shape[0] + row_bytes = weight_meta.shape[1] + if not isinstance(N, int) or not isinstance(row_bytes, int): + raise NotImplementedError( + "gguf q6k linear: weight shape must be statically known" + ) + if row_bytes % Q6K_BLOCK_BYTES != 0: + raise ValueError( + f"gguf q6k linear: weight row bytes {row_bytes} must be a multiple of " + f"{Q6K_BLOCK_BYTES}" + ) + K = (row_bytes // Q6K_BLOCK_BYTES) * QK_K + + # Determine M (product of x's leading dims). Static M lets us pick the + # optimal kernel and (for mat-mat) compute a literal launch grid. + x_meta = x_node.meta["val"] + leading_dims = x_meta.shape[:-1] + M: Optional[int] = 1 + for d in leading_dims: + if isinstance(d, int): + M *= d + else: + M = None # dynamic / symbolic + break + + out = P.make_or_get_slot(head) + tile = _Q6K_MM_NR1 # M-dimension tile (activation rows per threadgroup) + if M == 1: + # Static decode -> mat-vec. + _emit_q6k_matvec(P, x_node, x_slot, weight_slot, bias_slot, N, K, out) + elif M is not None: + # Static prefill -> tiled simdgroup mat-mat (literal grid). + blocks_m = (M + tile - 1) // tile + _emit_q6k_matmul( + P, + x_node, + x_slot, + weight_slot, + bias_slot, + N, + K, + IntOrVid.from_literal(blocks_m), + out, + ) + else: + # Dynamic seqlen -> emit both kernels in separate chains and select at + # runtime with an IfNode. cond = M - 1: nonzero (M>1) runs the mat-mat + # (then) chain, zero (M==1) runs the mat-vec (else) chain. + leading = emit_shape(P, x_node, x_slot, end_dim=-1) + m_iov = emit_product(P, leading) + + _, cond_slot = P.make_tmp_value_slot() + P.emit( + SubtractIntNode( + a=m_iov, + b=IntOrVid.from_literal(1), + out=P.slot_to_vid(cond_slot), + ) + ) + cond_iov = IntOrVid.from_vid(P.slot_to_vid(cond_slot)) + + # blocks_m = (M + tile - 1) // tile (mat-mat grid.y). + _, sum_slot = P.make_tmp_value_slot() + P.emit( + AddIntNode( + a=m_iov, + b=IntOrVid.from_literal(tile - 1), + out=P.slot_to_vid(sum_slot), + ) + ) + _, blocks_m_slot = P.make_tmp_value_slot() + P.emit( + FloorDivideIntNode( + a=IntOrVid.from_vid(P.slot_to_vid(sum_slot)), + b=IntOrVid.from_literal(tile), + out=P.slot_to_vid(blocks_m_slot), + ) + ) + blocks_m_iov = IntOrVid.from_vid(P.slot_to_vid(blocks_m_slot)) + + with P.new_chain() as then_idx: # prefill / mat-mat + _emit_q6k_matmul( + P, + x_node, + x_slot, + weight_slot, + bias_slot, + N, + K, + blocks_m_iov, + out, + ) + with P.new_chain() as else_idx: # decode / mat-vec + _emit_q6k_matvec(P, x_node, x_slot, weight_slot, bias_slot, N, K, out) + + P.emit( + IfNode( + cond=cond_iov, + then_chain_idx=then_idx, + else_chain_idx=else_idx, + ) + ) + return out diff --git a/backends/mlx/custom_kernel_ops/gguf/test/__init__.py b/backends/mlx/custom_kernel_ops/gguf/test/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/test/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py new file mode 100644 index 00000000000..3f8e60b7aa8 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for the GGUF Q6_K embedding lowering. + +An ``nn.Embedding`` whose weight is an ``ExportableGGUFTensor`` exports to +``embedding(torchao::dequantize_gguf(weight, "q6_k", ...), indices)``. The MLX +``GGUF_QUANTIZED_EMBEDDING`` pattern matches that subgraph and lowers it to the +fused Q6_K gather Metal kernel. These tests compare the kernel against the eager +reference (``gguf``-package dequant + ``F.embedding``) on the same packed table. + +Usage:: + + python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_embedding run + python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_embedding list +""" + +from typing import List, Tuple + +# Importing the patterns module registers GGUF_QUANTIZED_LINEAR / _EMBEDDING. +import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 +import torch +import torch.nn as nn +from executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear import ( + make_q6_k_blob, +) +from executorch.backends.mlx.test.test_utils import OpTestCase +from executorch.extension.llm.export.gguf import ExportableGGUFTensor + + +def _make_gguf_embedding_model(vocab: int, K: int, seed: int = 0) -> nn.Module: + """An ``nn.Embedding`` whose weight is a Q6_K ``ExportableGGUFTensor``.""" + emb = nn.Embedding(vocab, K) + blob = make_q6_k_blob(vocab, K, seed=seed) + emb.weight = nn.Parameter( + ExportableGGUFTensor.from_raw(blob, "q6_k", torch.bfloat16), + requires_grad=False, + ) + return emb + + +class GGUFEmbeddingTest(OpTestCase): + name = "gguf_embedding" + # Reference dequant runs in fp32 (gguf) then casts to bf16; the kernel + # dequantizes per element to bf16, so allow bf16 tolerance. + rtol = 2e-2 + atol = 2e-2 + + def __init__( + self, + vocab: int = 512, + K: int = 256, + idx_shape: Tuple[int, ...] = (8,), + ): + self.vocab = vocab + self.K = K + self.idx_shape = idx_shape + shp = "x".join(str(d) for d in idx_shape) + self.name = f"gguf_embedding_v{vocab}_k{K}_idx{shp}" + + @classmethod + def get_test_configs(cls) -> List["GGUFEmbeddingTest"]: + return [ + cls(vocab=512, K=256, idx_shape=(1,)), + cls(vocab=512, K=256, idx_shape=(8,)), + cls(vocab=512, K=256, idx_shape=(64,)), + cls(vocab=512, K=512, idx_shape=(8,)), + cls(vocab=512, K=1024, idx_shape=(4,)), + cls(vocab=300, K=256, idx_shape=(16,)), # vocab not tile-aligned + cls(vocab=512, K=256, idx_shape=(2, 3)), # multi-dim indices + # Real Gemma-4-31B embed width (K=5376, 21 Q6_K blocks/row). Vocab is + # kept small so the packed weight fits CI-runner GPU buffer limits; the + # gather + per-row dequant path is identical regardless of vocab. + cls(vocab=2048, K=5376, idx_shape=(8,)), + ] + + def get_edge_compile_config(self): + from executorch.exir import EdgeCompileConfig + + # The dequantize_gguf custom op isn't a core ATen op; skip IR validity. + return EdgeCompileConfig(_check_ir_validity=False) + + def create_model(self) -> nn.Module: + return _make_gguf_embedding_model(self.vocab, self.K) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + torch.manual_seed(0) + indices = torch.randint(0, self.vocab, self.idx_shape, dtype=torch.int64) + return (indices,) + + +def _main() -> None: # noqa: C901 + import argparse + import sys + + from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner + + parser = argparse.ArgumentParser(description="Test GGUF Q6_K embedding lowering") + parser.add_argument("action", choices=["generate", "compare", "run", "list"]) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--rebuild", action="store_true") + parser.add_argument("--config", type=str, default=None) + args = parser.parse_args() + + if args.rebuild and not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + configs = GGUFEmbeddingTest.get_test_configs() + + if args.action == "list": + for cfg in configs: + print(f" {cfg.name}") + sys.exit(0) + + if args.config: + configs = [c for c in configs if c.name == args.config] + if not configs: + print(f"No config matching '{args.config}'") + sys.exit(1) + + passed = 0 + failed = 0 + failed_names: List[str] = [] + + for test in configs: + if args.action == "generate": + pte_path, _, _ = test.generate_test_files(verbose=args.verbose) + print(f"Generated: {pte_path}") + elif args.action == "compare": + actual_path = test.get_test_dir() / "actual_output.bin" + ok, msg = test.compare_with_actual(actual_path) + print(f"{'✓' if ok else '✗'} {test.name}: {msg}") + passed, failed = (passed + 1, failed) if ok else (passed, failed + 1) + if not ok: + failed_names.append(test.name) + elif args.action == "run": + ok = test.run_test(verbose=args.verbose) + passed, failed = (passed + 1, failed) if ok else (passed, failed + 1) + if not ok: + failed_names.append(test.name) + + if args.action in ("run", "compare"): + print(f"\nPassed: {passed}, Failed: {failed}") + if failed_names: + print(f"Failed: {', '.join(failed_names)}") + sys.exit(0 if failed == 0 else 1) + + +if __name__ == "__main__": + _main() diff --git a/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py new file mode 100644 index 00000000000..4a7defbe107 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py @@ -0,0 +1,394 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for the GGUF Q6_K linear lowering. + +A linear whose weight is an ``ExportableGGUFTensor`` (extension/llm/export/gguf) +exports to ``linear(x, torchao::dequantize_gguf(weight, "q6_k", ...), bias)``. +The MLX ``GGUF_QUANTIZED_LINEAR`` pattern (custom_kernel_ops/gguf/patterns.py) +matches that subgraph and lowers it to the fused Q6_K Metal kernels (mat-vec for +decode, mat-mat for prefill). These tests compare the fused kernels against the +eager reference (``gguf``-package dequant + ``F.linear``) on the same packed +weight, so quantization quality is irrelevant -- only kernel-vs-reference +numerics are checked. + +``GGUFLinearDynamicTest`` exports once with a symbolic seqlen and runs the same +.pte with M=1 and M>1 to exercise both branches of the runtime ``IfNode`` +(decode mat-vec vs prefill mat-mat). + +Usage:: + + python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear run + python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear run -v + python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear list +""" + +from typing import List, Tuple + +# Importing the patterns module registers GGUF_QUANTIZED_LINEAR / _EMBEDDING. +import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 +import torch +import torch.nn as nn +from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import Q6K_BLOCK_BYTES, QK_K +from executorch.backends.mlx.test.test_utils import OpTestCase +from executorch.extension.llm.export.gguf import ExportableGGUFTensor + + +# --------------------------------------------------------------------------- +# GGUF Q6_K test fixtures. +# +# The Python ``gguf`` package can dequantize Q6_K but does NOT implement Q6_K +# quantization, so we build the packed weight here. Quantization quality is +# irrelevant: the tests only compare the kernel against the eager reference on +# the *same* bytes, so we just emit valid random blocks (random ql/qh/scales +# plus a small finite fp16 ``d`` -- the one field that must be finite). +# --------------------------------------------------------------------------- + + +def make_q6_k_blob(N: int, K: int, seed: int = 0) -> torch.Tensor: + """Build a ``(N, (K/256)*210)`` uint8 tensor of valid GGUF Q6_K blocks.""" + assert K % QK_K == 0, f"K={K} must be a multiple of {QK_K}" + nb = K // QK_K + g = torch.Generator().manual_seed(seed) + out = torch.empty(N, nb * Q6K_BLOCK_BYTES, dtype=torch.uint8) + blocks = out.view(N, nb, Q6K_BLOCK_BYTES) + # ql (0:128) + qh (128:192): any byte values are valid 6-bit quants. + blocks[..., :192] = torch.randint( + 0, 256, (N, nb, 192), dtype=torch.uint8, generator=g + ) + # scales (192:208): signed int8 scales (real Q6_K scales can be negative); + # a modest magnitude keeps dequantized values sane. + scales = torch.randint(-16, 17, (N, nb, 16), dtype=torch.int32, generator=g) + blocks[..., 192:208] = scales.to(torch.int8).view(torch.uint8) + # d (208:210): a small finite fp16 super-block scale. Chosen so dequantized + # element magnitudes (~ d * scale * (q-32)) are O(0.1), like real Q6_K + # weights -- the mat-mat kernel stores tiles in half precision (as in + # llama.cpp), so unrealistically large magnitudes would exceed bf16 tol. + blocks[..., 208:210] = torch.tensor([7e-4], dtype=torch.float16).view(torch.uint8) + return out + + +def make_q4_k_blob(N: int, K: int, seed: int = 0) -> torch.Tensor: + """Build a ``(N, (K/256)*144)`` uint8 tensor of valid GGUF Q4_K blocks.""" + assert K % QK_K == 0, f"K={K} must be a multiple of {QK_K}" + nb = K // QK_K + block_bytes = 144 # Q4_K: d(2) + dmin(2) + scales(12) + qs(128) + g = torch.Generator().manual_seed(seed) + out = torch.empty(N, nb * block_bytes, dtype=torch.uint8) + blocks = out.view(N, nb, block_bytes) + # d (0:2) / dmin (2:4): small finite fp16 super-block scale + min, so + # dequantized magnitudes stay O(0.1) like real Q4_K weights. + blocks[..., 0:2] = torch.tensor([7e-4], dtype=torch.float16).view(torch.uint8) + blocks[..., 2:4] = torch.tensor([7e-4], dtype=torch.float16).view(torch.uint8) + # scales+mins (4:16, 6-bit packed) and qs (16:144, 4-bit): any bytes valid. + blocks[..., 4:144] = torch.randint( + 0, 256, (N, nb, 140), dtype=torch.uint8, generator=g + ) + return out + + +_BLOB_MAKERS = {"q6_k": make_q6_k_blob, "q4_k": make_q4_k_blob} + + +def _make_gguf_linear_model( + N: int, + K: int, + dtype: torch.dtype, + bias: bool, + ggml_type: str = "q6_k", + seed: int = 0, +) -> nn.Module: + """An ``nn.Linear`` whose weight is a GGUF ``ExportableGGUFTensor``.""" + linear = nn.Linear(K, N, bias=bias).to(dtype) + blob = _BLOB_MAKERS[ggml_type](N, K, seed=seed) + linear.weight = nn.Parameter( + ExportableGGUFTensor.from_raw(blob, ggml_type, dtype), requires_grad=False + ) + return linear + + +class GGUFLinearModel(nn.Module): + """Wrapper so the forward arg is named ``x`` (for dynamic-shape specs).""" + + def __init__(self, linear: nn.Module): + super().__init__() + self.linear = linear + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +def _fp32_linear_reference(model: "GGUFLinearModel", x: torch.Tensor): + """fp32-accumulation reference matching the kernel. + + The kernels accumulate in fp32 and cast to the I/O dtype only at the end, so + a bf16 eager matmul is too noisy an oracle over large K. Dequantize in fp32, + matmul in fp32, then cast back -- differences collapse to ~1 output ULP. + + The reference weight must match the representation the kernel consumes: + Q6_K dequantizes the raw blob in-kernel at full precision (use the gguf-exact + dequant), while Q4_K is repacked into bf16 MLX qparams, so use that repacked + dequant (repack precision vs gguf is covered separately by test_gguf.py). + """ + lin = model.linear + weight = lin.weight + if getattr(weight, "ggml_type", None) == "q4_k": + # Q4_K is repacked into bf16 MLX affine qparams (S, Q, B); reconstruct + # exactly what the kernel dequantizes so the oracle isolates kernel + # accumulation (repack precision vs gguf is covered by test_gguf.py). + from executorch.backends.mlx.builder.op_helpers import to_mlx_qparams + + intx = weight.to_intx_unpacked_to_int8_tensor() + gs = int(intx.block_size[-1]) + Q, B = to_mlx_qparams(intx.qdata, intx.scale, intx.zero_point, 4) + qb = Q.view(torch.uint8) + nibbles = torch.stack([(qb & 0xF).float(), ((qb >> 4) & 0xF).float()], dim=-1) + q_unsigned = nibbles.reshape(intx.qdata.shape[0], -1) + scale = intx.scale.float().repeat_interleave(gs, dim=1) + bias_b = B.float().repeat_interleave(gs, dim=1) + w = scale * q_unsigned + bias_b + else: + w = weight.dequantize(torch.float32) + bias = lin.bias.float() if lin.bias is not None else None + out = torch.nn.functional.linear(x.float(), w, bias) + return [out.to(x.dtype)] + + +_DTYPE_TOL = { + torch.bfloat16: (2e-2, 2e-2), + # The mat-mat (prefill) kernel stores tiles in half precision (as in + # llama.cpp), so fp16 outputs are accurate to ~half precision (~4e-3). + torch.float16: (5e-3, 5e-3), + torch.float32: (1e-4, 1e-4), +} +_DTYPE_TAG = {torch.bfloat16: "bf16", torch.float16: "fp16", torch.float32: "fp32"} + + +def _edge_compile_config(): + from executorch.exir import EdgeCompileConfig + + # The dequantize_gguf custom op isn't a core ATen op; skip IR validity. + return EdgeCompileConfig(_check_ir_validity=False) + + +class GGUFLinearTest(OpTestCase): + name = "gguf_linear" + + def __init__( + self, + M: int = 1, + N: int = 256, + K: int = 256, + dtype: torch.dtype = torch.bfloat16, + bias: bool = True, + ggml_type: str = "q6_k", + ): + self.M = M + self.N = N + self.K = K + self.dtype = dtype + self.bias = bias + self.ggml_type = ggml_type + self.rtol, self.atol = _DTYPE_TOL[dtype] + tag = f"gguf_linear_{ggml_type}_m{M}_n{N}_k{K}_{_DTYPE_TAG[dtype]}" + self.name = tag if bias else tag + "_nobias" + + @classmethod + def get_test_configs(cls) -> List["GGUFLinearTest"]: + cfgs: List["GGUFLinearTest"] = [] + # Decode (mat-vec). + for K in (256, 512, 1024): + for N in (256, 512): + cfgs.append(cls(M=1, N=N, K=K, dtype=torch.bfloat16)) + cfgs.append(cls(M=1, N=256, K=256, dtype=torch.float16)) + cfgs.append(cls(M=1, N=256, K=256, dtype=torch.float32)) + cfgs.append(cls(M=1, N=256, K=256, dtype=torch.bfloat16, bias=False)) + # Prefill (mat-mat). + for M in (8, 64, 128): + cfgs.append(cls(M=M, N=512, K=512, dtype=torch.bfloat16)) + cfgs.append(cls(M=32, N=256, K=256, dtype=torch.float16)) + # Ragged shapes (M and N not multiples of the 32-wide tile / row group). + cfgs.append(cls(M=40, N=300, K=256, dtype=torch.bfloat16)) + cfgs.append(cls(M=1, N=300, K=256, dtype=torch.bfloat16)) + # Real Gemma-4-31B shapes (hidden=5376, ffn=21504) at production N/K. + cfgs.append(cls(M=1, N=4096, K=5376, dtype=torch.bfloat16)) # attn_v + cfgs.append(cls(M=1, N=5376, K=21504, dtype=torch.bfloat16)) # ffn_down + cfgs.append(cls(M=8, N=5376, K=21504, dtype=torch.bfloat16)) # ffn_down prefill + # lm_head: real vocab is 262144, but N is capped so the packed weight + # fits CI-runner GPU buffer limits; the mat-vec N-tiling path is the + # same at any N. + cfgs.append(cls(M=1, N=16384, K=5376, dtype=torch.bfloat16)) # lm_head + # Q4_K -> MLX native 4-bit quantized_matmul (group_size 32). + cfgs.append(cls(M=1, N=512, K=512, dtype=torch.bfloat16, ggml_type="q4_k")) + cfgs.append(cls(M=8, N=512, K=512, dtype=torch.bfloat16, ggml_type="q4_k")) + cfgs.append(cls(M=1, N=5376, K=5376, dtype=torch.bfloat16, ggml_type="q4_k")) + cfgs.append( + cls(M=1, N=512, K=512, dtype=torch.bfloat16, bias=False, ggml_type="q4_k") + ) + return cfgs + + def get_edge_compile_config(self): + return _edge_compile_config() + + def create_model(self) -> nn.Module: + return GGUFLinearModel( + _make_gguf_linear_model( + self.N, self.K, self.dtype, self.bias, self.ggml_type + ) + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + torch.manual_seed(0) + return (torch.randn(self.M, self.K, dtype=self.dtype),) + + def compute_expected_outputs(self, model, test_inputs): + return _fp32_linear_reference(model, test_inputs[0]) + + +class GGUFLinearDynamicTest(OpTestCase): + """Dynamic seqlen: export once with a symbolic M, run with M=1 (decode / + else chain) and M>1 (prefill / then chain) to exercise both IfNode branches. + """ + + name = "gguf_linear_dynamic" + + def __init__( + self, + export_M: int = 4, + test_M: int = 1, + N: int = 512, + K: int = 512, + dtype: torch.dtype = torch.bfloat16, + ): + self.export_M = export_M + self.test_M = test_M + self.N = N + self.K = K + self.dtype = dtype + self.rtol, self.atol = _DTYPE_TOL[dtype] + self.name = ( + f"gguf_linear_dyn_exp{export_M}_test{test_M}_n{N}_k{K}_" + f"{_DTYPE_TAG[dtype]}" + ) + + @classmethod + def get_test_configs(cls) -> List["GGUFLinearDynamicTest"]: + return [ + cls(export_M=4, test_M=1, dtype=torch.bfloat16), # decode / else + cls(export_M=4, test_M=8, dtype=torch.bfloat16), # prefill / then + cls(export_M=4, test_M=4, dtype=torch.bfloat16), # control + cls(export_M=4, test_M=1, dtype=torch.float16), + cls(export_M=4, test_M=40, N=300, K=256, dtype=torch.bfloat16), # ragged + ] + + def get_dynamic_shapes(self): + seq_dim = torch.export.Dim("seq_len", min=1, max=64) + return {"x": {0: seq_dim}} + + def get_edge_compile_config(self): + return _edge_compile_config() + + def create_model(self) -> nn.Module: + # Deterministic weight so export-time and run-time use the same model. + return GGUFLinearModel( + _make_gguf_linear_model(self.N, self.K, self.dtype, bias=True) + ) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + torch.manual_seed(0) + return (torch.randn(self.export_M, self.K, dtype=self.dtype),) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + torch.manual_seed(0) + return (torch.randn(self.test_M, self.K, dtype=self.dtype),) + + def compute_expected_outputs(self, model, test_inputs): + return _fp32_linear_reference(model, test_inputs[0]) + + +def _eager_sanity() -> None: + """Quick CPU check: the subclass linear exports to dequantize_gguf.""" + model = GGUFLinearModel(_make_gguf_linear_model(4, 512, torch.bfloat16, bias=True)) + x = torch.randn(3, 512, dtype=torch.bfloat16) + out = model(x) + print( + f"eager forward finite: {torch.isfinite(out).all().item()}, shape {tuple(out.shape)}" + ) + ep = torch.export.export(model, (x,)).run_decompositions({}) + targets = {str(n.target) for n in ep.graph.nodes if n.op == "call_function"} + assert "torchao.dequantize_gguf.default" in targets, targets + print("export contains torchao.dequantize_gguf: OK") + + +if __name__ == "__main__": # noqa: C901 + import argparse + import sys + + from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner + + parser = argparse.ArgumentParser(description="Test GGUF Q6_K linear lowering") + parser.add_argument( + "action", choices=["generate", "compare", "run", "list", "eager"] + ) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--rebuild", action="store_true") + parser.add_argument("--config", type=str, default=None) + args = parser.parse_args() + + if args.action == "eager": + _eager_sanity() + sys.exit(0) + + if args.rebuild and not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + configs = ( + GGUFLinearTest.get_test_configs() + GGUFLinearDynamicTest.get_test_configs() + ) + + if args.action == "list": + for cfg in configs: + print(f" {cfg.name}") + sys.exit(0) + + if args.config: + configs = [c for c in configs if c.name == args.config] + if not configs: + print(f"No config matching '{args.config}'") + sys.exit(1) + + passed = 0 + failed = 0 + failed_names: List[str] = [] + + for test in configs: + if args.action == "generate": + pte_path, _, _ = test.generate_test_files(verbose=args.verbose) + print(f"Generated: {pte_path}") + elif args.action == "compare": + actual_path = test.get_test_dir() / "actual_output.bin" + ok, msg = test.compare_with_actual(actual_path) + print(f"{'✓' if ok else '✗'} {test.name}: {msg}") + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + elif args.action == "run": + ok = test.run_test(verbose=args.verbose) + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + + if args.action in ("run", "compare"): + print(f"\nPassed: {passed}, Failed: {failed}") + if failed_names: + print(f"Failed: {', '.join(failed_names)}") + sys.exit(0 if failed == 0 else 1) diff --git a/backends/mlx/custom_kernel_ops/test/__init__.py b/backends/mlx/custom_kernel_ops/test/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/test/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/model_ops/test_gated_delta_rule.py b/backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py similarity index 98% rename from backends/mlx/model_ops/test_gated_delta_rule.py rename to backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py index 10dceef14b1..0a7e6a687f9 100644 --- a/backends/mlx/model_ops/test_gated_delta_rule.py +++ b/backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py @@ -10,18 +10,18 @@ Usage: # Run all configs: - python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run + python -m executorch.backends.mlx.custom_kernel_ops.test.test_gated_delta_rule run # Run with verbose output: - python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v + python -m executorch.backends.mlx.custom_kernel_ops.test.test_gated_delta_rule run -v # Rebuild C++ runner first: - python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run --rebuild + python -m executorch.backends.mlx.custom_kernel_ops.test.test_gated_delta_rule run --rebuild """ from typing import List, Tuple -import executorch.backends.mlx.model_ops.gated_delta_rule # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.gated_delta_rule # noqa: F401 import torch import torch.nn as nn diff --git a/backends/mlx/model_ops/test_tq4_compress.py b/backends/mlx/custom_kernel_ops/test/test_tq4_compress.py similarity index 94% rename from backends/mlx/model_ops/test_tq4_compress.py rename to backends/mlx/custom_kernel_ops/test/test_tq4_compress.py index c2aaa13afa7..ba114e67b23 100644 --- a/backends/mlx/model_ops/test_tq4_compress.py +++ b/backends/mlx/custom_kernel_ops/test/test_tq4_compress.py @@ -13,14 +13,14 @@ Usage:: - python -m executorch.backends.mlx.model_ops.test_tq4_compress run - python -m executorch.backends.mlx.model_ops.test_tq4_compress run -v - python -m executorch.backends.mlx.model_ops.test_tq4_compress run --rebuild + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq4_compress run + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq4_compress run -v + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq4_compress run --rebuild """ from typing import List, Tuple -import executorch.backends.mlx.model_ops.tq4_compress # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.tq4_compress # noqa: F401 import torch import torch.nn as nn diff --git a/backends/mlx/model_ops/test_tq_dequant.py b/backends/mlx/custom_kernel_ops/test/test_tq_dequant.py similarity index 93% rename from backends/mlx/model_ops/test_tq_dequant.py rename to backends/mlx/custom_kernel_ops/test/test_tq_dequant.py index 07d9deb895a..f50fad9b651 100644 --- a/backends/mlx/model_ops/test_tq_dequant.py +++ b/backends/mlx/custom_kernel_ops/test/test_tq_dequant.py @@ -15,14 +15,14 @@ Usage:: - python -m executorch.backends.mlx.model_ops.test_tq_dequant run - python -m executorch.backends.mlx.model_ops.test_tq_dequant run -v - python -m executorch.backends.mlx.model_ops.test_tq_dequant run --rebuild + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq_dequant run + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq_dequant run -v + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq_dequant run --rebuild """ from typing import List, Tuple -import executorch.backends.mlx.model_ops.tq_dequant # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.tq_dequant # noqa: F401 import torch import torch.nn as nn diff --git a/backends/mlx/model_ops/test_tq_norm.py b/backends/mlx/custom_kernel_ops/test/test_tq_norm.py similarity index 93% rename from backends/mlx/model_ops/test_tq_norm.py rename to backends/mlx/custom_kernel_ops/test/test_tq_norm.py index 35c4491d8ae..4f3b93a945f 100644 --- a/backends/mlx/model_ops/test_tq_norm.py +++ b/backends/mlx/custom_kernel_ops/test/test_tq_norm.py @@ -13,14 +13,14 @@ Usage:: - python -m executorch.backends.mlx.model_ops.test_tq_norm run - python -m executorch.backends.mlx.model_ops.test_tq_norm run -v - python -m executorch.backends.mlx.model_ops.test_tq_norm run --rebuild + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq_norm run + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq_norm run -v + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq_norm run --rebuild """ from typing import List, Tuple -import executorch.backends.mlx.model_ops.tq_norm # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.tq_norm # noqa: F401 import torch import torch.nn as nn diff --git a/backends/mlx/model_ops/tq4_compress.py b/backends/mlx/custom_kernel_ops/tq4_compress.py similarity index 98% rename from backends/mlx/model_ops/tq4_compress.py rename to backends/mlx/custom_kernel_ops/tq4_compress.py index f08d47b9a11..f957be379c0 100644 --- a/backends/mlx/model_ops/tq4_compress.py +++ b/backends/mlx/custom_kernel_ops/tq4_compress.py @@ -20,7 +20,7 @@ Usage:: - import executorch.backends.mlx.model_ops.tq4_compress # noqa: F401 + import executorch.backends.mlx.custom_kernel_ops.tq4_compress # noqa: F401 packed = torch.ops.mlx.tq4_compress(rotated, boundaries) # rotated: (..., D) float diff --git a/backends/mlx/model_ops/tq_dequant.py b/backends/mlx/custom_kernel_ops/tq_dequant.py similarity index 98% rename from backends/mlx/model_ops/tq_dequant.py rename to backends/mlx/custom_kernel_ops/tq_dequant.py index 28a168e9be0..0c1842712e4 100644 --- a/backends/mlx/model_ops/tq_dequant.py +++ b/backends/mlx/custom_kernel_ops/tq_dequant.py @@ -23,7 +23,7 @@ Usage:: - import executorch.backends.mlx.model_ops.tq_dequant # noqa: F401 + import executorch.backends.mlx.custom_kernel_ops.tq_dequant # noqa: F401 out = torch.ops.mlx.tq_dequant(packed, norms, centroids) # packed: (..., D/2) uint8 diff --git a/backends/mlx/model_ops/tq_norm.py b/backends/mlx/custom_kernel_ops/tq_norm.py similarity index 98% rename from backends/mlx/model_ops/tq_norm.py rename to backends/mlx/custom_kernel_ops/tq_norm.py index 7e6a4d657f3..e456c2f6aa4 100644 --- a/backends/mlx/model_ops/tq_norm.py +++ b/backends/mlx/custom_kernel_ops/tq_norm.py @@ -20,7 +20,7 @@ Usage:: - import executorch.backends.mlx.model_ops.tq_norm # noqa: F401 + import executorch.backends.mlx.custom_kernel_ops.tq_norm # noqa: F401 norms = torch.ops.mlx.tq_norm(x) # x: (..., D) bf16 diff --git a/backends/mlx/llm/turboquant_cache.py b/backends/mlx/llm/turboquant_cache.py index 7f2109ba074..b262876c481 100644 --- a/backends/mlx/llm/turboquant_cache.py +++ b/backends/mlx/llm/turboquant_cache.py @@ -25,11 +25,12 @@ from typing import Optional, Tuple +import executorch.backends.mlx.custom_kernel_ops.tq4_compress # noqa: F401 mlx::tq4_compress +import executorch.backends.mlx.custom_kernel_ops.tq_dequant # noqa: F401 mlx::tq_dequant +import executorch.backends.mlx.custom_kernel_ops.tq_norm # noqa: F401 mlx::tq_norm + # Register the MLX custom ops used by this cache. import executorch.backends.mlx.custom_ops # noqa: F401 mlx::custom_sdpa, mlx::kv_cache_update -import executorch.backends.mlx.model_ops.tq4_compress # noqa: F401 mlx::tq4_compress -import executorch.backends.mlx.model_ops.tq_dequant # noqa: F401 mlx::tq_dequant -import executorch.backends.mlx.model_ops.tq_norm # noqa: F401 mlx::tq_norm import torch diff --git a/backends/mlx/patterns.py b/backends/mlx/patterns.py index 5f74cbea643..dcc4f4d7d30 100644 --- a/backends/mlx/patterns.py +++ b/backends/mlx/patterns.py @@ -21,7 +21,9 @@ import torch from executorch.backends.mlx.builder.op_helpers import ( emit_quantized_biases, + emit_quantized_gather, emit_stop_position, + parse_dequant_int4_node, parse_dequant_node, parse_dequant_nvfp4_node, to_mlx_qparams, @@ -44,7 +46,6 @@ DequantizeNode, IndexCopyNode, IntOrVid, - IntOrVidOrTid, ModIntNode, MultiplyNode, QuantizedMatmulNode, @@ -53,13 +54,40 @@ SliceUpdateNode, SubtractIntNode, SymSizeNode, - TakeNode, TransposeNode, ) from torch.export.exported_program import ExportedProgram from torch.fx.node import Node +def _unpack_int4_to_intx_fields( + qdata_packed: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert ``Int4Tensor`` packed fields to the IntxUnpacked layout for + :func:`to_mlx_qparams`. + + Input is the torchao ``Int4Tensor`` layout: ``qdata_packed`` ``(N, K//2)`` uint8 + (two nibbles/byte, even index -> low nibble, unsigned [0, 15]) and ``scale`` / + ``zero_point`` ``(K // gs, N)`` (zero_point unsigned [0, 15]). + + Returns ``(qdata, scale, zero_point)`` where ``qdata`` is ``(N, K)`` int8 in + [-8, 7], and ``scale`` / ``zero_point`` are ``(N, K // gs)`` (zero_point + centered by -8). ``zero_point`` keeps its original (possibly fractional, e.g. + HQQ) dtype -- it is only used in :func:`to_mlx_qparams`'s float bias math, so + it must not be truncated to int. The affine identity ``scale * (q - z)`` is + preserved. + """ + p = qdata_packed.view(torch.uint8) + low = (p & 0x0F).to(torch.int8) + high = ((p >> 4) & 0x0F).to(torch.int8) + q = torch.stack([low, high], dim=-1).reshape(p.shape[0], -1) - 8 + scale_nk = scale.t().contiguous() + zero_point_nk = zero_point.t().contiguous() - 8 + return q, scale_nk, zero_point_nk + + @REGISTRY.register_pattern(name="INDEX_COPY") class IndexCopyHandler(PatternHandler): """ @@ -600,43 +628,18 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: [x_node, self.scale, self.per_tensor_scale, self.qdata] ) - ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(x)) - - # Gather quantized weights by indices - _, wq_sel = P.make_tmp_slot() - P.emit( - TakeNode( - x=P.slot_to_tid(qdata_slot), - index=ids_index, - out=P.slot_to_tid(wq_sel), - axis=0, - ) - ) - - # Gather scales by indices - _, sc_sel = P.make_tmp_slot() - P.emit( - TakeNode( - x=P.slot_to_tid(scales_slot), - index=ids_index, - out=P.slot_to_tid(sc_sel), - axis=0, - ) - ) - - # Dequantize the gathered slices out = P.make_or_get_slot(n) - P.emit( - DequantizeNode( - w=P.slot_to_tid(wq_sel), - scales=P.slot_to_tid(sc_sel), - out=P.slot_to_tid(out), - biases=None, - group_size=16, - bits=4, - mode="nvfp4", - dtype=torch_dtype_to_scalar_type(self.output_dtype), - ) + emit_quantized_gather( + P, + out, + x, + qdata_slot, + scales_slot, + None, + group_size=16, + bits=4, + mode="nvfp4", + out_dtype=self.output_dtype, ) if has_per_tensor_scale: @@ -1060,7 +1063,7 @@ def maybe_create( def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: assert n == self.head - w, x = n.args[0:2] + indices_node = n.args[1] qdata_target, qdata = P.get_placeholder_target_and_tensor(self.qdata) zero_point_target, zero_point = P.get_placeholder_target_and_tensor( @@ -1069,62 +1072,25 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: _, scale = P.get_placeholder_target_and_tensor(self.scale) Q, B = to_mlx_qparams(qdata, scale, zero_point, self.bits) - out_scalar_type = torch_dtype_to_scalar_type(self.out_dtype) - w = P.make_or_get_constant(f"{qdata_target}_to_packed", Q) - x, scale_slot = P.slot_map([x, self.scale]) + indices_slot, scale_slot = P.slot_map([indices_node, self.scale]) biases = emit_quantized_biases( P, zero_point_target, scale, zero_point, self.bits, B, scale_slot ) - ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(x)) - - # Gather quantized weights by ids - _, wq_sel = P.make_tmp_slot() - P.emit( - TakeNode( - x=P.slot_to_tid(w), - index=ids_index, - out=P.slot_to_tid(wq_sel), - axis=0, - ) - ) - - # Gather scales by ids - _, sc_sel = P.make_tmp_slot() - P.emit( - TakeNode( - x=P.slot_to_tid(scale_slot), - index=ids_index, - out=P.slot_to_tid(sc_sel), - axis=0, - ) - ) - - # Gather biases by ids - _, b_sel = P.make_tmp_slot() - P.emit( - TakeNode( - x=P.slot_to_tid(biases), - index=ids_index, - out=P.slot_to_tid(b_sel), - axis=0, - ) - ) - # Dequantize the gathered slices out = P.make_or_get_slot(n) - P.emit( - DequantizeNode( - w=P.slot_to_tid(wq_sel), - scales=P.slot_to_tid(sc_sel), - out=P.slot_to_tid(out), - biases=P.slot_to_tid(b_sel), - group_size=self.group_size, - bits=self.bits, - mode="affine", - dtype=out_scalar_type, - ) + emit_quantized_gather( + P, + out, + indices_slot, + w, + scale_slot, + biases, + group_size=self.group_size, + bits=self.bits, + mode="affine", + out_dtype=self.out_dtype, ) return out @@ -1228,3 +1194,174 @@ def __call__(self, P, n): ) return out + + +@REGISTRY.register_pattern(name="INT4_QUANTIZED_LINEAR") +class Int4QuantizedLinearHandler(PatternHandler): + """Fuse dequantize_int4_tensor + linear into QuantizedMatmulNode(mode="affine"). + + Matches:: + + linear(x, dequantize_int4_tensor(qdata, scale, zero_point, group_size), bias) + + The nibble-packed Int4 weight is unpacked and repacked into MLX 4-bit qparams + at export time. + """ + + def __init__(self, head, body, qdata, scale, zero_point, group_size, out_dtype): + super().__init__(head, body) + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.group_size = group_size + self.out_dtype = out_dtype + + _MIN_FUSED_GROUP_SIZE = 32 + + @staticmethod + def _allow_non_fused() -> bool: + return os.environ.get("ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS", "0") == "1" + + @classmethod + def maybe_create(cls, ep, head): + if not match_target(head, torch.ops.aten.linear.default): + return None + if len(head.args) < 2 or not isinstance(head.args[1], Node): + return None + dequant = head.args[1] + if not has_single_user(dequant): + return None + parsed = parse_dequant_int4_node(dequant) + if parsed is None: + return None + qdata, scale, zero_point, group_size, out_dtype = parsed + return cls(head, [dequant], qdata, scale, zero_point, group_size, out_dtype) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + x_node = n.args[0] + b_node = n.args[2] if len(n.args) > 2 else None + + qdata_target, qdata_packed = P.get_placeholder_target_and_tensor(self.qdata) + zp_target, zero_point = P.get_placeholder_target_and_tensor(self.zero_point) + _, scale = P.get_placeholder_target_and_tensor(self.scale) + + q, scale_nk, zp = _unpack_int4_to_intx_fields(qdata_packed, scale, zero_point) + Q, B = to_mlx_qparams(q, scale_nk, zp, 4) + + w = P.make_or_get_constant(f"{qdata_target}_int4_to_packed", Q) + scale_slot = P.make_or_get_constant(f"{qdata_target}_int4_scales", scale_nk) + biases = emit_quantized_biases(P, zp_target, scale_nk, zp, 4, B, scale_slot) + + x_slot, b_slot = P.slot_map([x_node, b_node]) + out_dtype = ( + x_node.meta["val"].dtype if self.out_dtype is None else self.out_dtype + ) + needs_cast = out_dtype != x_node.meta["val"].dtype + + if self.group_size < self._MIN_FUSED_GROUP_SIZE and not self._allow_non_fused(): + raise ValueError( + f"Int4 quantized linear with group_size={self.group_size} requires " + f"the non-fused path; set ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS=1." + ) + + out = P.make_or_get_slot(n) + P.emit( + QuantizedMatmulNode( + x=P.slot_to_tid(x_slot), + w=P.slot_to_tid(w), + scales=P.slot_to_tid(scale_slot), + biases=P.slot_to_tid(biases), + out=P.slot_to_tid(out), + group_size=self.group_size, + bits=4, + mode="affine", + transpose=True, + ) + ) + + if b_node is not None: + P.emit( + AddNode( + a=P.slot_to_tid(out), + b=P.slot_to_tid(b_slot), + out=P.slot_to_tid(out), + ) + ) + + if needs_cast: + P.emit( + AsTypeNode( + x=P.slot_to_tid(out), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(out_dtype), + ) + ) + + return out + + +@REGISTRY.register_pattern(name="INT4_QUANTIZED_EMBEDDING") +class Int4QuantizedEmbeddingHandler(PatternHandler): + """Fuse dequantize_int4_tensor + embedding into gather + DequantizeNode(affine). + + Matches:: + + embedding(dequantize_int4_tensor(qdata, scale, zero_point, group_size), ids) + """ + + def __init__(self, head, body, qdata, scale, zero_point, group_size, out_dtype): + super().__init__(head, body) + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.group_size = group_size + self.out_dtype = out_dtype + + @classmethod + def maybe_create(cls, ep, head): + if not match_target(head, torch.ops.aten.embedding.default): + return None + if len(head.args) < 2 or not isinstance(head.args[0], Node): + return None + dequant = head.args[0] + if not has_single_user(dequant): + return None + parsed = parse_dequant_int4_node(dequant) + if parsed is None: + return None + qdata, scale, zero_point, group_size, out_dtype = parsed + return cls(head, [dequant], qdata, scale, zero_point, group_size, out_dtype) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + indices_node = n.args[1] + + qdata_target, qdata_packed = P.get_placeholder_target_and_tensor(self.qdata) + zp_target, zero_point = P.get_placeholder_target_and_tensor(self.zero_point) + _, scale = P.get_placeholder_target_and_tensor(self.scale) + + q, scale_nk, zp = _unpack_int4_to_intx_fields(qdata_packed, scale, zero_point) + Q, B = to_mlx_qparams(q, scale_nk, zp, 4) + + w = P.make_or_get_constant(f"{qdata_target}_int4_to_packed", Q) + scale_slot = P.make_or_get_constant(f"{qdata_target}_int4_scales", scale_nk) + biases = emit_quantized_biases(P, zp_target, scale_nk, zp, 4, B, scale_slot) + + (indices_slot,) = P.slot_map([indices_node]) + out_dtype = scale.dtype if self.out_dtype is None else self.out_dtype + + out = P.make_or_get_slot(n) + emit_quantized_gather( + P, + out, + indices_slot, + w, + scale_slot, + biases, + group_size=self.group_size, + bits=4, + mode="affine", + out_dtype=out_dtype, + ) + return out diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 34fd8815ba8..8563ff339a7 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -990,8 +990,8 @@ inline void exec_metal_kernel( n.name, n.input_names, n.output_names, - n.source, - n.header, + n.source ? *n.source : std::string{}, + n.header ? *n.header : std::string{}, n.ensure_row_contiguous, n.atomic_outputs); @@ -1837,6 +1837,8 @@ class Interpreter { st.begin_op(idx, op_name(instr.op)); if (instr.op == OpCode::SCAN) { exec_scan(prog, std::get(instr.node), st, stream); + } else if (instr.op == OpCode::IF) { + exec_if(prog, std::get(instr.node), st, stream); } else { dispatch(instr, st, stream); } @@ -1846,6 +1848,20 @@ class Interpreter { } private: + void exec_if( + const MLXProgram& prog, + const IfNode& n, + ExecutionState& st, + StreamOrDevice s) const { + // Select one branch at runtime based on the integer condition. + // Nonzero -> then_chain, zero -> else_chain. The selected chain's + // instructions write the output slot(s) directly. + const int64_t cond = resolve_int(n.cond, st); + const uint32_t chain_idx = + (cond != 0) ? n.then_chain_idx : n.else_chain_idx; + run_chain(prog, chain_idx, st, s); + } + void exec_scan( const MLXProgram& prog, const ScanNode& n, diff --git a/backends/mlx/serialization/MLXLoader.cpp.tmpl b/backends/mlx/serialization/MLXLoader.cpp.tmpl index aa4716d7a4a..7017988d271 100644 --- a/backends/mlx/serialization/MLXLoader.cpp.tmpl +++ b/backends/mlx/serialization/MLXLoader.cpp.tmpl @@ -62,7 +62,8 @@ std::vector to_vector(const flatbuffers::Vector* fb_vec) { // load_instruction - AUTO-GENERATED switch statement // ============================================================================= -Instruction load_instruction(const mlx_delegate::Instruction* fb_instr) { +Instruction load_instruction( + const mlx_delegate::Instruction* fb_instr, StringPool& strpool) { Instruction instr; if (!fb_instr || !fb_instr->op()) { @@ -142,6 +143,10 @@ MLXProgram load_program(const void* data, size_t size) { check_collection_size(program.num_tensors(), "num_tensors()"); check_collection_size(program.num_values, "num_values"); + // Pool shared across all chains so identical kernel source/header blobs are + // interned once for the whole program. + StringPool strpool; + if (fb_graph->instruction_chains()) { check_collection_size(fb_graph->instruction_chains()->size(), "instruction_chains"); program.instruction_chains.reserve(fb_graph->instruction_chains()->size()); @@ -152,7 +157,7 @@ MLXProgram load_program(const void* data, size_t size) { check_collection_size(fb_chain->instructions()->size(), "instructions in chain"); chain.reserve(fb_chain->instructions()->size()); for (size_t i = 0; i < fb_chain->instructions()->size(); ++i) { - chain.push_back(load_instruction(fb_chain->instructions()->Get(static_cast(i)))); + chain.push_back(load_instruction(fb_chain->instructions()->Get(static_cast(i)), strpool)); } } program.instruction_chains.push_back(std::move(chain)); diff --git a/backends/mlx/serialization/MLXLoader.h.tmpl b/backends/mlx/serialization/MLXLoader.h.tmpl index 0930d5e00e1..8bee2c23bc8 100644 --- a/backends/mlx/serialization/MLXLoader.h.tmpl +++ b/backends/mlx/serialization/MLXLoader.h.tmpl @@ -4,9 +4,11 @@ #include #include +#include #include #include #include +#include #include #include @@ -330,8 +332,27 @@ inline SlotVariant convert_slot_variant(const mlx_delegate::SlotVariant* fb) { return SlotVariant{fb->idx(), convert_slot_type(fb->slot_type())}; } +// Interns FlatBuffer strings by pointer so identical kernel source/header +// blobs (deduplicated to a single offset by the serializer) share one +// std::string in memory. Buffers written without string sharing simply get +// one entry per node — correct, just not deduplicated. +struct StringPool { + std::unordered_map> map; + std::shared_ptr intern(const flatbuffers::String* s) { + if (!s) { + return nullptr; + } + auto& slot = map[static_cast(s)]; + if (!slot) { + slot = std::make_shared(s->str()); + } + return slot; + } +}; + // Load an instruction from FlatBuffer -Instruction load_instruction(const mlx_delegate::Instruction* fb_instr); +Instruction load_instruction( + const mlx_delegate::Instruction* fb_instr, StringPool& strpool); // Load the full MLXProgram from FlatBuffer data MLXProgram load_program(const void* data, size_t size); diff --git a/backends/mlx/serialization/generate.py b/backends/mlx/serialization/generate.py index db3d4cd2d49..fd0b5b672b0 100755 --- a/backends/mlx/serialization/generate.py +++ b/backends/mlx/serialization/generate.py @@ -627,6 +627,16 @@ def generate_python_serializers(schema: FBSSchema) -> str: " return builder.EndVector()", "", "", + "def _shared_string(builder: flatbuffers.Builder, s):", + ' """CreateString with per-buffer dedup so identical strings share one offset."""', + " if s is None:", + " return None", + " # flatbuffers' Builder dedups identical strings via its built-in", + " # sharedStrings cache; fall back to CreateString on old flatbuffers.", + ' create = getattr(builder, "CreateSharedString", None) or builder.CreateString', + " return create(s)", + "", + "", "class GeneratedOpBuilders:", ' """Mixin class with auto-generated op builder methods."""', "", @@ -714,7 +724,7 @@ def generate_python_serializers(schema: FBSSchema) -> str: " self, builder: flatbuffers.Builder, vec: List[str]", " ) -> int:", ' """Pre-build a vector of strings (offsets must be created before table Start)."""', - " offsets = [builder.CreateString(s) for s in vec]", + " offsets = [_shared_string(builder, s) for s in vec]", " builder.StartVector(4, len(offsets), 4)", " for off in reversed(offsets):", " builder.PrependUOffsetTRelative(off)", @@ -800,12 +810,12 @@ def _generate_op_builder_method(table: FBSTable) -> str: } _PY_PREBUILD_OFFSET = { - "str": "builder.CreateString(op.{name})", + "str": "_shared_string(builder, op.{name})", "int_or_vid": "self._build_int_or_vid(builder, op.{name})", "float_or_vid": "self._build_float_or_vid(builder, op.{name})", "vid_or_tid": "self._build_vid_or_tid(builder, op.{name})", "int_or_vid_or_tid": "self._build_int_or_vid_or_tid(builder, op.{name})", - "optional_str": "builder.CreateString(op.{name}) if op.{name} is not None else None", + "optional_str": "_shared_string(builder, op.{name})", } @@ -996,6 +1006,19 @@ def generate_cpp_loader_h(schema: FBSSchema) -> str: return header + result +def _is_interned_str(table, field_name) -> bool: + """Whether a string field should be loaded as an interned shared_ptr. + + Only large, frequently-duplicated kernel blobs (MetalKernelNode source/ + header) are interned so identical text shares one std::string at runtime. + """ + return ( + table is not None + and getattr(table, "name", None) == "MetalKernelNode" + and field_name in ("source", "header") + ) + + def _fbs_type_to_cpp( fbs_type: str, required: bool, @@ -1023,6 +1046,10 @@ def _fbs_type_to_cpp( cpp_type = FBS_TO_CPP.get(fbs_type, fbs_type) + # Interned strings (deduped + shared at load time) use a shared_ptr handle. + if _is_interned_str(table, fld.name if fld is not None else None): + return "std::shared_ptr" + # Handle optional types if not required: if fbs_type == "Tid": @@ -1113,7 +1140,7 @@ def _generate_loader_case(table: FBSTable) -> List[str]: fb_field_name = fld.name kind = _get_field_kind(fld, table) - load_lines = _emit_cpp_load(kind, fld.name, fb_field_name) + load_lines = _emit_cpp_load(kind, fld.name, fb_field_name, table) if load_lines is None: raise ValueError( f"Unhandled field kind '{kind}' for field '{fld.name}' in table '{table.name}'. " @@ -1145,8 +1172,13 @@ def _generate_loader_case(table: FBSTable) -> List[str]: } -def _emit_cpp_load(kind: str, name: str, fb_name: str) -> "List[str] | None": +def _emit_cpp_load( + kind: str, name: str, fb_name: str, table=None +) -> "List[str] | None": """Emit C++ load lines for a field kind, or None if kind is unrecognized.""" + # Interned string fields share one std::string via the load-time pool. + if _is_interned_str(table, name) and kind in ("str", "optional_str"): + return [f" node.{name} = strpool.intern(fb->{fb_name}());"] # Required struct / compound via converter if kind in _CPP_CONVERTER: conv = _CPP_CONVERTER[kind] diff --git a/backends/mlx/serialization/mlx_graph_serialize.py b/backends/mlx/serialization/mlx_graph_serialize.py index db5acc9048f..26c562dd7e8 100644 --- a/backends/mlx/serialization/mlx_graph_serialize.py +++ b/backends/mlx/serialization/mlx_graph_serialize.py @@ -31,6 +31,7 @@ # Import auto-generated serializers from executorch.backends.mlx.serialization._generated_serializers import ( + _shared_string, GeneratedOpBuilders, ) from executorch.backends.mlx.serialization.mlx_graph_schema import ( # noqa: F401 @@ -85,7 +86,7 @@ def _build_int_or_vid(builder: flatbuffers.Builder, iov: IntOrVid) -> int: def _build_string(builder: flatbuffers.Builder, s: str) -> int: - return builder.CreateString(s) + return _shared_string(builder, s) def _build_int_vector(builder: flatbuffers.Builder, vec: List[int]) -> int: @@ -188,7 +189,7 @@ def _build_flatbuffer(self) -> bytes: tensor_meta_vec = self._build_offset_vector(builder, tensor_meta_offsets) # 5. Build version string (must be created before the table that uses it) - version_off = builder.CreateString(self.graph.version) + version_off = _shared_string(builder, self.graph.version) # 6. Build the root MLXGraph table from executorch.backends.mlx.serialization._generated.mlx_delegate import ( @@ -280,7 +281,7 @@ def _build_slot_variant( return FBSlotVariantModule.End(builder) def _build_named_slot(self, builder: flatbuffers.Builder, ns: NamedSlot) -> int: - name_off = builder.CreateString(ns.name) + name_off = _shared_string(builder, ns.name) slot_off = self._build_slot_variant(builder, ns.slot) from executorch.backends.mlx.serialization._generated.mlx_delegate import ( diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 3c02e5785ce..42c53e5172b 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -976,6 +976,15 @@ table ScanNode { scan_axis: int32 = 1; // dimension to iterate over } +// Runtime conditional: select one of two instruction chains based on a runtime +// integer condition. The selected branch writes its output slot(s) directly, so +// no `outputs` field is needed (unlike ScanNode, which post-processes/stacks). +table IfNode { + cond: IntOrVid (required); // nonzero -> then_chain, zero -> else_chain + then_chain_idx: uint32; // index into MLXGraph.instruction_chains + else_chain_idx: uint32; // index into MLXGraph.instruction_chains +} + // Custom Metal kernel execution via mlx::core::fast::metal_kernel(). // Two-phase API: // 1. Factory: metal_kernel(name, input_names, output_names, source, header, @@ -1151,7 +1160,8 @@ union OpNode { RollNode, BitwiseAndNode, BitwiseOrNode, - BitwiseXorNode + BitwiseXorNode, + IfNode // BC: Add new op nodes here (append only) } diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 9d07af84268..6ba17cccda7 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -7402,3 +7402,158 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: self.batch_size, self.seq_len, self.in_features, dtype=self.dtype ) return (x,) + + +def _make_int4_quantized_weight(weight: torch.Tensor, group_size: int) -> torch.Tensor: + """Groupwise affine 4-bit quantize a ``(N, K)`` weight into an + ``ExportableInt4Tensor`` (torchao ``Int4Tensor`` packed layout).""" + from executorch.extension.llm.export.int4 import ExportableInt4Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + N, K = weight.shape + dtype = weight.dtype + w = weight.float().reshape(N, K // group_size, group_size) + wmin = w.amin(dim=-1) + wmax = w.amax(dim=-1) + scale = ((wmax - wmin) / 15.0).clamp(min=1e-8) + # Fractional zero-point (HQQ-style), exercises the float zero_point repack path. + zero = (-wmin / scale).clamp(0, 15) + q = torch.round(w / scale.unsqueeze(-1) + zero.unsqueeze(-1)).clamp(0, 15) + q = q.reshape(N, K).to(torch.uint8) + # Two nibbles/byte: even index -> low nibble. + packed = (q[:, 0::2] | (q[:, 1::2] << 4)).to(torch.uint8) + it = Int4Tensor( + qdata=packed, + scale=scale.t().contiguous().to(dtype), + zero_point=zero.t().contiguous().to(dtype), + block_size=[1, group_size], + shape=torch.Size([N, K]), + ) + return ExportableInt4Tensor.from_int4_tensor(it) + + +class Int4QuantizedLinearModel(nn.Module): + """Linear layer whose weight is an ``ExportableInt4Tensor``.""" + + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +@register_test +class Int4QuantizedLinearTest(OpTestCase): + """ExportableInt4Tensor nn.Linear -> MLX 4-bit affine quantized matmul.""" + + name = "int4_quantized_linear" + rtol = 0.1 + atol = 0.1 + + def __init__( + self, + in_features: int = 64, + out_features: int = 128, + batch_size: int = 2, + seq_len: int = 16, + bias: bool = True, + group_size: int = 32, + dtype: torch.dtype = torch.bfloat16, + ): + self.in_features = in_features + self.out_features = out_features + self.batch_size = batch_size + self.seq_len = seq_len + self.bias = bias + self.group_size = group_size + self.dtype = dtype + + parts = ["int4_quantized_linear", f"g{group_size}"] + if not bias: + parts.append("no_bias") + if dtype != torch.bfloat16: + parts.append(str(dtype).split(".")[-1]) + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["Int4QuantizedLinearTest"]: + return [ + cls(), + cls(bias=False), + cls(group_size=64), + cls(dtype=torch.float32), + ] + + def get_edge_compile_config(self): + from executorch.exir import EdgeCompileConfig + + return EdgeCompileConfig(_check_ir_validity=False) + + def create_model(self) -> nn.Module: + model = Int4QuantizedLinearModel( + self.in_features, self.out_features, bias=self.bias + ).to(self.dtype) + model.linear.weight = nn.Parameter( + _make_int4_quantized_weight(model.linear.weight.data, self.group_size), + requires_grad=False, + ) + return model + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, self.seq_len, self.in_features, dtype=self.dtype + ) + return (x,) + + +@register_test +class Int4QuantizedEmbeddingTest(OpTestCase): + """ExportableInt4Tensor nn.Embedding -> MLX 4-bit affine quantized gather.""" + + name = "int4_quantized_embedding" + rtol = 0.1 + atol = 0.1 + + def __init__( + self, + num_embeddings: int = 1000, + embedding_dim: int = 128, + batch_size: int = 2, + seq_len: int = 16, + group_size: int = 32, + dtype: torch.dtype = torch.bfloat16, + ): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.batch_size = batch_size + self.seq_len = seq_len + self.group_size = group_size + self.dtype = dtype + self.name = f"int4_quantized_embedding_g{group_size}" + + @classmethod + def get_test_configs(cls) -> List["Int4QuantizedEmbeddingTest"]: + return [ + cls(), + cls(group_size=64), + cls(group_size=128), + ] + + def get_edge_compile_config(self): + from executorch.exir import EdgeCompileConfig + + return EdgeCompileConfig(_check_ir_validity=False) + + def create_model(self) -> nn.Module: + model = EmbeddingModel(self.num_embeddings, self.embedding_dim) + model = model.to(self.dtype) + model.embedding.weight = nn.Parameter( + _make_int4_quantized_weight(model.embedding.weight.data, self.group_size), + requires_grad=False, + ) + return model + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randint(0, self.num_embeddings, (self.batch_size, self.seq_len)) + return (x,) diff --git a/backends/mlx/test/test_serialization_dedup.py b/backends/mlx/test/test_serialization_dedup.py new file mode 100644 index 00000000000..e28e4613384 --- /dev/null +++ b/backends/mlx/test/test_serialization_dedup.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Serializer string-dedup regression test. + +MetalKernelNode ``source``/``header`` blobs are large and repeated once per +layer. The serializer routes every string through ``_shared_string`` so +identical text is written into the FlatBuffer exactly once (multiple fields +share a single offset). The loader then interns those shared offsets into one +``std::shared_ptr`` per unique blob, so this dedup also +shrinks runtime memory for newly-produced ``.pte`` files. + +This test pins the serializer half of that behavior. +""" + +import unittest + +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + Instruction, + InstructionChain, + IntOrVid, + MetalKernelNode, + MLXGraph, + Tid, +) +from executorch.backends.mlx.serialization.mlx_graph_serialize import ( + serialize_mlx_graph, +) + + +def _graph(nodes): + chain = InstructionChain(instructions=[Instruction(op=n) for n in nodes]) + return MLXGraph( + instruction_chains=[chain], + version="test", + input_map=[], + output_map=[], + mutable_buffer_map=[], + named_slots=[], + tensor_meta=[], + ) + + +def _kernel(source, header=None): + return MetalKernelNode( + name="gguf_q6k_matmul", + source=source, + inputs=[Tid(0)], + outputs=[Tid(1)], + grid=[IntOrVid(literal=1)], + threadgroup=[IntOrVid(literal=1)], + header=header, + input_names=["x"], + output_names=["out"], + ) + + +class TestSerializationStringDedup(unittest.TestCase): + def test_identical_source_header_written_once(self): + source = "KERNEL_SOURCE_MARKER_" + "x" * 2000 + header = "KERNEL_HEADER_MARKER_" + "y" * 2000 + + nodes = [_kernel(source, header) for _ in range(5)] + buf = serialize_mlx_graph(_graph(nodes)) + + self.assertEqual(buf.count(source.encode()), 1) + self.assertEqual(buf.count(header.encode()), 1) + + def test_distinct_sources_not_merged(self): + base = "KERNEL_SOURCE_MARKER_" + "x" * 2000 + nodes = [_kernel(base + str(i)) for i in range(3)] + buf = serialize_mlx_graph(_graph(nodes)) + + # Each distinct source must still appear (the common prefix appears once + # per distinct string since the suffixes differ). + self.assertEqual(buf.count(base.encode()), 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/backends/mlx/test/test_utils.py b/backends/mlx/test/test_utils.py index 5dbc35b824d..1a964bea935 100644 --- a/backends/mlx/test/test_utils.py +++ b/backends/mlx/test/test_utils.py @@ -883,6 +883,16 @@ def get_test_dir(self) -> Path: test_dir.mkdir(parents=True, exist_ok=True) return test_dir + def compute_expected_outputs(self, model, test_inputs): + """Reference outputs the device result is compared against. + + Defaults to the eager ``model`` forward. Override to supply a + higher-precision reference -- e.g. fp32 accumulation matching a kernel + that accumulates in fp32, so bf16 reference noise doesn't dominate the + comparison. + """ + return model(*test_inputs) + def generate_test_files(self, verbose: bool = False) -> Tuple[Path, Path, Path]: """ Generate .pte, input.bin, and expected_output.bin files. @@ -915,7 +925,7 @@ def generate_test_files(self, verbose: bool = False) -> Tuple[Path, Path, Path]: with torch.no_grad(): if isinstance(test_inputs, torch.Tensor): test_inputs = (test_inputs,) - expected_outputs = model(*test_inputs) + expected_outputs = self.compute_expected_outputs(model, test_inputs) if isinstance(expected_outputs, torch.Tensor): expected_outputs = [expected_outputs] else: diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index ed3dcdba9c3..64e55319490 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -306,6 +306,12 @@ def _export_mlx( """ import gc + # Register the GGUF dequant op + MLX GGUF pattern handlers so quantized GGUF + # weights lower to the fused Q6_K kernels / Q4_K quantized matmul. + import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 + import executorch.extension.llm.export.gguf # noqa: F401 + import executorch.extension.llm.export.int4 # noqa: F401 + from executorch.backends.mlx import MLXPartitioner from executorch.backends.mlx.passes import get_default_passes @@ -471,18 +477,13 @@ def main() -> None: backend=args.backend, ) - if args.gguf and args.backend == "mlx": - os.environ["ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS"] = "1" - try: - export_and_lower( - model, - config, - args.output_dir, - backend=args.backend, - use_turboquant=args.turboquant, - ) - finally: - os.environ.pop("ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS", None) + export_and_lower( + model, + config, + args.output_dir, + backend=args.backend, + use_turboquant=args.turboquant, + ) if __name__ == "__main__": diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 35dddb5a0dc..5d7c5ec540d 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -6,9 +6,19 @@ """Load a GGUF file into a Gemma 4 31B model. -Streams tensors one at a time via ``iter_gguf_tensors`` for low peak -memory, remaps GGUF names to model FQNs, handles tied embed/lm_head, -and packs for the target backend. +Streams tensors one at a time via the shared loader in +``extension/llm/export/gguf.py`` (each quantized weight arrives as an +``ExportableGGUFTensor`` wrapping the raw GGUF blob), remaps GGUF names to model +FQNs, handles the tied embed/lm_head, and converts each weight for the target +backend: + +* **MLX**: every quantized weight stays an ``ExportableGGUFTensor`` and is lowered + by the MLX GGUF pattern (Q6_K custom kernels, Q4_K native affine ops) for both + linear and embedding. ``embed_tokens`` and ``lm_head`` stay tied -- they share + the one quantized tensor. +* **CUDA**: Q4_K -> ``Int4Tensor``, Q6_K -> ``IntxUnpackedToInt8Tensor``; + ``lm_head`` keeps the quantized tensor but the token embedding is dequantized to + bf16 (``Int4Tensor`` can't gather), so they are untied. Usage: model, config = load_gguf_model("model.gguf", backend="cuda") @@ -65,24 +75,6 @@ def gguf_to_model_key(gguf_key: str) -> Optional[str]: return None -def _resolve_tied_lm_head(model, embed_quant, packers): - """Handle tied embed/lm_head after streaming all tensors.""" - from executorch.examples.models.gemma4_31b.quant import pack_one - - lm_head = getattr(model.lm_head, "weight", None) - if lm_head is None or lm_head.device.type != "meta": - return - if embed_quant is not None: - pack_one(model, "lm_head.weight", embed_quant, packers) - else: - pack_one( - model, - "lm_head.weight", - model.embed_tokens.weight.data.clone(), - packers, - ) - - def _validate_no_meta(model): """Ensure all parameters have been loaded.""" for fqn, p in model.named_parameters(): @@ -95,28 +87,57 @@ def _validate_no_meta(model): p.requires_grad_(False) +def _convert_weight(model, model_key: str, gtensor, backend: str): + """Convert an ``ExportableGGUFTensor`` to the per-backend module weight.""" + if backend == "mlx": + return gtensor + # CUDA: native torchao quantized tensors. + if gtensor.ggml_type == "q4_k": + return gtensor.to_int4_tensor() + return gtensor.to_intx_unpacked_to_int8_tensor() + + +def _resolve_tied_lm_head(model, lm_head_weight, packers): + """Assign a tied lm_head (GGUF ties it to the token embedding).""" + from executorch.examples.models.gemma4_31b.quant import pack_one + + lm_head = getattr(model.lm_head, "weight", None) + if lm_head is None or lm_head.device.type != "meta": + return + if lm_head_weight is not None: + pack_one(model, "lm_head.weight", lm_head_weight, packers) + else: + pack_one( + model, "lm_head.weight", model.embed_tokens.weight.data.clone(), packers + ) + + def load_gguf_model( gguf_path: str, max_seq_len: int = 4096, backend: str = "cuda", + config=None, ) -> tuple: - """Load a GGUF file, remap keys, and pack for the target backend. + """Load a GGUF file, remap keys, and convert weights for the target backend. - Streams tensors one at a time for low peak memory. + Streams tensors one at a time for low peak memory. GGUF ties ``embed_tokens`` + and ``lm_head``: on MLX they stay tied (one shared quantized tensor); on CUDA + they are untied so the embedding can be dequantized for the gather while + ``lm_head`` keeps its quantization. See the module docstring for the + per-backend conversion details. - GGUF ties ``embed_tokens`` and ``lm_head`` into a single Q4_K tensor. - We untie them so ``lm_head`` keeps the original Q4_K quantization. - On CUDA, the embedding is dequantized to bf16 because ``Int4Tensor`` - does not support the gather op that ``nn.Embedding`` requires. On - MLX, the embedding stays quantized — ``QuantizedEmbeddingHandler`` - handles quantized gather natively. + ``config`` defaults to the full Gemma 4 31B config; pass a smaller + ``Gemma4_31BConfig`` (e.g. in tests) to load a GGUF for a tiny model. Returns ``(model, config)``. """ - from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig + from executorch.examples.models.gemma4_31b.model import ( + Gemma4_31B, + Gemma4_31BConfig, + materialize_runtime_buffers, + ) from executorch.examples.models.gemma4_31b.quant import dequantize_weight, pack_one - from executorch.examples.models.gemma4_31b.quant.gguf import iter_gguf_tensors - from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + from executorch.extension.llm.export.gguf import ExportableGGUFTensor, iter_gguf if backend == "cuda": from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS @@ -129,37 +150,46 @@ def load_gguf_model( else: raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda', 'mlx'.") - config = Gemma4_31BConfig(max_seq_len=max_seq_len) + if config is None: + config = Gemma4_31BConfig(max_seq_len=max_seq_len) print("Building model on meta device...") with torch.device("meta"): model = Gemma4_31B(config) - embed_quant = None + lm_head_weight = None # weight reused for a tied lm_head n_processed = 0 print(f"Streaming GGUF from {gguf_path}...") - for gguf_name, result in iter_gguf_tensors(gguf_path): + for gguf_name, value in iter_gguf(gguf_path): model_key = gguf_to_model_key(gguf_name) if model_key is None: continue - if type(result) is torch.Tensor and result.dtype == torch.float32: - result = result.to(torch.bfloat16) - - if model_key == "embed_tokens.weight" and isinstance(result, Int4Tensor): - embed_quant = result - if backend == "cuda": - result = dequantize_weight(result, torch.bfloat16) + if isinstance(value, ExportableGGUFTensor): + weight = _convert_weight(model, model_key, value, backend) + if model_key == "embed_tokens.weight": + # Tied lm_head reuses the embedding weight: MLX wants the raw + # ExportableGGUFTensor (linear pattern), CUDA the quant tensor. + lm_head_weight = value if backend == "mlx" else weight + if backend == "cuda": + weight = dequantize_weight(weight, torch.bfloat16) + value = weight + elif value.dtype == torch.float32: + value = value.to(torch.bfloat16) - pack_one(model, model_key, result, packers) + pack_one(model, model_key, value, packers) n_processed += 1 if n_processed % 100 == 0: print(f" Processed {n_processed} tensors...") - _resolve_tied_lm_head(model, embed_quant, packers) - del embed_quant + _resolve_tied_lm_head(model, lm_head_weight, packers) + + # Fill RoPE tables / KV caches / scalar constants (left on meta by the + # streaming load), matching load_prequantized_model so the CUDA and eager + # forward paths get bf16 runtime buffers instead of float32 defaults. + materialize_runtime_buffers(model, dtype=torch.bfloat16) _validate_no_meta(model) model.eval() diff --git a/examples/models/gemma4_31b/model.md b/examples/models/gemma4_31b/model.md index 13207bdbb06..32f407c6b40 100644 --- a/examples/models/gemma4_31b/model.md +++ b/examples/models/gemma4_31b/model.md @@ -154,8 +154,10 @@ Modules in `quant/`: packers dispatch by module type (`nn.Linear`, `nn.Embedding`). CUDA passes Int4Tensor through (dispatch handled by `int4_dispatch.py`); MLX converts Int4Tensor → IntxUnpackedToInt8Tensor and regroups per-axis embeddings. -- **GGUF** (`gguf.py`): `unpack_gguf_tensor` / `iter_gguf_tensors` for - loading community-quantized GGUF files (Q4_K, Q6_K). +- **GGUF**: community-quantized GGUF files (Q4_K, Q6_K) are loaded by the + shared, backend-agnostic `extension/llm/export/gguf.py` (`load_gguf` / + `iter_gguf` → `ExportableGGUFTensor`); `gguf_loader.py` remaps GGUF names to + model FQNs and picks the per-backend weight representation. The quantize-once flow: diff --git a/examples/models/gemma4_31b/quant/README.md b/examples/models/gemma4_31b/quant/README.md index 92ddbf97243..8906a0faede 100644 --- a/examples/models/gemma4_31b/quant/README.md +++ b/examples/models/gemma4_31b/quant/README.md @@ -11,7 +11,9 @@ Quantization framework: **recipe → quantize → pack**. | `pack.py` | **Packing dispatch** — `pack_model` (bulk) and `pack_one` (streaming) | — | | `pack_cuda.py` | **CUDA packing** — passes Int4Tensor/IntxUnpacked through for CUDA dispatch | pack | | `pack_mlx.py` | **MLX packing** — converts Int4Tensor → IntxUnpacked, regroups per-axis embeddings | pack | -| `gguf.py` | **GGUF import** — unpacks Q4_K/Q6_K blocks to torchao subclasses | torchao | + +GGUF import (unpacking Q4_K/Q6_K blocks) now lives in the shared +`extension/llm/export/gguf.py`. ## Data flow @@ -49,4 +51,4 @@ The format is compatible with torchao's `save_pretrained` / `load_pretrained`. ## TODO - `pack_metal.py` — Metal backend packer. -- `gguf.py` — extend with Q5_K, Q8_0 GGUF quant types. +- GGUF quant types (Q5_K, Q8_0): extend `extension/llm/export/gguf.py`. diff --git a/examples/models/gemma4_31b/quant/gguf.py b/examples/models/gemma4_31b/quant/gguf.py deleted file mode 100644 index 78c3aa3d8f9..00000000000 --- a/examples/models/gemma4_31b/quant/gguf.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Unpack GGUF quantized tensors to torchao tensor subclasses. - -Supports Q4_K, Q6_K, F32, and F16 tensor types. Two public APIs: - - - ``unpack_gguf_tensor`` — convert a single tensor - - ``iter_gguf_tensors`` — stream all tensors from a file (low peak memory) - -Model-agnostic. For Gemma 4 31B key mapping and model loading, see -``gguf_loader.py``. -""" - -from collections.abc import Iterator - -import torch - -QK_K = 256 # super-block size for k-quants -Q4_K_GROUPS = 8 # sub-blocks per Q4_K super-block -Q4_K_GROUP_SIZE = QK_K // Q4_K_GROUPS # 32 -Q6_K_GROUPS = 16 # sub-blocks per Q6_K super-block -Q6_K_GROUP_SIZE = QK_K // Q6_K_GROUPS # 16 - - -def _raw_tensor(data: bytes) -> torch.Tensor: - """Wrap a numpy mmap view as a uint8 torch tensor (zero-copy).""" - return torch.frombuffer(memoryview(data), dtype=torch.uint8) - - -def _read_f16(raw: torch.Tensor, col_start: int, col_end: int) -> torch.Tensor: - """Read fp16 field from block bytes, return float32.""" - return raw[:, col_start:col_end].contiguous().view(torch.float16).float() - - -def _unpack_q4_k(data, shape: list[int]) -> torch.Tensor: - """Unpack Q4_K super-blocks into an ``Int4Tensor``. - - Q4_K block layout (144 bytes per 256 values): - - d (2B, fp16): super-block scale - - dmin (2B, fp16): super-block min - - scales (12B): 8 sub-block scales + 8 sub-block mins, 6-bit packed - - qs (128B): 256 4-bit values, two per byte - - Dequant: weight = d * sub_scale * q - dmin * sub_min - """ - from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor - - N, K = shape - assert K % QK_K == 0, f"Q4_K requires K divisible by {QK_K}, got {K}" - n_blocks = N * (K // QK_K) - block_bytes = 2 + 2 + 12 + QK_K // 2 # 144 - raw = _raw_tensor(data).reshape(n_blocks, block_bytes) - - d = _read_f16(raw, 0, 2) - dmin = _read_f16(raw, 2, 4) - s = raw[:, 4:16] - qs = raw[:, 16:144] - - sc = torch.empty(n_blocks, 8, dtype=torch.float32) - mn = torch.empty(n_blocks, 8, dtype=torch.float32) - sc[:, :4] = (s[:, :4] & 0x3F).float() - mn[:, :4] = (s[:, 4:8] & 0x3F).float() - sc[:, 4:] = ((s[:, 8:12] & 0xF) | ((s[:, :4] >> 6) << 4)).float() - mn[:, 4:] = ((s[:, 8:12] >> 4) | ((s[:, 4:8] >> 6) << 4)).float() - del s - - eff_scale = (d * sc).reshape(N, -1) - eff_min = (dmin * mn).reshape(N, -1) - del d, dmin, sc, mn - - zero_std = torch.where( - eff_scale != 0, eff_min / eff_scale, torch.zeros_like(eff_min) - ) - del eff_min - - # GGUF Q4_K nibble order: 32 lows then 32 highs per sub-block pair - low = (qs & 0x0F).to(torch.uint8) - high = ((qs >> 4) & 0x0F).to(torch.uint8) - qdata_unpacked = torch.cat( - [ - low[:, :32], - high[:, :32], - low[:, 32:64], - high[:, 32:64], - low[:, 64:96], - high[:, 64:96], - low[:, 96:128], - high[:, 96:128], - ], - dim=-1, - ).reshape(N, K) - del qs, low, high - - # Nibble-pack for Int4Tensor: even=LOW, odd=HIGH - packed = qdata_unpacked[:, ::2] | (qdata_unpacked[:, 1::2] << 4) - - # Int4Tensor scale/zero layout: (K//gs, N) — transposed - return Int4Tensor( - qdata=packed, - scale=eff_scale.to(torch.bfloat16).t().contiguous(), - zero_point=zero_std.to(torch.bfloat16).t().contiguous(), - block_size=[1, Q4_K_GROUP_SIZE], - shape=torch.Size([N, K]), - ) - - -def _unpack_q6_k(data, shape: list[int]) -> torch.Tensor: - """Unpack Q6_K super-blocks into an ``IntxUnpackedToInt8Tensor``. - - Q6_K block layout (210 bytes per 256 values): - - ql (128B): lower 4 bits of 256 6-bit values - - qh (64B): upper 2 bits of 256 6-bit values - - scales (16B): 16 int8 sub-block scales (groups of 16) - - d (2B, fp16): super-block scale - - Dequant: weight = d * scale_j * (q - 32) - Values are 6-bit [-32, 31], widened to INT8. - """ - from torchao.quantization import IntxUnpackedToInt8Tensor - - N, K = shape - assert K % QK_K == 0, f"Q6_K requires K divisible by {QK_K}, got {K}" - n_blocks = N * (K // QK_K) - block_bytes = 2 + QK_K // 2 + QK_K // 4 + QK_K // 16 # 210 - raw = _raw_tensor(data).reshape(n_blocks, block_bytes) - - ql = raw[:, 0:128] - qh = raw[:, 128:192] - sc = raw[:, 192:208] - d = _read_f16(raw, 208, 210) - - qh0 = qh[:, :32] - qh1 = qh[:, 32:64] - qdata = torch.empty(n_blocks, QK_K, dtype=torch.int16) - qdata[:, 0:32] = (ql[:, :32] & 0x0F) | ((qh0 & 0x03) << 4) - qdata[:, 32:64] = (ql[:, 32:64] & 0x0F) | (((qh0 >> 2) & 0x03) << 4) - qdata[:, 64:96] = ((ql[:, :32] >> 4) & 0x0F) | (((qh0 >> 4) & 0x03) << 4) - qdata[:, 96:128] = ((ql[:, 32:64] >> 4) & 0x0F) | (((qh0 >> 6) & 0x03) << 4) - qdata[:, 128:160] = (ql[:, 64:96] & 0x0F) | ((qh1 & 0x03) << 4) - qdata[:, 160:192] = (ql[:, 96:128] & 0x0F) | (((qh1 >> 2) & 0x03) << 4) - qdata[:, 192:224] = ((ql[:, 64:96] >> 4) & 0x0F) | (((qh1 >> 4) & 0x03) << 4) - qdata[:, 224:256] = ((ql[:, 96:128] >> 4) & 0x0F) | (((qh1 >> 6) & 0x03) << 4) - qdata -= 32 - del ql, qh, qh0, qh1 - - # sc bytes are signed int8 scales; reinterpret from uint8 - eff_scale = (d * sc.to(torch.int8).float()).reshape(N, -1) - del d, sc - - return IntxUnpackedToInt8Tensor( - qdata=qdata.reshape(N, K).to(torch.int8), - scale=eff_scale.to(torch.bfloat16), - zero_point=torch.zeros_like(eff_scale, dtype=torch.int8), - target_dtype=torch.int8, - block_size=(1, Q6_K_GROUP_SIZE), - dtype=torch.bfloat16, - activation_quantization=None, - ) - - -def unpack_gguf_tensor( - tensor_data, - tensor_type, - shape: list[int], -) -> torch.Tensor: - """Unpack a single GGUF tensor. - - Returns an ``Int4Tensor`` for Q4_K, ``IntxUnpackedToInt8Tensor`` for Q6_K, - or a plain ``torch.Tensor`` for F32/F16. - """ - from gguf import GGMLQuantizationType - - if tensor_type == GGMLQuantizationType.Q4_K: - return _unpack_q4_k(tensor_data, shape) - elif tensor_type == GGMLQuantizationType.Q6_K: - return _unpack_q6_k(tensor_data, shape) - elif tensor_type == GGMLQuantizationType.F32: - return _raw_tensor(tensor_data).view(torch.float32).reshape(shape).clone() - elif tensor_type == GGMLQuantizationType.F16: - return ( - _raw_tensor(tensor_data) - .view(torch.float16) - .reshape(shape) - .to(torch.bfloat16) - ) - else: - raise ValueError(f"Unsupported GGUF quant type: {tensor_type}") - - -def iter_gguf_tensors( - path: str, -) -> Iterator[tuple[str, torch.Tensor]]: - """Yield ``(name, result)`` for each tensor in a GGUF file. - - Processes one tensor at a time for low peak memory. Tensor names are - GGUF names (e.g., ``blk.0.attn_q.weight``); the caller handles key - remapping. GGUF shapes are reversed to PyTorch convention automatically. - """ - from gguf import GGUFReader - - reader = GGUFReader(path) - for tensor in reader.tensors: - shape = list(reversed(tensor.shape.tolist())) - result = unpack_gguf_tensor(tensor.data, tensor.tensor_type, shape) - yield tensor.name, result diff --git a/examples/models/gemma4_31b/quant/pack_mlx.py b/examples/models/gemma4_31b/quant/pack_mlx.py index d627c9c437c..22f525accd2 100644 --- a/examples/models/gemma4_31b/quant/pack_mlx.py +++ b/examples/models/gemma4_31b/quant/pack_mlx.py @@ -6,11 +6,11 @@ """MLX packer: convert quantized weights to MLX-compatible format. -MLX's ``QuantizedLinearHandler`` matches ``dequantize_affine → linear`` -in the exported graph. ``IntxUnpackedToInt8Tensor`` produces this -pattern naturally, but ``Int4Tensor`` does not (its dispatch calls -CUDA-specific mslk kernels). So INT4 weights are converted to -``IntxUnpackedToInt8Tensor(target_dtype=torch.int4)`` at pack time. +``Int4Tensor`` weights are wrapped as ``ExportableInt4Tensor`` so they export to +``dequantize_int4_tensor -> linear/embedding`` (matched by MLX's Int4 handlers). +``IntxUnpackedToInt8Tensor`` (e.g. int8 / Q6_K) already exports to +``dequantize_affine -> linear`` and is assigned directly, regrouped to an +MLX-compatible group size when needed. The backend-agnostic ``pack_model`` dispatcher lives in ``pack.py``. """ @@ -25,45 +25,6 @@ _MLX_SUPPORTED_GROUP_SIZES = (128, 64, 32, 16) -# --------------------------------------------------------------------------- -# Int4Tensor → IntxUnpackedToInt8Tensor conversion - - -def _int4_to_intx_unpacked(w: torch.Tensor) -> torch.Tensor: - """Convert an ``Int4Tensor`` to ``IntxUnpackedToInt8Tensor``. - - Int4Tensor stores qdata as nibble-packed uint8 ``(N, K/2)`` with - scale/zero transposed to ``(K//gs, N)``. IntxUnpackedToInt8Tensor - stores qdata as int8 ``(N, K)`` with scale/zero as ``(N, K//gs)``. - """ - from torchao.quantization import IntxUnpackedToInt8Tensor - - # Unpack nibbles: packed = even | (odd << 4), unsigned [0, 15] - p = w.qdata.to(torch.uint8) - low = (p & 0x0F).to(torch.int8) - high = ((p >> 4) & 0x0F).to(torch.int8) - qdata = torch.stack([low, high], dim=-1).reshape(w.shape) - - # Shift unsigned [0, 15] → signed [-8, 7] - qdata = qdata - 8 - - gs = w.block_size[-1] - - # Transpose scale/zero from (K//gs, N) → (N, K//gs) - scale = w.scale.t().contiguous() - zero_point = (w.zero_point - 8).t().contiguous() - - return IntxUnpackedToInt8Tensor( - qdata=qdata, - scale=scale, - zero_point=zero_point, - target_dtype=torch.int4, - block_size=(1, gs), - dtype=scale.dtype, - activation_quantization=None, - ) - - # --------------------------------------------------------------------------- # Embedding group_size regrouping @@ -122,21 +83,23 @@ def _regroup_intx(w: torch.Tensor, new_gs: int) -> torch.Tensor: def pack_for_mlx(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: """Pack a quantized weight for MLX. - ``Int4Tensor`` is converted to ``IntxUnpackedToInt8Tensor`` so the - default dispatch produces the ``dequantize_affine → linear`` pattern - MLX expects. Regroups to a compatible group_size when needed (e.g. - per-axis group_size=5376 → group_size=128) since MLX's - ``parse_dequant_node`` only accepts group_size in {16, 32, 64, 128}. - Group sizes ≥ 32 use the fused ``QuantizedMatmulNode``; group_size=16 - (e.g. GGUF Q6_K) falls back to ``DequantizeNode`` + matmul at export. + ``Int4Tensor`` is wrapped as ``ExportableInt4Tensor`` (exports to + ``dequantize_int4_tensor → linear/embedding``). ``IntxUnpackedToInt8Tensor`` + is assigned directly, regrouped to a compatible group_size when needed (e.g. + per-axis group_size=5376 → 128) since MLX accepts group_size in + {16, 32, 64, 128}. Group sizes ≥ 32 use the fused ``QuantizedMatmulNode``; + group_size=16 (e.g. GGUF Q6_K) falls back to ``DequantizeNode`` + matmul. """ + from executorch.extension.llm.export.int4 import ExportableInt4Tensor from torchao.quantization import IntxUnpackedToInt8Tensor from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor w = weights["weight"] if isinstance(w, Int4Tensor): - w = _int4_to_intx_unpacked(w) - if isinstance(w, IntxUnpackedToInt8Tensor): + # Int4 group is MLX-native (32); wrap so it exports to + # dequantize_int4_tensor -> linear/embedding. + w = ExportableInt4Tensor.from_int4_tensor(w) + elif isinstance(w, IntxUnpackedToInt8Tensor): gs = w.block_size[-1] K = w.qdata.shape[-1] target_gs = _mlx_group_size(gs, K) diff --git a/examples/models/gemma4_31b/quant/tests/test_gguf.py b/examples/models/gemma4_31b/quant/tests/test_gguf.py deleted file mode 100644 index 89a7099d6f0..00000000000 --- a/examples/models/gemma4_31b/quant/tests/test_gguf.py +++ /dev/null @@ -1,282 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Unit tests for quant/gguf.py — Q4_K and Q6_K unpacking. - -Tests verify the API contract: dequantized weights match the original -GGUF dequantization formula. Uses synthetic blocks — no GGUF file required. -""" - -import os -import struct -import tempfile -import unittest - -import numpy as np -import torch - -try: - from gguf import GGMLQuantizationType - - _HAS_GGUF = True -except ImportError: - _HAS_GGUF = False - -if _HAS_GGUF: - from executorch.examples.models.gemma4_31b.quant.gguf import unpack_gguf_tensor - -from executorch.examples.models.gemma4_31b.quant.quantize import dequantize_weight -from safetensors import safe_open -from safetensors.torch import save_file -from torchao.prototype.safetensors.safetensors_support import ( - flatten_tensor_state_dict, - unflatten_tensor_state_dict, -) - - -def _make_q4_k_block(d, dmin, sub_scales, sub_mins, qvals): - """Build one Q4_K block (144 bytes) from components.""" - buf = bytearray(144) - struct.pack_into("> 4) << 6 - scales_bytes[j] |= (sub_mins[j] >> 4) << 6 - buf[4:16] = scales_bytes - # GGUF Q4_K nibble order: 32 lows then 32 highs per sub-block pair - for g in range(4): - for i in range(32): - lo_val = qvals[g * 64 + i] - hi_val = qvals[g * 64 + 32 + i] - buf[16 + g * 32 + i] = (lo_val & 0xF) | ((hi_val & 0xF) << 4) - return buf - - -def _make_q6_k_block(d, scales_16, qvals_256): - """Build one Q6_K block (210 bytes) from components. - - ggml processes 128 values at a time. For each 128-value half: - ql: 64 bytes (two groups of 32, low/high nibbles) - qh: 32 bytes (2 bits each for 4 sub-positions) - The qvals_256 array is in output order (position 0..255). - """ - buf = bytearray(210) - # First half (positions 0..127): ql bytes 0..63, qh bytes 0..31 - for i in range(32): - buf[i] = (qvals_256[i] & 0x0F) | ((qvals_256[i + 64] & 0x0F) << 4) - for i in range(32): - buf[32 + i] = (qvals_256[i + 32] & 0x0F) | ((qvals_256[i + 96] & 0x0F) << 4) - for i in range(32): - h0 = (qvals_256[i] >> 4) & 0x03 - h1 = (qvals_256[i + 32] >> 4) & 0x03 - h2 = (qvals_256[i + 64] >> 4) & 0x03 - h3 = (qvals_256[i + 96] >> 4) & 0x03 - buf[128 + i] = h0 | (h1 << 2) | (h2 << 4) | (h3 << 6) - # Second half (positions 128..255): ql bytes 64..127, qh bytes 32..63 - for i in range(32): - buf[64 + i] = (qvals_256[i + 128] & 0x0F) | ((qvals_256[i + 192] & 0x0F) << 4) - for i in range(32): - buf[96 + i] = (qvals_256[i + 160] & 0x0F) | ((qvals_256[i + 224] & 0x0F) << 4) - for i in range(32): - h0 = (qvals_256[i + 128] >> 4) & 0x03 - h1 = (qvals_256[i + 160] >> 4) & 0x03 - h2 = (qvals_256[i + 192] >> 4) & 0x03 - h3 = (qvals_256[i + 224] >> 4) & 0x03 - buf[160 + i] = h0 | (h1 << 2) | (h2 << 4) | (h3 << 6) - # Scales and d - for i in range(16): - buf[192 + i] = scales_16[i] & 0xFF - struct.pack_into(" CUDA load -> inference -> export (mirrors TestGgufLinearMlx).""" + + def setUp(self): + _require_cuda(self) + try: + import gguf # noqa: F401 + except ImportError: + self.skipTest("gguf package required") + + def _load(self, tmp): + path = os.path.join(tmp, "tiny.gguf") + build_gguf_checkpoint(path) + return load_gguf_model(path, backend="cuda", config=GGUF_CONFIG) + + def test_load_converts_weights(self): + """GGUF -> CUDA: Q4_K -> Int4Tensor, Q6_K -> IntxUnpacked, embedding bf16.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + with tempfile.TemporaryDirectory() as tmp: + model, _ = self._load(tmp) + + self.assertIsInstance(model.layers[0].self_attn.q_proj.weight.data, Int4Tensor) + self.assertIsInstance( + model.layers[0].mlp.down_proj.weight.data, IntxUnpackedToInt8Tensor + ) + # Token embedding is dequantized to bf16 (Int4/Intx can't gather). + self.assertEqual(model.embed_tokens.weight.dtype, torch.bfloat16) + + def test_generate(self): + """GGUF -> CUDA -> eager generate produces valid tokens (inference.py).""" + with tempfile.TemporaryDirectory() as tmp: + model, config = self._load(tmp) + _move_to_cuda(model, config) + model.eval() + tokenizer = MockTokenizer(GGUF_CONFIG.vocab_size) + + torch.manual_seed(0) + out = generate(model, tokenizer, prompt="hi", max_new_tokens=3, temperature=1.0) + self.assertIsInstance(out, str) + self.assertGreater(len(out), 0) + + def test_export(self): + """GGUF -> CUDA -> export_and_lower produces a .pte (export.py).""" + with tempfile.TemporaryDirectory() as tmp, tempfile.TemporaryDirectory() as out_dir: + model, config = self._load(tmp) + export_and_lower(model, config, out_dir) + self.assertTrue(os.path.exists(os.path.join(out_dir, "model.pte"))) + + if __name__ == "__main__": unittest.main() diff --git a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py index 37f61fddb0f..b26e2783aa6 100644 --- a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py @@ -20,7 +20,6 @@ import torch import torch.nn as nn - from executorch.examples.models.gemma4_31b.model import Gemma4_31B from executorch.examples.models.gemma4_31b.quant import ( DEFAULT_MLX_PACKERS, @@ -31,8 +30,10 @@ QuantRule, ) from executorch.examples.models.gemma4_31b.tests.test_pipeline import ( + build_gguf_checkpoint, build_random_tiny_model, config_dict, + GGUF_CONFIG, save_checkpoint, TINY_CONFIG, ) @@ -323,5 +324,208 @@ def test_embedding_packing_preserves_values(self): ) +class TestGgufLinearMlx(unittest.TestCase): + """GGUF-quantized linears (Q6_K + Q4_K) lower through the MLX GGUF pattern.""" + + def _linear(self, N: int, K: int, ggml_type: str) -> nn.Module: + from executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear import ( + make_q4_k_blob, + make_q6_k_blob, + ) + from executorch.extension.llm.export.gguf import ExportableGGUFTensor + + blob = (make_q6_k_blob if ggml_type == "q6_k" else make_q4_k_blob)(N, K) + lin = nn.Linear(K, N, bias=False).to(torch.bfloat16) + lin.weight = nn.Parameter( + ExportableGGUFTensor.from_raw(blob, ggml_type, torch.bfloat16), + requires_grad=False, + ) + return lin.eval() + + def _assert_delegated(self, model, example, leftovers): + import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 + from executorch.backends.mlx import MLXPartitioner + from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower + from torch.export import Dim, export + + seq = Dim("seq", min=1, max=8) + ep = export(model, example, dynamic_shapes=({0: seq},), strict=True) + et = to_edge_transform_and_lower( + ep, + partitioner=[MLXPartitioner()], + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + remaining = [ + str(n.target) + for n in et.exported_program().graph.nodes + if n.op == "call_function" and any(t in str(n.target) for t in leftovers) + ] + self.assertEqual(remaining, [], f"not delegated to MLX: {remaining}") + + def test_q6k_linear_delegates(self): + self._assert_delegated( + self._linear(256, 512, "q6_k"), + (torch.randn(4, 512, dtype=torch.bfloat16),), + ("dequantize_gguf", "linear"), + ) + + def test_q4k_linear_delegates(self): + self._assert_delegated( + self._linear(512, 512, "q4_k"), + (torch.randn(4, 512, dtype=torch.bfloat16),), + ("dequantize_gguf", "linear"), + ) + + +class TestGgufEmbeddingMlx(unittest.TestCase): + """GGUF token embeddings (Q6_K + Q4_K) lower through the MLX GGUF pattern.""" + + def _assert_delegated(self, ggml_type: str): + import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 + from executorch.backends.mlx import MLXPartitioner + from executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear import ( + make_q4_k_blob, + make_q6_k_blob, + ) + from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower + from executorch.extension.llm.export.gguf import ExportableGGUFTensor + from torch.export import Dim, export + + vocab, K = 512, 256 + blob = (make_q6_k_blob if ggml_type == "q6_k" else make_q4_k_blob)(vocab, K) + emb = nn.Embedding(vocab, K) + emb.weight = nn.Parameter( + ExportableGGUFTensor.from_raw(blob, ggml_type, torch.bfloat16), + requires_grad=False, + ) + emb = emb.eval() + seq = Dim("seq", min=1, max=8) + ep = export( + emb, + (torch.randint(0, vocab, (4,), dtype=torch.int64),), + dynamic_shapes=({0: seq},), + strict=True, + ) + et = to_edge_transform_and_lower( + ep, + partitioner=[MLXPartitioner()], + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + remaining = [ + str(n.target) + for n in et.exported_program().graph.nodes + if n.op == "call_function" + and any(t in str(n.target) for t in ("dequantize_gguf", "embedding")) + ] + self.assertEqual(remaining, [], f"not delegated to MLX: {remaining}") + + def test_q6k_embedding_delegates(self): + self._assert_delegated("q6_k") + + def test_q4k_embedding_delegates(self): + self._assert_delegated("q4_k") + + +class TestInt4Mlx(unittest.TestCase): + """ExportableInt4Tensor linear + embedding lower through the MLX Int4 pattern.""" + + def _make_int4(self, N, K, gs=32, seed=0): + from executorch.extension.llm.export.int4 import ExportableInt4Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + g = torch.Generator().manual_seed(seed) + q = torch.randint(0, 16, (N, K), generator=g, dtype=torch.int32) + packed = (q[:, 0::2] | (q[:, 1::2] << 4)).to(torch.uint8) + scale = (torch.randn(K // gs, N, generator=g) * 0.1).to(torch.bfloat16) + zero = torch.randint(0, 16, (K // gs, N), generator=g).to(torch.bfloat16) + it = Int4Tensor( + qdata=packed, + scale=scale, + zero_point=zero, + block_size=[1, gs], + shape=torch.Size([N, K]), + ) + return ExportableInt4Tensor.from_int4_tensor(it) + + def _assert_delegated(self, model, example, leftovers): + import executorch.backends.mlx.patterns # noqa: F401 + from executorch.backends.mlx import MLXPartitioner + from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower + from torch.export import Dim, export + + seq = Dim("seq", min=1, max=8) + ep = export(model, example, dynamic_shapes=({0: seq},), strict=True) + et = to_edge_transform_and_lower( + ep, + partitioner=[MLXPartitioner()], + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + remaining = [ + str(n.target) + for n in et.exported_program().graph.nodes + if n.op == "call_function" and any(t in str(n.target) for t in leftovers) + ] + self.assertEqual(remaining, [], f"not delegated to MLX: {remaining}") + + def test_int4_linear_delegates(self): + lin = nn.Linear(512, 256, bias=False).to(torch.bfloat16) + lin.weight = nn.Parameter(self._make_int4(256, 512), requires_grad=False) + self._assert_delegated( + lin.eval(), + (torch.randn(4, 512, dtype=torch.bfloat16),), + ("dequantize_int4_tensor", "linear"), + ) + + def test_int4_embedding_delegates(self): + vocab, K = 512, 256 + emb = nn.Embedding(vocab, K) + emb.weight = nn.Parameter(self._make_int4(vocab, K), requires_grad=False) + self._assert_delegated( + emb.eval(), + (torch.randint(0, vocab, (4,), dtype=torch.int64),), + ("dequantize_int4_tensor", "embedding"), + ) + + +class TestGgufLoadMlx(unittest.TestCase): + """GGUF file -> load_gguf_model(mlx) -> export (parity with the CUDA test).""" + + def setUp(self): + try: + import gguf # noqa: F401 + except ImportError: + self.skipTest("gguf package required") + + def _load(self, tmp): + from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model + + path = os.path.join(tmp, "tiny.gguf") + build_gguf_checkpoint(path) + return load_gguf_model(path, backend="mlx", config=GGUF_CONFIG) + + def test_load_keeps_gguf_tensors_and_ties_lm_head(self): + """MLX keeps weights as ExportableGGUFTensor; lm_head stays tied.""" + from executorch.extension.llm.export.gguf import ExportableGGUFTensor + + with tempfile.TemporaryDirectory() as tmp: + model, _ = self._load(tmp) + + self.assertIsInstance( + model.layers[0].self_attn.q_proj.weight.data, ExportableGGUFTensor + ) + self.assertIsInstance(model.embed_tokens.weight.data, ExportableGGUFTensor) + # GGUF ties embed/lm_head; on MLX they share the one quantized tensor. + self.assertIs(model.lm_head.weight.data, model.embed_tokens.weight.data) + + def test_export(self): + """GGUF -> MLX load -> export_and_lower produces a .pte (export.py).""" + from executorch.examples.models.gemma4_31b.export import export_and_lower + + with tempfile.TemporaryDirectory() as tmp, tempfile.TemporaryDirectory() as out_dir: + model, config = self._load(tmp) + export_and_lower(model, config, out_dir, backend="mlx") + self.assertTrue(os.path.exists(os.path.join(out_dir, "model.pte"))) + + if __name__ == "__main__": unittest.main() diff --git a/examples/models/gemma4_31b/tests/test_pipeline.py b/examples/models/gemma4_31b/tests/test_pipeline.py index a8d9d9cbe34..f81d68c623a 100644 --- a/examples/models/gemma4_31b/tests/test_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_pipeline.py @@ -158,6 +158,96 @@ def build_hf_checkpoint(output_dir: str) -> None: json.dump(config_dict(), f) +# GGUF-friendly tiny config: Q4_K/Q6_K need in-features that are multiples of 256, +# so hidden/intermediate are 256. Two layers exercises a sliding + a global layer. +GGUF_CONFIG = Gemma4_31BConfig( + vocab_size=256, + hidden_size=256, + intermediate_size=256, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=256, + global_head_dim=512, + sliding_window=16, + max_seq_len=64, +) + + +def _model_to_gguf_key(fqn: str): + """Invert ``gguf_loader._KEY_MAP`` (model FQN -> GGUF tensor name).""" + from executorch.examples.models.gemma4_31b.gguf_loader import _KEY_MAP + + for gguf_pat, model_pat in _KEY_MAP.items(): + if "{}" not in model_pat: + if fqn == model_pat: + return gguf_pat + continue + prefix, suffix = model_pat.split("{}") + if fqn.startswith(prefix) and fqn.endswith(suffix): + idx = fqn[len(prefix) : len(fqn) - len(suffix)] + if idx.isdigit(): + return gguf_pat.replace("{}", idx) + return None + + +def build_gguf_checkpoint(path: str, config: Gemma4_31BConfig = GGUF_CONFIG) -> None: + """Write a tiny GGUF file matching ``config``. + + Linears are Q4_K; ``ffn_down`` / ``token_embd`` are Q6_K (to exercise both + GGUF unpack paths); norms / scalars are F32. Tensor shapes are derived from + the instantiated model so per-layer-type differences (e.g. global layers + having no v_proj / q_norm) are handled automatically. ``output.weight`` is + omitted -- GGUF ties lm_head to the token embedding. Requires the ``gguf`` + package. + """ + import gguf + from executorch.extension.llm.export.gguf import QK_K + from executorch.extension.llm.export.test.test_gguf import ( + _make_q4k_raw, + _make_q6k_raw, + ) + + with torch.device("meta"): + model = Gemma4_31B(config) + + writer = gguf.GGUFWriter(path, "gemma") + for fqn, p in model.named_parameters(): + gguf_key = _model_to_gguf_key(fqn) + if gguf_key is None: + continue + if p.dim() == 2: + N, K = int(p.shape[0]), int(p.shape[1]) + nb = K // QK_K + use_q6 = gguf_key == "token_embd.weight" or gguf_key.endswith( + "ffn_down.weight" + ) + blob = (_make_q6k_raw(N, nb) if use_q6 else _make_q4k_raw(N, nb)).numpy() + raw_dtype = ( + gguf.GGMLQuantizationType.Q6_K + if use_q6 + else gguf.GGMLQuantizationType.Q4_K + ) + writer.add_tensor(gguf_key, blob, raw_dtype=raw_dtype) + else: + arr = (torch.randn(tuple(p.shape), dtype=torch.float32) * 0.1).numpy() + writer.add_tensor(gguf_key, arr) + # Per-layer scalars are buffers (not parameters) but are stored in real + # GGUFs (e.g. blk.N.layer_output_scale.weight). Write the ones that have a + # GGUF mapping so they load as bf16; runtime buffers (RoPE, KV cache, ...) + # map to None and are skipped. + for fqn, b in model.named_buffers(): + gguf_key = _model_to_gguf_key(fqn) + if gguf_key is None: + continue + arr = torch.ones(tuple(b.shape), dtype=torch.float32).numpy() + writer.add_tensor(gguf_key, arr) + writer.write_header_to_file() + writer.write_kv_data_to_file() + writer.write_tensors_to_file() + writer.close() + + # --------------------------------------------------------------------------- # Tests (CPU only, no backend dependency) diff --git a/examples/models/qwen3_5_moe/mlx_source_transformations.py b/examples/models/qwen3_5_moe/mlx_source_transformations.py index 25605fb6342..9a49f8a84f6 100644 --- a/examples/models/qwen3_5_moe/mlx_source_transformations.py +++ b/examples/models/qwen3_5_moe/mlx_source_transformations.py @@ -194,7 +194,7 @@ def _exportable_gated_delta_net_forward(self, x, input_pos): x = a + self.dt_bias g = (-self.A_log.exp() * torch.logaddexp(x, torch.zeros_like(x))).exp() - import executorch.backends.mlx.model_ops.gated_delta_rule as _ # noqa: ensure op registered + import executorch.backends.mlx.custom_kernel_ops.gated_delta_rule as _ # noqa: ensure op registered output = torch.ops.mlx.gated_delta_rule( q, diff --git a/extension/llm/export/gguf.py b/extension/llm/export/gguf.py new file mode 100644 index 00000000000..1ffb0435eb9 --- /dev/null +++ b/extension/llm/export/gguf.py @@ -0,0 +1,386 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Export-time GGUF quantized weights. + +``ExportableGGUFTensor`` wraps the *raw* GGUF block bytes for one tensor and +defers all unpacking, serving as the canonical GGUF loading representation: + +* ``load_gguf(path)`` -> ``{name -> ExportableGGUFTensor | Tensor}`` (quantized + tensors become subclasses; F32/F16 stay plain). No unpacking at load. +* As a weight, it dequantizes via the ``torchao::dequantize_gguf`` custom op + (gguf-package eager body) then a plain ``linear`` / ``embedding`` -- a backend + can pattern-match ``dequantize_gguf`` -> linear/embedding to fuse. +* ``.to_int4_tensor()`` / ``.to_intx_unpacked_to_int8_tensor()`` convert into + torchao subclasses (``Int4Tensor`` / ``IntxUnpackedToInt8Tensor``) instead. + +The quant type is a string (``"q4_k"`` / ``"q6_k"``); the ``gguf`` package's +integer ``GGMLQuantizationType`` ids are an internal lookup detail. Which tensors +to convert is the caller's policy. + +Attribution: Q4_K / Q6_K block layouts follow llama.cpp / gguf-py +(``ggml-common.h``), MIT-licensed (Copyright (c) 2023-2024 The ggml authors). +""" + +from __future__ import annotations + +from typing import Dict, Iterator, Optional, Tuple, Union + +import numpy as np +import torch +from torch import Tensor +from torchao.utils import TorchAOBaseTensor + +aten = torch.ops.aten + +# GGUF k-quant constants + +QK_K = 256 # super-block size for k-quants + +Q4_K_GROUP_SIZE = QK_K // 8 # 32 (8 sub-blocks per super-block) +Q6_K_GROUP_SIZE = QK_K // 16 # 16 (16 sub-blocks per super-block) + +_Q4_K_BLOCK_BYTES = 2 + 2 + 12 + QK_K // 2 # 144 +_Q6_K_BLOCK_BYTES = 2 + QK_K // 2 + QK_K // 4 + QK_K // 16 # 210 + +# ``gguf.GGMLQuantizationType`` integer ids. +GGML_F32 = 0 +GGML_F16 = 1 +GGML_Q4_K = 12 +GGML_Q6_K = 14 + +# String quant-type names are the user-facing identifier (op arg + subclass attr). +# These dicts map names to the internal ids / block sizes. +_GGML_ID_BY_TYPE = {"q4_k": GGML_Q4_K, "q6_k": GGML_Q6_K} +_TYPE_BY_GGML_ID = {v: k for k, v in _GGML_ID_BY_TYPE.items()} +_BLOCK_BYTES_BY_TYPE = {"q4_k": _Q4_K_BLOCK_BYTES, "q6_k": _Q6_K_BLOCK_BYTES} + + +def _read_f16(raw: Tensor, col_start: int, col_end: int) -> Tensor: + """Read an fp16 field from per-block bytes, return float32.""" + return raw[:, col_start:col_end].contiguous().view(torch.float16).float() + + +def _dequantize_gguf(raw: Tensor, ggml_type: str, output_dtype: torch.dtype) -> Tensor: + """Dequantize a raw GGUF block blob to a float tensor via the ``gguf`` package. + + ``raw`` is ``(N, row_bytes)`` uint8; the result is ``(N, K)`` in + ``output_dtype``. + """ + import gguf + + if ggml_type not in _GGML_ID_BY_TYPE: + raise NotImplementedError(f"unsupported GGUF quant type {ggml_type!r}") + qtype = gguf.GGMLQuantizationType(_GGML_ID_BY_TYPE[ggml_type]) + np_raw = raw.detach().cpu().contiguous().numpy() + deq = gguf.dequantize(np_raw, qtype) + return torch.from_numpy(np.ascontiguousarray(deq)).to( + device=raw.device, dtype=output_dtype + ) + + +# Fused ops (eager = gguf.dequantize + torch op; a backend may lower to kernels) + + +@torch.library.custom_op("torchao::dequantize_gguf", mutates_args=()) +def dequantize_gguf( + weight: Tensor, + ggml_type: str, + output_dtype: torch.dtype = torch.bfloat16, +) -> Tensor: + """Dequantize a raw GGUF block blob (``(N, row_bytes)`` uint8) to ``(N, K)``.""" + return _dequantize_gguf(weight, ggml_type, output_dtype) + + +@dequantize_gguf.register_fake +def _(weight, ggml_type, output_dtype=torch.bfloat16): + K = (weight.shape[1] // _BLOCK_BYTES_BY_TYPE[ggml_type]) * QK_K + return torch.empty((weight.shape[0], K), dtype=output_dtype, device=weight.device) + + +# Per-type field extraction (used by the to_*_tensor conversions) + + +def _q4_k_fields(raw: Tensor, N: int, K: int) -> Tuple[Tensor, Tensor, Tensor]: + """Decode Q4_K blocks for conversion to ``Int4Tensor``. + + Returns ``(q, eff_scale, eff_min)`` where ``q`` is ``(N, K)`` uint8 in + [0, 15], and ``eff_scale`` / ``eff_min`` are ``(N, K // 32)`` float32. + """ + n_blocks = N * (K // QK_K) + blk = raw.reshape(n_blocks, _Q4_K_BLOCK_BYTES) + + d = _read_f16(blk, 0, 2) + dmin = _read_f16(blk, 2, 4) + s = blk[:, 4:16] + qs = blk[:, 16:144] + + sc = torch.empty(n_blocks, 8, dtype=torch.float32) + mn = torch.empty(n_blocks, 8, dtype=torch.float32) + sc[:, :4] = (s[:, :4] & 0x3F).float() + mn[:, :4] = (s[:, 4:8] & 0x3F).float() + sc[:, 4:] = ((s[:, 8:12] & 0xF) | ((s[:, :4] >> 6) << 4)).float() + mn[:, 4:] = ((s[:, 8:12] >> 4) | ((s[:, 4:8] >> 6) << 4)).float() + + eff_scale = (d * sc).reshape(N, -1) + eff_min = (dmin * mn).reshape(N, -1) + + # GGUF Q4_K nibble order: 32 lows then 32 highs per sub-block pair. + low = (qs & 0x0F).to(torch.uint8) + high = ((qs >> 4) & 0x0F).to(torch.uint8) + q = torch.cat( + [ + low[:, :32], + high[:, :32], + low[:, 32:64], + high[:, 32:64], + low[:, 64:96], + high[:, 64:96], + low[:, 96:128], + high[:, 96:128], + ], + dim=-1, + ).reshape(N, K) + return q, eff_scale, eff_min + + +def _q6_k_fields(raw: Tensor, N: int, K: int) -> Tuple[Tensor, Tensor]: + """Decode Q6_K blocks for conversion to ``IntxUnpackedToInt8Tensor``. + + Returns ``(q, eff_scale)`` where ``q`` is ``(N, K)`` int8 in [-32, 31] and + ``eff_scale`` is ``(N, K // 16)`` float32. + """ + n_blocks = N * (K // QK_K) + blk = raw.reshape(n_blocks, _Q6_K_BLOCK_BYTES) + + ql = blk[:, 0:128] + qh = blk[:, 128:192] + sc = blk[:, 192:208] + d = _read_f16(blk, 208, 210) + + qh0 = qh[:, :32] + qh1 = qh[:, 32:64] + q = torch.empty(n_blocks, QK_K, dtype=torch.int16) + q[:, 0:32] = (ql[:, :32] & 0x0F) | ((qh0 & 0x03) << 4) + q[:, 32:64] = (ql[:, 32:64] & 0x0F) | (((qh0 >> 2) & 0x03) << 4) + q[:, 64:96] = ((ql[:, :32] >> 4) & 0x0F) | (((qh0 >> 4) & 0x03) << 4) + q[:, 96:128] = ((ql[:, 32:64] >> 4) & 0x0F) | (((qh0 >> 6) & 0x03) << 4) + q[:, 128:160] = (ql[:, 64:96] & 0x0F) | ((qh1 & 0x03) << 4) + q[:, 160:192] = (ql[:, 96:128] & 0x0F) | (((qh1 >> 2) & 0x03) << 4) + q[:, 192:224] = ((ql[:, 64:96] >> 4) & 0x0F) | (((qh1 >> 4) & 0x03) << 4) + q[:, 224:256] = ((ql[:, 96:128] >> 4) & 0x0F) | (((qh1 >> 6) & 0x03) << 4) + q -= 32 + + # ``sc`` bytes are signed int8 sub-block scales. + eff_scale = (d * sc.to(torch.int8).float()).reshape(N, -1) + return q.reshape(N, K).to(torch.int8), eff_scale + + +# Tensor subclass + + +class ExportableGGUFTensor(TorchAOBaseTensor): + """Wraps the raw GGUF block bytes for one quantized weight. + + Stores the exact GGUF ``block_q*_K`` byte layout (no repacking) plus the + quant type string (``"q4_k"`` / ``"q6_k"``). ``aten.linear`` / ``aten.embedding`` + dequantize via the ``torchao::dequantize_gguf`` op (then a plain + linear/embedding); :meth:`to_int4_tensor` / :meth:`to_intx_unpacked_to_int8_tensor` + convert to torchao subclasses instead. + """ + + tensor_data_names = ["raw"] + tensor_attribute_names = ["ggml_type", "orig_dtype"] + + def __new__(cls, raw: Tensor, ggml_type: str, orig_dtype: torch.dtype): + if raw.dim() != 2 or raw.dtype != torch.uint8: + raise ValueError( + f"ExportableGGUFTensor: raw must be 2-D uint8 (N, row_bytes); got " + f"shape {tuple(raw.shape)} dtype {raw.dtype}" + ) + if ggml_type not in _BLOCK_BYTES_BY_TYPE: + raise NotImplementedError( + f"ExportableGGUFTensor: unsupported quant type {ggml_type!r}; " + f"supported: {sorted(_BLOCK_BYTES_BY_TYPE)}" + ) + n, row_bytes = int(raw.shape[0]), int(raw.shape[1]) + block_bytes = _BLOCK_BYTES_BY_TYPE[ggml_type] + if row_bytes % block_bytes != 0: + raise ValueError( + f"ExportableGGUFTensor: row bytes {row_bytes} not a multiple of " + f"block bytes {block_bytes} for quant type {ggml_type!r}" + ) + K = (row_bytes // block_bytes) * QK_K + self = torch.Tensor._make_wrapper_subclass( + cls, (n, K), dtype=orig_dtype, device=raw.device, requires_grad=False + ) + self.raw = raw + self.ggml_type = ggml_type + self.orig_dtype = orig_dtype + return self + + @classmethod + def from_raw( + cls, + raw: Tensor, + ggml_type: str, + orig_dtype: torch.dtype = torch.bfloat16, + ) -> "ExportableGGUFTensor": + """Build from a ``(N, row_bytes)`` uint8 GGUF block blob.""" + return cls(raw.contiguous(), ggml_type, orig_dtype) + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> Tensor: + """Dequantize to a plain float tensor using the ``gguf`` package.""" + return torch.ops.torchao.dequantize_gguf( + self.raw, self.ggml_type, output_dtype or self.orig_dtype + ) + + def to_int4_tensor(self) -> Tensor: + """Convert a Q4_K tensor to a torchao ``Int4Tensor``.""" + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + if self.ggml_type != "q4_k": + raise NotImplementedError( + f"to_int4_tensor only supports q4_k; got {self.ggml_type!r}" + ) + N, K = int(self.shape[0]), int(self.shape[1]) + q, eff_scale, eff_min = _q4_k_fields(self.raw, N, K) + + zero = torch.where( + eff_scale != 0, eff_min / eff_scale, torch.zeros_like(eff_min) + ) + # Nibble-pack for Int4Tensor: even index -> low nibble, odd -> high. + packed = q[:, ::2] | (q[:, 1::2] << 4) + return Int4Tensor( + qdata=packed, + # Int4Tensor scale/zero layout is (K // gs, N) -- transposed. + scale=eff_scale.to(torch.bfloat16).t().contiguous(), + zero_point=zero.to(torch.bfloat16).t().contiguous(), + block_size=[1, Q4_K_GROUP_SIZE], + shape=torch.Size([N, K]), + ) + + def to_intx_unpacked_to_int8_tensor(self) -> Tensor: + """Convert to a torchao ``IntxUnpackedToInt8Tensor`` (Q4_K or Q6_K). + + Q6_K maps to a symmetric int8 tensor (values [-32, 31], zero-point 0). + Q4_K maps to a 4-bit tensor: values are centered to [-8, 7] and the + affine min is folded into a (float) zero-point, so the rewrite is exact. + """ + from torchao.quantization import IntxUnpackedToInt8Tensor + + N, K = int(self.shape[0]), int(self.shape[1]) + if self.ggml_type == "q6_k": + q, eff_scale = _q6_k_fields(self.raw, N, K) + return IntxUnpackedToInt8Tensor( + qdata=q, + scale=eff_scale.to(torch.bfloat16), + zero_point=torch.zeros_like(eff_scale, dtype=torch.int8), + target_dtype=torch.int8, + block_size=(1, Q6_K_GROUP_SIZE), + dtype=torch.bfloat16, + activation_quantization=None, + ) + if self.ggml_type == "q4_k": + q, eff_scale, eff_min = _q4_k_fields(self.raw, N, K) + zero = torch.where( + eff_scale != 0, eff_min / eff_scale, torch.zeros_like(eff_min) + ) + # Center quants [0, 15] -> [-8, 7] and shift the zero-point to match + # (dequant = scale * (q - zp) is preserved). + return IntxUnpackedToInt8Tensor( + qdata=q.to(torch.int8) - 8, + scale=eff_scale.to(torch.bfloat16), + zero_point=(zero - 8).to(torch.bfloat16), + target_dtype=torch.int4, + block_size=(1, Q4_K_GROUP_SIZE), + dtype=torch.bfloat16, + activation_quantization=None, + ) + raise NotImplementedError( + f"to_intx_unpacked_to_int8_tensor supports q4_k/q6_k; " + f"got {self.ggml_type!r}" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + +implements = ExportableGGUFTensor.implements + + +@implements([aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight = args[0], args[1] + bias = args[2] if len(args) > 2 else None + return torch.nn.functional.linear( + input_tensor, weight.dequantize(input_tensor.dtype), bias + ) + + +@implements([aten.embedding.default]) +def _(func, types, args, kwargs): + weight, indices = args[0], args[1] + return torch.nn.functional.embedding(indices, weight.dequantize()) + + +@implements([aten.t.default]) +def _(func, types, args, kwargs): + return args[0].dequantize().t() + + +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return args[0] + + +@implements([aten._to_copy.default]) +def _(func, types, args, kwargs): + return args[0].dequantize(output_dtype=kwargs.get("dtype", args[0].orig_dtype)) + + +# Loader + + +def iter_gguf( + path: str, +) -> Iterator[Tuple[str, Union[ExportableGGUFTensor, Tensor]]]: + """Stream ``(name, value)`` for every tensor in a GGUF file (low peak mem). + + Quantized tensors (Q4_K, Q6_K) are wrapped as ``ExportableGGUFTensor`` with + the raw block bytes; F32/F16 are returned as plain float tensors (bf16 for + F16). GGUF shapes are reversed to PyTorch ``(N, K)`` convention. + """ + from gguf import GGMLQuantizationType, GGUFReader + + reader = GGUFReader(path) + for tensor in reader.tensors: + shape = list(reversed(tensor.shape.tolist())) + ttype = int(tensor.tensor_type) + flat = torch.frombuffer(memoryview(tensor.data), dtype=torch.uint8) + if ttype in _TYPE_BY_GGML_ID: + N = shape[0] + row_bytes = flat.numel() // N + raw = flat.reshape(N, row_bytes).clone() + yield tensor.name, ExportableGGUFTensor.from_raw( + raw, _TYPE_BY_GGML_ID[ttype] + ) + elif tensor.tensor_type == GGMLQuantizationType.F32: + yield tensor.name, flat.view(torch.float32).reshape(shape).clone() + elif tensor.tensor_type == GGMLQuantizationType.F16: + yield tensor.name, flat.view(torch.float16).reshape(shape).to( + torch.bfloat16 + ) + else: + raise ValueError(f"Unsupported GGUF quant type: {tensor.tensor_type}") + + +def load_gguf(path: str) -> Dict[str, Union[ExportableGGUFTensor, Tensor]]: + """Load a GGUF file into ``{name -> ExportableGGUFTensor | Tensor}``. + + Holds all tensors at once; use :func:`iter_gguf` for low peak memory. + """ + return dict(iter_gguf(path)) diff --git a/extension/llm/export/int4.py b/extension/llm/export/int4.py new file mode 100644 index 00000000000..59251ae1875 --- /dev/null +++ b/extension/llm/export/int4.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Int4 export-compatible quantization. + +Wraps a torchao ``Int4Tensor`` (nibble-packed 4-bit groupwise weight) so it +survives ``torch.export`` / ``run_decompositions``: a ``torchao::dequantize_int4_tensor`` +custom op carries the dequant, and ``aten.linear`` / ``aten.embedding`` desugar to +``dequantize_int4_tensor -> linear/embedding`` (mirroring ``dequantize_nvfp4`` / +``dequantize_gguf``). A backend may pattern-match the op to a low-bit kernel; the +eager body is a plain affine dequant so the representation is portable. + +The tensor stores the ``Int4Tensor`` layout verbatim: + * ``qdata`` ``(N, K // 2)`` uint8, two nibbles/byte (even index -> low nibble), + unsigned values in [0, 15]. + * ``scale`` ``(K // group_size, N)``. + * ``zero_point`` ``(K // group_size, N)``, unsigned values in [0, 15]. +Dequant is ``scale * (q - zero_point)`` per group. +""" + +import torch +from torch import Tensor +from torchao.utils import TorchAOBaseTensor + +aten = torch.ops.aten + + +def _dequantize_int4( + qdata: Tensor, + scale: Tensor, + zero_point: Tensor, + group_size: int, + output_dtype: torch.dtype, +) -> Tensor: + """Eager affine dequant of an ``Int4Tensor``-layout weight to ``(N, K)``.""" + p = qdata.view(torch.uint8) + low = (p & 0x0F).to(torch.int32) + high = ((p >> 4) & 0x0F).to(torch.int32) + # Two nibbles/byte: even index -> low, odd -> high. + q = torch.stack([low, high], dim=-1).reshape(p.shape[0], -1).to(torch.float32) + + # scale / zero_point are (K // gs, N) -> transpose to (N, K // gs) and expand. + s = scale.t().to(torch.float32).repeat_interleave(group_size, dim=-1) + z = zero_point.t().to(torch.float32).repeat_interleave(group_size, dim=-1) + return ((q - z) * s).to(output_dtype) + + +@torch.library.custom_op("torchao::dequantize_int4_tensor", mutates_args=()) +def dequantize_int4_tensor( + qdata: Tensor, + scale: Tensor, + zero_point: Tensor, + group_size: int, + output_dtype: torch.dtype = torch.bfloat16, +) -> Tensor: + """Dequantize a nibble-packed Int4 weight (``(N, K//2)`` uint8) to ``(N, K)``.""" + return _dequantize_int4(qdata, scale, zero_point, group_size, output_dtype) + + +@dequantize_int4_tensor.register_fake +def _(qdata, scale, zero_point, group_size, output_dtype=torch.bfloat16): + K = qdata.shape[1] * 2 # two 4-bit values per byte + return torch.empty(qdata.shape[0], K, dtype=output_dtype, device=qdata.device) + + +class ExportableInt4Tensor(TorchAOBaseTensor): + """Int4 tensor subclass that dequantizes via a registered custom op.""" + + tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_attribute_names = ["group_size", "orig_dtype"] + + def __new__(cls, qdata, scale, zero_point, group_size, orig_dtype): + K = qdata.shape[-1] * 2 # two 4-bit values per byte + shape = (qdata.shape[0], K) + self = torch.Tensor._make_wrapper_subclass( + cls, shape, dtype=orig_dtype, device=qdata.device, requires_grad=False + ) + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.group_size = group_size + self.orig_dtype = orig_dtype + return self + + @classmethod + def from_int4_tensor(cls, w: Tensor) -> "ExportableInt4Tensor": + """Build from a torchao ``Int4Tensor`` (copies its packed fields).""" + return cls( + w.qdata, + w.scale, + w.zero_point, + int(w.block_size[-1]), + w.dtype, + ) + + def dequantize(self, output_dtype=None): + return torch.ops.torchao.dequantize_int4_tensor( + self.qdata, + self.scale, + self.zero_point, + self.group_size, + output_dtype=output_dtype or self.orig_dtype, + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + +implements = ExportableInt4Tensor.implements + + +@implements([aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight = args[0], args[1] + bias = args[2] if len(args) > 2 else None + return torch.nn.functional.linear( + input_tensor, weight.dequantize(input_tensor.dtype), bias + ) + + +@implements([aten.embedding.default]) +def _(func, types, args, kwargs): + weight, indices = args[0], args[1] + return torch.nn.functional.embedding(indices, weight.dequantize()) + + +@implements([aten.t.default]) +def _(func, types, args, kwargs): + return args[0].dequantize().t() + + +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return args[0] + + +@implements([aten._to_copy.default]) +def _(func, types, args, kwargs): + return args[0].dequantize(output_dtype=kwargs.get("dtype", args[0].orig_dtype)) diff --git a/extension/llm/export/test/test_gguf.py b/extension/llm/export/test/test_gguf.py new file mode 100644 index 00000000000..13e2dff53fc --- /dev/null +++ b/extension/llm/export/test/test_gguf.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for ``extension/llm/export/gguf.py``. + +The reference oracle is the ``gguf`` package's own ``gguf.dequantize`` (which can +dequantize Q4_K / Q6_K). We validate that: + +* ``ExportableGGUFTensor.dequantize`` (and the ``torchao::dequantize_gguf`` op, + whose eager body uses ``gguf``) reproduces ``gguf.dequantize``; +* our hand-written ``to_int4_tensor`` / ``to_intx_unpacked_to_int8_tensor`` + unpack matches ``gguf.dequantize`` (within bf16 storage tolerance); +* using the subclass as a weight dispatches linear/embedding to the fused ops. + +Blocks are crafted with a small fp16 super-block scale and fixed mid-range +sub-scales so dequantized magnitudes are O(1) and bf16 round-trip error is small +and deterministic (random sub-scales can produce near-zero effective scales, +which blow up the bf16 zero-point error for Q4_K). +""" + +import unittest + +import numpy as np +import torch + +try: + import gguf + from gguf import GGMLQuantizationType + + _HAS_GGUF = True +except ImportError: + _HAS_GGUF = False + +from executorch.extension.llm.export.gguf import ( + _Q4_K_BLOCK_BYTES, + _Q6_K_BLOCK_BYTES, + ExportableGGUFTensor, + Q4_K_GROUP_SIZE, +) + + +def _fp16_bytes(x: float) -> torch.Tensor: + return torch.tensor([x], dtype=torch.float16).view(torch.uint8) + + +def _make_q4k_raw(N: int, nb: int, seed: int = 0) -> torch.Tensor: + """A ``(N, nb*144)`` uint8 Q4_K blob with sane, deterministic magnitudes.""" + g = torch.Generator().manual_seed(seed) + blk = torch.randint( + 0, 256, (N * nb, _Q4_K_BLOCK_BYTES), dtype=torch.uint8, generator=g + ) + blk[:, 0:2] = _fp16_bytes(0.01) # d + blk[:, 2:4] = _fp16_bytes(0.01) # dmin + blk[:, 4:16] = 0x21 # fixed mid-range 6-bit sub-scales/mins (non-zero) + return blk.reshape(N, nb * _Q4_K_BLOCK_BYTES) + + +def _make_q6k_raw(N: int, nb: int, seed: int = 0) -> torch.Tensor: + """A ``(N, nb*210)`` uint8 Q6_K blob with sane, deterministic magnitudes.""" + g = torch.Generator().manual_seed(seed) + blk = torch.randint( + 0, 256, (N * nb, _Q6_K_BLOCK_BYTES), dtype=torch.uint8, generator=g + ) + blk[:, 192:208] = 0x10 # fixed int8 sub-scales (non-zero) + blk[:, 208:210] = _fp16_bytes(0.01) # d + return blk.reshape(N, nb * _Q6_K_BLOCK_BYTES) + + +def _gguf_ref(raw: torch.Tensor, qtype) -> torch.Tensor: + return torch.from_numpy(np.asarray(gguf.dequantize(raw.numpy(), qtype))).float() + + +def _int4_to_float(w) -> torch.Tensor: + """Dequantize an ``Int4Tensor`` from its stored fields. + + ``Int4Tensor`` has no working ``dequantize()`` on CPU (``aten.dequantize`` is + unimplemented and the linear path needs fbgemm), so reconstruct directly + from its public fields (this still exercises our nibble-packing). + """ + N, K = int(w.shape[0]), int(w.shape[1]) + gs = w.block_size[1] + q = torch.empty(N, K, dtype=torch.float32) + q[:, ::2] = (w.qdata & 0x0F).float() + q[:, 1::2] = (w.qdata >> 4).float() + scale = w.scale.t().float().repeat_interleave(gs, dim=1) + zero = w.zero_point.t().float().repeat_interleave(gs, dim=1) + return scale * (q - zero) + + +@unittest.skipUnless(_HAS_GGUF, "gguf package not installed") +class TestExportableGGUFTensor(unittest.TestCase): + def test_dequantize_matches_gguf(self): + for ggml_type, qtype, make in ( + ("q4_k", GGMLQuantizationType.Q4_K, _make_q4k_raw), + ("q6_k", GGMLQuantizationType.Q6_K, _make_q6k_raw), + ): + raw = make(N=3, nb=2) + t = ExportableGGUFTensor.from_raw(raw, ggml_type) + self.assertEqual(tuple(t.shape), (3, 2 * 256)) + mine = t.dequantize(torch.float32) + ref = _gguf_ref(raw, qtype) + # .dequantize() routes through gguf, so it should match exactly. + self.assertTrue(torch.equal(mine, ref), f"{qtype}") + + def test_to_intx_unpacked_matches_reference(self): + # Reference is the gguf-package dequant (ExportableGGUFTensor.dequantize); + # the Intx tensor's dequantize exercises our unpacking. Covers Q4_K & Q6_K. + for ggml_type, make in (("q4_k", _make_q4k_raw), ("q6_k", _make_q6k_raw)): + raw = make(N=3, nb=2) + t = ExportableGGUFTensor.from_raw(raw, ggml_type) + ix = t.to_intx_unpacked_to_int8_tensor() + self.assertEqual(tuple(ix.shape), (3, 512)) + # bf16 storage tolerance. + self.assertTrue( + torch.allclose( + ix.dequantize().float(), + t.dequantize(torch.float32), + rtol=1e-2, + atol=5e-2, + ), + ggml_type, + ) + + def test_to_int4_tensor_matches_reference(self): + raw = _make_q4k_raw(N=3, nb=2) + t = ExportableGGUFTensor.from_raw(raw, "q4_k") + w = t.to_int4_tensor() + self.assertEqual(tuple(w.shape), (3, 512)) + self.assertEqual(list(w.block_size), [1, Q4_K_GROUP_SIZE]) + # Int4Tensor has no CPU dequantize(); reconstruct from its packed fields + # (this still exercises our nibble-packing) against the gguf reference. + self.assertTrue( + torch.allclose( + _int4_to_float(w), + t.dequantize(torch.float32), + rtol=1e-2, + atol=5e-2, + ) + ) + + def test_dequantize_gguf_op_matches_reference(self): + for ggml_type, make in (("q4_k", _make_q4k_raw), ("q6_k", _make_q6k_raw)): + raw = make(N=3, nb=2) + t = ExportableGGUFTensor.from_raw(raw, ggml_type) + out = torch.ops.torchao.dequantize_gguf(raw, ggml_type, torch.float32) + self.assertTrue(torch.equal(out, t.dequantize(torch.float32))) + + def test_subclass_linear_dispatches_to_dequant(self): + raw = _make_q6k_raw(N=4, nb=1) + t = ExportableGGUFTensor.from_raw(raw, "q6_k") + x = torch.randn(2, 256, dtype=torch.bfloat16) + out = torch.nn.functional.linear(x, t) + ref = torch.nn.functional.linear(x, t.dequantize(torch.bfloat16)) + self.assertTrue(torch.equal(out, ref)) + + def test_subclass_embedding_dispatches_to_dequant(self): + raw = _make_q6k_raw(N=8, nb=1) + t = ExportableGGUFTensor.from_raw(raw, "q6_k") + idx = torch.tensor([0, 3, 7, 1]) + out = torch.nn.functional.embedding(idx, t) + ref = torch.nn.functional.embedding(idx, t.dequantize(torch.bfloat16)) + self.assertTrue(torch.equal(out, ref)) + + def test_unsupported_type_raises(self): + raw = torch.zeros(1, _Q6_K_BLOCK_BYTES, dtype=torch.uint8) + with self.assertRaises(NotImplementedError): + ExportableGGUFTensor.from_raw(raw, "q5_k") + + +@unittest.skipUnless(_HAS_GGUF, "gguf package not installed") +class TestExportableGGUFTensorExport(unittest.TestCase): + """Exporting a module whose weight is an ``ExportableGGUFTensor`` should + lower linear/embedding through the ``torchao::dequantize_gguf`` op after + ``run_decompositions`` (the subclass dispatch fires during decomposition).""" + + @staticmethod + def _targets(ep): + return {str(n.target) for n in ep.graph.nodes if n.op == "call_function"} + + def test_linear_exports_with_dequantize_gguf(self): + t = ExportableGGUFTensor.from_raw(_make_q6k_raw(N=4, nb=1), "q6_k") + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.nn.Parameter(t, requires_grad=False) + + def forward(self, x): + return torch.nn.functional.linear(x, self.w) + + ep = torch.export.export( + M(), (torch.randn(2, 256, dtype=torch.bfloat16),) + ).run_decompositions({}) + self.assertIn("torchao.dequantize_gguf.default", self._targets(ep)) + + def test_embedding_exports_with_dequantize_gguf(self): + t = ExportableGGUFTensor.from_raw(_make_q6k_raw(N=8, nb=1), "q6_k") + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.nn.Parameter(t, requires_grad=False) + + def forward(self, idx): + return torch.nn.functional.embedding(idx, self.w) + + ep = torch.export.export(M(), (torch.tensor([0, 1, 2, 3]),)).run_decompositions( + {} + ) + self.assertIn("torchao.dequantize_gguf.default", self._targets(ep)) + + +if __name__ == "__main__": + unittest.main() diff --git a/extension/llm/export/test/test_int4.py b/extension/llm/export/test/test_int4.py new file mode 100644 index 00000000000..9414248d59a --- /dev/null +++ b/extension/llm/export/test/test_int4.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for ExportableInt4Tensor + the torchao::dequantize_int4_tensor op.""" + +import unittest + +import torch +from executorch.extension.llm.export.int4 import ExportableInt4Tensor + + +def _make_int4_tensor(N: int, K: int, gs: int, seed: int = 0): + """Build a synthetic ``Int4Tensor`` plus the (q, scale, zero_point) it encodes. + + Returns ``(int4_tensor, q_unsigned (N,K), scale (K//gs,N), zero (K//gs,N))``. + """ + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + g = torch.Generator().manual_seed(seed) + q = torch.randint(0, 16, (N, K), generator=g, dtype=torch.int32) # unsigned [0,15] + # Pack two nibbles/byte: even index -> low, odd -> high. + packed = (q[:, 0::2] | (q[:, 1::2] << 4)).to(torch.uint8) + scale = (torch.randn(K // gs, N, generator=g) * 0.1).to(torch.bfloat16) + zero = torch.randint(0, 16, (K // gs, N), generator=g).to(torch.bfloat16) + it = Int4Tensor( + qdata=packed, + scale=scale, + zero_point=zero, + block_size=[1, gs], + shape=torch.Size([N, K]), + ) + return it, q, scale, zero + + +def _reference_dequant(q, scale, zero, gs): + """Independent affine dequant: scale * (q - zero), groups expanded.""" + s = scale.t().to(torch.float32).repeat_interleave(gs, dim=-1) + z = zero.t().to(torch.float32).repeat_interleave(gs, dim=-1) + return (q.to(torch.float32) - z) * s + + +class TestDequantizeInt4Op(unittest.TestCase): + def test_op_matches_reference(self): + it, q, scale, zero = _make_int4_tensor(N=8, K=64, gs=32) + out = torch.ops.torchao.dequantize_int4_tensor( + it.qdata, it.scale, it.zero_point, 32, torch.float32 + ) + ref = _reference_dequant(q, scale, zero, 32) + self.assertEqual(tuple(out.shape), (8, 64)) + self.assertTrue(torch.allclose(out, ref, rtol=1e-2, atol=5e-2)) + + def test_subclass_dequantize_matches_op(self): + it, _, _, _ = _make_int4_tensor(N=8, K=64, gs=32) + t = ExportableInt4Tensor.from_int4_tensor(it) + ref = torch.ops.torchao.dequantize_int4_tensor( + it.qdata, it.scale, it.zero_point, 32, torch.bfloat16 + ) + self.assertTrue(torch.equal(t.dequantize(torch.bfloat16), ref)) + + def test_subclass_linear_dispatches_to_dequant(self): + it, _, _, _ = _make_int4_tensor(N=16, K=64, gs=32) + t = ExportableInt4Tensor.from_int4_tensor(it) + x = torch.randn(2, 64, dtype=torch.bfloat16) + out = torch.nn.functional.linear(x, t) + ref = torch.nn.functional.linear(x, t.dequantize(torch.bfloat16)) + self.assertTrue(torch.equal(out, ref)) + + def test_subclass_embedding_dispatches_to_dequant(self): + it, _, _, _ = _make_int4_tensor(N=16, K=64, gs=32) + t = ExportableInt4Tensor.from_int4_tensor(it) + idx = torch.tensor([0, 3, 7, 1]) + out = torch.nn.functional.embedding(idx, t) + ref = torch.nn.functional.embedding(idx, t.dequantize(torch.bfloat16)) + self.assertTrue(torch.equal(out, ref)) + + +class TestExportableInt4TensorExport(unittest.TestCase): + """Exporting a module whose weight is an ``ExportableInt4Tensor`` should lower + linear/embedding through ``torchao::dequantize_int4_tensor`` after + ``run_decompositions`` (the subclass dispatch fires during decomposition).""" + + @staticmethod + def _targets(ep): + return {str(n.target) for n in ep.graph.nodes if n.op == "call_function"} + + def test_linear_exports_with_dequantize_int4(self): + it, _, _, _ = _make_int4_tensor(N=16, K=64, gs=32) + t = ExportableInt4Tensor.from_int4_tensor(it) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.nn.Parameter(t, requires_grad=False) + + def forward(self, x): + return torch.nn.functional.linear(x, self.w) + + ep = torch.export.export( + M(), (torch.randn(2, 64, dtype=torch.bfloat16),) + ).run_decompositions({}) + self.assertIn("torchao.dequantize_int4_tensor.default", self._targets(ep)) + + def test_embedding_exports_with_dequantize_int4(self): + it, _, _, _ = _make_int4_tensor(N=16, K=64, gs=32) + t = ExportableInt4Tensor.from_int4_tensor(it) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.nn.Parameter(t, requires_grad=False) + + def forward(self, idx): + return torch.nn.functional.embedding(idx, self.w) + + ep = torch.export.export(M(), (torch.tensor([0, 1, 2, 3]),)).run_decompositions( + {} + ) + self.assertIn("torchao.dequantize_int4_tensor.default", self._targets(ep)) + + +if __name__ == "__main__": + unittest.main() diff --git a/requirements-dev.txt b/requirements-dev.txt index d2c3b5fcc20..71c68c968ec 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,4 +14,5 @@ lintrunner-adapters==0.14.0 pytest<9.0 pytest-xdist pytest-rerunfailures==15.1 -pytest-json-report \ No newline at end of file +pytest-json-report +gguf # For extension/llm/export/test/test_gguf.py (GGUF Q4_K/Q6_K dequant tests). \ No newline at end of file