Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions .github/workflows/mlx.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ on:
- backends/mlx/**
- extension/llm/export/**
- extension/audio/**
- examples/models/gemma4_31b/**
- examples/models/parakeet/**
- examples/models/voxtral_realtime/**
- examples/models/qwen3_5_moe/**
Expand Down Expand Up @@ -77,6 +78,8 @@ jobs:
backends/mlx/test/test_passes.py \
backends/mlx/test/test_pattern_utils.py \
backends/mlx/test/test_partitioner.py \
backends/mlx/test/test_serialization_dedup.py \
examples/models/gemma4_31b/quant/tests/test_pack_mlx.py \
examples/models/gemma4_31b/tests/test_mlx_pipeline.py \
-v
echo "::endgroup::"
Expand All @@ -89,20 +92,16 @@ jobs:
./cmake-out/backends/mlx/test/multi_thread_test_runner
echo "::endgroup::"

echo "::group::Run gated_delta_rule op tests"
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v
echo "::endgroup::"

echo "::group::Run tq_norm op tests"
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_norm run -v
echo "::endgroup::"

echo "::group::Run tq4_compress op tests"
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq4_compress run -v
echo "::endgroup::"

echo "::group::Run tq_dequant op tests"
${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_dequant run -v
echo "::group::Run custom_kernel_ops op tests"
# Run every custom_kernel_ops/**/test/test_*.py via its OpTestCase `run`
# CLI. Recurses into per-format subpackages (e.g. gguf/test), so adding a
# new op test file requires no change here.
set -e
for t in $(find backends/mlx/custom_kernel_ops -path '*/test/test_*.py' | sort); do
mod="executorch.$(echo "${t%.py}" | tr '/' '.')"
echo "--- ${mod} ---"
${CONDA_RUN} python -m "${mod}" run -v
done
echo "::endgroup::"

test-mlx-qwen35-moe:
Expand Down
101 changes: 101 additions & 0 deletions backends/mlx/builder/op_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,79 @@ def emit_quantized_biases(
return biases


def emit_quantized_gather(
P: MLXProgramBuilder,
out: Slot,
indices_slot: Slot,
qdata_slot: Slot,
scales_slot: Slot,
biases_slot: Optional[Slot],
*,
group_size: int,
bits: int,
mode: str,
out_dtype: torch.dtype,
) -> None:
"""Gather quantized rows by index and dequantize them into ``out``.

Emits ``TakeNode`` for qdata and scales (and biases when present), then a
``DequantizeNode``.
"""
from executorch.backends.mlx.serialization.mlx_graph_schema import (
DequantizeNode,
IntOrVidOrTid,
TakeNode,
)

ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(indices_slot))

_, wq_sel = P.make_tmp_slot()
P.emit(
TakeNode(
x=P.slot_to_tid(qdata_slot),
index=ids_index,
out=P.slot_to_tid(wq_sel),
axis=0,
)
)

_, sc_sel = P.make_tmp_slot()
P.emit(
TakeNode(
x=P.slot_to_tid(scales_slot),
index=ids_index,
out=P.slot_to_tid(sc_sel),
axis=0,
)
)

biases_tid = None
if biases_slot is not None:
_, b_sel = P.make_tmp_slot()
P.emit(
TakeNode(
x=P.slot_to_tid(biases_slot),
index=ids_index,
out=P.slot_to_tid(b_sel),
axis=0,
)
)
biases_tid = P.slot_to_tid(b_sel)

P.emit(
DequantizeNode(
w=P.slot_to_tid(wq_sel),
scales=P.slot_to_tid(sc_sel),
out=P.slot_to_tid(out),
biases=biases_tid,
group_size=group_size,
bits=bits,
mode=mode,
dtype=torch_dtype_to_scalar_type(out_dtype),
)
)


def to_mlx_qparams(
qdata: torch.Tensor,
scale: torch.Tensor,
Expand Down Expand Up @@ -421,6 +494,34 @@ def parse_dequant_nvfp4_node(
return qdata, scale, per_tensor_scale, output_dtype


def parse_dequant_int4_node(
node: Node,
) -> Optional[Tuple[Node, Node, Node, int, Optional[torch.dtype]]]:
"""Parse a torchao.dequantize_int4_tensor node.

Returns (qdata, scale, zero_point, group_size, output_dtype) or None if not a
dequantize_int4_tensor node or the custom op is not registered.
"""
target = get_aten_target(node.target)
try:
import executorch.extension.llm.export.int4 # noqa: F401
except ImportError:
return None

if target is not torch.ops.torchao.dequantize_int4_tensor.default:
return None

qdata, scale, zero_point, group_size = node.args[0:4]

output_dtype = None
if len(node.args) > 4:
output_dtype = node.args[4]
elif "output_dtype" in node.kwargs:
output_dtype = node.kwargs["output_dtype"]

return qdata, scale, zero_point, group_size, output_dtype


def parse_dequant_node(
node: Node,
) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]:
Expand Down
18 changes: 18 additions & 0 deletions backends/mlx/custom_kernel_ops/gguf/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#

"""GGUF-quantized weight lowering for the MLX backend.

Import :mod:`.patterns` for its side effect to enable lowering of
``torchao::dequantize_gguf -> linear/embedding`` to the Q6_K / Q4_K kernels::

import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401

This ``__init__`` is side-effect free, so importing ``.q6k`` for the pure-torch
dequant does not pull in the MLX builder/registry.
"""
167 changes: 167 additions & 0 deletions backends/mlx/custom_kernel_ops/gguf/patterns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#

"""MLX pattern handlers for GGUF-quantized weights.

``ExportableGGUFTensor`` (extension/llm/export/gguf.py) lowers a quantized
linear/embedding to::

linear(x, torchao::dequantize_gguf(weight, ggml_type, out_dtype), bias)
embedding(torchao::dequantize_gguf(weight, ggml_type, out_dtype), indices)

These handlers match that ``dequantize_gguf -> linear/embedding`` subgraph and
lower it without materializing the dequantized weight:

* **Q6_K** -> fused custom Metal kernels in :mod:`.q6k`.
* **Q4_K** -> MLX's native 4-bit affine ops via :mod:`.q4k` (GGUF blocks
repacked into MLX qparams at export time).

Both cover linear and embedding.

Other quant types are left unmatched (the caller is expected to convert them to a
torchao ``Int4Tensor`` / ``IntxUnpackedToInt8Tensor`` first).

Importing this module registers the patterns as a side effect.
"""

from __future__ import annotations

from typing import Optional, Tuple

import torch
from executorch.backends.mlx.builder.op_helpers import get_aten_target
from executorch.backends.mlx.builder.op_registry import PatternHandler, REGISTRY
from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder
from executorch.backends.mlx.builder.slot_manager import Slot
from executorch.backends.mlx.pattern_utils import has_single_user, match_target
from torch.export.exported_program import ExportedProgram
from torch.fx.node import Node

# Quant types each pattern can lower (Q6_K via custom Metal kernels, Q4_K via
# MLX-native affine ops).
_LINEAR_TYPES = {"q4_k", "q6_k"}
_EMBEDDING_TYPES = {"q4_k", "q6_k"}


def parse_dequantize_gguf_node(
node: Node,
) -> Optional[Tuple[Node, str, torch.dtype]]:
"""Parse a ``torchao::dequantize_gguf`` node.

Returns ``(weight_node, ggml_type, output_dtype)`` or ``None`` if ``node`` is
not a ``dequantize_gguf`` node (or the op isn't registered).
"""
try:
import executorch.extension.llm.export.gguf # noqa: F401 registers the op
except ImportError:
return None

if get_aten_target(node.target) is not torch.ops.torchao.dequantize_gguf.default:
return None

weight = node.args[0]
ggml_type = node.args[1]
output_dtype = torch.bfloat16
if len(node.args) > 2:
output_dtype = node.args[2]
elif "output_dtype" in node.kwargs:
output_dtype = node.kwargs["output_dtype"]
return weight, ggml_type, output_dtype


@REGISTRY.register_pattern(name="GGUF_QUANTIZED_LINEAR")
class GGUFQuantizedLinearHandler(PatternHandler):
"""Lower ``dequantize_gguf + linear`` to a fused quantized matmul.

Matches ``linear(x, dequantize_gguf(weight, ggml_type, out_dtype), bias)``
and dispatches on ``ggml_type``: Q6_K -> custom Metal kernels, Q4_K -> MLX
4-bit ``quantized_matmul``.
"""

def __init__(self, head, body, weight, ggml_type, output_dtype):
super().__init__(head, body)
self.weight = weight
self.ggml_type = ggml_type
self.output_dtype = output_dtype

@classmethod
def maybe_create(cls, ep: ExportedProgram, head: Node):
if not match_target(head, torch.ops.aten.linear.default):
return None
if len(head.args) < 2 or not isinstance(head.args[1], Node):
return None
dequant = head.args[1]
if not has_single_user(dequant):
return None
parsed = parse_dequantize_gguf_node(dequant)
if parsed is None:
return None
weight, ggml_type, output_dtype = parsed
if ggml_type not in _LINEAR_TYPES:
return None
return cls(head, [dequant], weight, ggml_type, output_dtype)

def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot:
assert n == self.head
x_node = n.args[0]
bias_node = n.args[2] if len(n.args) > 2 else None
if self.ggml_type == "q6_k":
from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.linear import (
emit_linear,
)
else: # q4_k
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear import (
emit_linear,
)
return emit_linear(P, n, x_node, self.weight, bias_node)


@REGISTRY.register_pattern(name="GGUF_QUANTIZED_EMBEDDING")
class GGUFQuantizedEmbeddingHandler(PatternHandler):
"""Lower ``dequantize_gguf + embedding`` to a quantized gather.

Matches ``embedding(dequantize_gguf(weight, ggml_type, out_dtype), indices)``
and dispatches on ``ggml_type``: Q6_K -> custom Metal gather, Q4_K -> MLX
quantized gather.
"""

def __init__(self, head, body, weight, ggml_type, output_dtype):
super().__init__(head, body)
self.weight = weight
self.ggml_type = ggml_type
self.output_dtype = output_dtype

@classmethod
def maybe_create(cls, ep: ExportedProgram, head: Node):
if not match_target(head, torch.ops.aten.embedding.default):
return None
if len(head.args) < 2 or not isinstance(head.args[0], Node):
return None
dequant = head.args[0]
if not has_single_user(dequant):
return None
parsed = parse_dequantize_gguf_node(dequant)
if parsed is None:
return None
weight, ggml_type, output_dtype = parsed
if ggml_type not in _EMBEDDING_TYPES:
return None
return cls(head, [dequant], weight, ggml_type, output_dtype)

def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot:
assert n == self.head
indices_node = n.args[1]
if self.ggml_type == "q6_k":
from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.embedding import (
emit_embedding,
)
else: # q4_k
from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding import (
emit_embedding,
)
return emit_embedding(P, n, self.weight, indices_node, self.output_dtype)
14 changes: 14 additions & 0 deletions backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#

"""GGUF Q4_K format lowering for the MLX backend (native affine 4-bit).
See :mod:`.linear` / :mod:`.embedding` for the ``emit_*`` lowerings (called by
``custom_kernel_ops.gguf.patterns``); they are not imported here to keep the
package import light.
"""
Loading
Loading