From 003ec56319715b84944d9176638cee461c5ea957 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 3 Jun 2026 16:53:21 -0700 Subject: [PATCH 01/18] up --- .github/workflows/mlx.yml | 24 +- backends/mlx/model_ops/gguf_linear.py | 686 ++++++++++++++++++ backends/mlx/model_ops/test_gguf_linear.py | 349 +++++++++ backends/mlx/runtime/MLXInterpreter.h | 16 + backends/mlx/serialization/schema.fbs | 12 +- examples/models/gemma4_31b/export.py | 19 +- examples/models/gemma4_31b/gguf_loader.py | 18 +- examples/models/gemma4_31b/mlx_gguf_linear.py | 89 +++ examples/models/gemma4_31b/quant/gguf.py | 37 +- .../gemma4_31b/quant/tests/test_gguf.py | 46 ++ .../gemma4_31b/tests/test_mlx_pipeline.py | 66 ++ 11 files changed, 1329 insertions(+), 33 deletions(-) create mode 100644 backends/mlx/model_ops/gguf_linear.py create mode 100644 backends/mlx/model_ops/test_gguf_linear.py create mode 100644 examples/models/gemma4_31b/mlx_gguf_linear.py diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 38914f7612b..69c767224fd 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -78,6 +78,7 @@ jobs: backends/mlx/test/test_pattern_utils.py \ backends/mlx/test/test_partitioner.py \ examples/models/gemma4_31b/tests/test_mlx_pipeline.py \ + examples/models/gemma4_31b/quant/tests/test_gguf.py \ -v echo "::endgroup::" @@ -89,20 +90,15 @@ jobs: ./cmake-out/backends/mlx/test/multi_thread_test_runner echo "::endgroup::" - echo "::group::Run gated_delta_rule op tests" - ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v - echo "::endgroup::" - - echo "::group::Run tq_norm op tests" - ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_norm run -v - echo "::endgroup::" - - echo "::group::Run tq4_compress op tests" - ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq4_compress run -v - echo "::endgroup::" - - echo "::group::Run tq_dequant op tests" - ${CONDA_RUN} python -m executorch.backends.mlx.model_ops.test_tq_dequant run -v + echo "::group::Run model_ops op tests" + # Run every model_ops/test_*.py via its OpTestCase `run` CLI. Adding a + # new op test file requires no change here. + set -e + for t in backends/mlx/model_ops/test_*.py; do + mod="executorch.backends.mlx.model_ops.$(basename "$t" .py)" + echo "--- ${mod} ---" + ${CONDA_RUN} python -m "${mod}" run -v + done echo "::endgroup::" test-mlx-qwen35-moe: diff --git a/backends/mlx/model_ops/gguf_linear.py b/backends/mlx/model_ops/gguf_linear.py new file mode 100644 index 00000000000..a58096aa131 --- /dev/null +++ b/backends/mlx/model_ops/gguf_linear.py @@ -0,0 +1,686 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +``mlx::gguf_linear``: linear layer against a GGUF-quantized weight. + + out = x @ dequant(weight)^T (+ bias) + +The weight is stored in the **exact GGUF packed block layout** (no repacking), +so weights converted by llama.cpp / gguf-py can be consumed directly. The +``format`` argument selects the GGUF quantization type; only ``"q6k"`` is +supported initially and anything else raises ``NotImplementedError``. + +Q6_K layout (per 256-element super-block, 210 bytes, see llama.cpp +``block_q6_K`` in ``ggml-common.h``):: + + uint8 ql[128] # quants, lower 4 bits + uint8 qh[64] # quants, upper 2 bits + int8 scales[16] # per-16-element sub-block scales (8-bit) + half d # super-block scale + +The dequantized value for a 6-bit code ``q`` (0..63) in sub-block ``s`` is +``d * scales[s] * (q - 32)``. + +Compute is keyed on the activation dtype (matching GGUF/llama.cpp): the Metal +kernels are templated on ``InT``, accumulate in ``float32``, read ``d`` as +``half``, and produce output in the activation dtype. + +Two kernels are emitted depending on the number of activation rows ``M``: + + * ``M == 1`` (decode): a fused mat-vec kernel ported from llama.cpp + ``kernel_mul_mv_q6_K_f32_impl``. + * static ``M > 1`` (prefill): a tiled simdgroup mat-mat kernel that + dequantizes weight tiles into threadgroup memory and reuses them across + the activation rows. + * dynamic/symbolic ``M`` (single program serving both prefill and decode): + both kernels are emitted into separate instruction chains and selected at + runtime via an ``IfNode`` on ``M`` (``M > 1`` -> mat-mat, ``M == 1`` -> + mat-vec). + +Usage:: + + import executorch.backends.mlx.model_ops.gguf_linear # noqa: F401 + + out = torch.ops.mlx.gguf_linear(x, weight, "q6k", bias) + # x: (..., K) bf16 / fp16 / fp32 + # weight: (N, (K/256)*210) uint8 GGUF q6_K blob + # bias: (N,) or None activation dtype + # out: (..., N) activation dtype +""" + +from __future__ import annotations + +from typing import Optional + +import torch +from torch import Tensor +from torch.fx.node import Node + + +# --------------------------------------------------------------------------- +# Q6_K constants and pure-torch reference (also used by the eager fallback) +# --------------------------------------------------------------------------- + +QK_K = 256 +# Per-super-block byte counts. +_Q6K_QL_BYTES = QK_K // 2 # 128 +_Q6K_QH_BYTES = QK_K // 4 # 64 +_Q6K_SCALES = QK_K // 16 # 16 +_Q6K_D_BYTES = 2 # one fp16 +Q6K_BLOCK_BYTES = _Q6K_QL_BYTES + _Q6K_QH_BYTES + _Q6K_SCALES + _Q6K_D_BYTES # 210 + + +def dequantize_q6_k(weight: Tensor, K: int) -> Tensor: + """Dequantize a GGUF Q6_K blob to float32. + + Args: + weight: ``(N, n_blocks * 210)`` uint8, GGUF ``block_q6_K`` layout. + K: number of logical input features (``n_blocks * 256``). + + Returns: + ``(N, K)`` float32 dequantized weight. + """ + if weight.dtype != torch.uint8: + raise ValueError(f"gguf_linear: weight must be uint8; got {weight.dtype}") + N = weight.shape[0] + nb = K // QK_K + if weight.shape[-1] != nb * Q6K_BLOCK_BYTES: + raise ValueError( + f"gguf_linear: weight row bytes {weight.shape[-1]} != " + f"{nb} blocks * {Q6K_BLOCK_BYTES}" + ) + + blocks = weight.view(N, nb, Q6K_BLOCK_BYTES) + ql = blocks[..., 0:_Q6K_QL_BYTES].to(torch.int32) + qh = blocks[..., _Q6K_QL_BYTES : _Q6K_QL_BYTES + _Q6K_QH_BYTES].to(torch.int32) + sc_off = _Q6K_QL_BYTES + _Q6K_QH_BYTES + scales = ( + blocks[..., sc_off : sc_off + _Q6K_SCALES] + .contiguous() + .view(torch.int8) + .to(torch.float32) + ) + d = ( + blocks[..., sc_off + _Q6K_SCALES : sc_off + _Q6K_SCALES + _Q6K_D_BYTES] + .contiguous() + .view(torch.float16) + .to(torch.float32) + ) # (N, nb, 1) + + y = torch.empty(N, nb, QK_K, dtype=torch.float32, device=weight.device) + # is = l // 16 over l in 0..31 -> selects which of the 8 half-scales. + is_idx = (torch.arange(32, device=weight.device) // 16).long() # (32,) + + for h in range(2): # two 128-element halves + ql_h = ql[..., h * 64 : h * 64 + 64] # (N, nb, 64) + qh_h = qh[..., h * 32 : h * 32 + 32] # (N, nb, 32) + sc_h = scales[..., h * 8 : h * 8 + 8] # (N, nb, 8) + + ql_lo = ql_h[..., 0:32] + ql_hi = ql_h[..., 32:64] + + q1 = (ql_lo & 0xF) | ((qh_h & 0x3) << 4) + q2 = (ql_hi & 0xF) | (((qh_h >> 2) & 0x3) << 4) + q3 = (ql_lo >> 4) | (((qh_h >> 4) & 0x3) << 4) + q4 = (ql_hi >> 4) | (((qh_h >> 6) & 0x3) << 4) + + sc0 = sc_h[..., is_idx + 0] + sc2 = sc_h[..., is_idx + 2] + sc4 = sc_h[..., is_idx + 4] + sc6 = sc_h[..., is_idx + 6] + + base = h * 128 + y[..., base + 0 : base + 32] = d * sc0 * (q1 - 32).to(torch.float32) + y[..., base + 32 : base + 64] = d * sc2 * (q2 - 32).to(torch.float32) + y[..., base + 64 : base + 96] = d * sc4 * (q3 - 32).to(torch.float32) + y[..., base + 96 : base + 128] = d * sc6 * (q4 - 32).to(torch.float32) + + return y.reshape(N, K) + + +# --------------------------------------------------------------------------- +# Custom op + eager fallback +# --------------------------------------------------------------------------- + + +@torch.library.custom_op("mlx::gguf_linear", mutates_args=()) +def gguf_linear( + x: Tensor, + weight: Tensor, + format: str, + bias: Optional[Tensor] = None, +) -> Tensor: + """Linear against a GGUF-quantized weight. + + Args: + x: ``(..., K)`` activations (bf16 / fp16 / fp32). + weight: ``(N, (K/256)*210)`` uint8 GGUF ``q6_K`` blob. + format: GGUF quant type; only ``"q6k"`` supported. + bias: optional ``(N,)`` of activation dtype. + + Returns: + ``(..., N)`` of activation dtype. + """ + if format != "q6k": + raise NotImplementedError( + f"mlx::gguf_linear: unsupported format {format!r}; only 'q6k' is supported" + ) + if weight.dim() != 2: + raise ValueError( + f"mlx::gguf_linear: weight must be 2-D (N, row_bytes); got " + f"shape {tuple(weight.shape)}" + ) + N, row_bytes = weight.shape + if row_bytes % Q6K_BLOCK_BYTES != 0: + raise ValueError( + f"mlx::gguf_linear: weight row bytes {row_bytes} must be a multiple of " + f"{Q6K_BLOCK_BYTES} (one q6_K block per 256 features)" + ) + K = (row_bytes // Q6K_BLOCK_BYTES) * QK_K + if x.shape[-1] != K: + raise ValueError( + f"mlx::gguf_linear: x last dim {x.shape[-1]} != K {K} implied by weight" + ) + + w_deq = dequantize_q6_k(weight, K) # (N, K) float32 + out = torch.matmul(x.to(torch.float32), w_deq.t()) # (..., N) float32 + if bias is not None: + out = out + bias.to(torch.float32) + return out.to(x.dtype) + + +@torch.library.register_fake("mlx::gguf_linear") +def gguf_linear_fake( + x: Tensor, + weight: Tensor, + format: str, + bias: Optional[Tensor] = None, +) -> Tensor: + N = weight.shape[0] + out_shape = list(x.shape) + out_shape[-1] = N + return x.new_empty(out_shape, dtype=x.dtype) + + +# --------------------------------------------------------------------------- +# MLX handler +# --------------------------------------------------------------------------- + +from executorch.backends.mlx.builder.op_helpers import ( + emit_product, + emit_shape, + torch_dtype_to_scalar_type, +) +from executorch.backends.mlx.builder.op_registry import REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AddIntNode, + FloorDivideIntNode, + IfNode, + IntOrVid, + MetalKernelNode, + SubtractIntNode, +) + + +# Shared Metal header: the GGUF block_q6_K struct (matches llama.cpp +# ggml-common.h; sizeof == 210, no padding since max align is 2) plus a +# per-element dequant helper used by the mat-mat kernel. +_Q6K_HEADER = """ +#include +#include +using namespace metal; + +#define QK_K 256 + +typedef struct { + uint8_t ql[QK_K/2]; // lower 4 bits + uint8_t qh[QK_K/4]; // upper 2 bits + int8_t scales[QK_K/16]; // per-16-element sub-block scales + half d; // super-block scale +} block_q6_K; + +// Dequantize a single element at within-block position p (0..255) of a +// block_q6_K. Matches the canonical ggml dequantize_row_q6_K layout. +inline float dequant_q6k_elem(device const block_q6_K * blk, int p) { + const int h = p >> 7; // which 128-element half (0/1) + const int pp = p & 127; // position within half (0..127) + const int g = pp >> 5; // group: 0=q1, 1=q2, 2=q3, 3=q4 + const int l = pp & 31; // 0..31 + device const uint8_t * ql = blk->ql + h * 64; + device const uint8_t * qh = blk->qh + h * 32; + device const int8_t * sc = blk->scales + h * 8; + const int is = l >> 4; // 0/1 + const uint8_t qhb = qh[l]; + int q; + if (g == 0) { q = (ql[l] & 0xF) | ((qhb & 0x03) << 4); } + else if (g == 1) { q = (ql[l + 32] & 0xF) | ((qhb & 0x0C) << 2); } + else if (g == 2) { q = (ql[l] >> 4) | ((qhb & 0x30) << 0); } + else { q = (ql[l + 32] >> 4) | ((qhb & 0xC0) >> 2); } + const float scale = (float) sc[is + 2 * g]; + return (float) blk->d * scale * (float)(q - 32); +} +""" + + +# Decode mat-vec kernel, ported from llama.cpp kernel_mul_mv_q6_K_f32_impl. +# Threadgroup = (32 * NSG, 1, 1): NSG simdgroups, each computing N_R0 output +# rows for one activation row (grid.y). Accumulate in float, reduce via simd_sum. +def _q6k_matvec_source(has_bias: bool) -> str: + write = "out[(uint)m * N + r] = (InT)(tot" + write += " + (float)bias[r]);" if has_bias else ");" + return f""" + constexpr short N_R0 = 2; + + const ushort tiisg = thread_index_in_simdgroup; + const ushort sgitg = simdgroup_index_in_threadgroup; + const uint m = thread_position_in_grid.y; + const uint tgx = thread_position_in_grid.x / (32u * NSG); + const int nb = K / QK_K; + const int first_row = (int)(tgx * NSG + sgitg) * N_R0; + + const short tid = tiisg / 2; + const short ix = tiisg % 2; + const short ip = tid / 8; // 0 or 1 (which 128-half) + const short il = tid % 8; + const short l0 = 4 * il; + const short is = 8 * ip + l0 / 16; + + const short y_offset = 128 * ip + l0; + const short q_offset_l = 64 * ip + l0; + const short q_offset_h = 32 * ip + l0; + + device const block_q6_K * xrows = (device const block_q6_K *) weight; + device const InT * yy = x + (uint)m * (uint)K; + + float sumf[N_R0]; + for (short r = 0; r < N_R0; ++r) {{ sumf[r] = 0.f; }} + + float yl[16]; + for (int i = ix; i < nb; i += 2) {{ + device const InT * yb = yy + i * QK_K + y_offset; + for (short l = 0; l < 4; ++l) {{ + yl[4*l + 0] = (float) yb[l + 0]; + yl[4*l + 1] = (float) yb[l + 32]; + yl[4*l + 2] = (float) yb[l + 64]; + yl[4*l + 3] = (float) yb[l + 96]; + }} + + for (short row = 0; row < N_R0; ++row) {{ + const int r = first_row + row; + if (r >= N) {{ break; }} + device const block_q6_K * blk = xrows + (uint)r * nb + i; + device const uint8_t * q1 = blk->ql + q_offset_l; + device const uint8_t * q2 = q1 + 32; + device const uint8_t * qh = blk->qh + q_offset_h; + device const int8_t * sc = blk->scales + is; + const float d = (float) blk->d; + + float4 sums = {{0.f, 0.f, 0.f, 0.f}}; + for (short l = 0; l < 4; ++l) {{ + sums[0] += yl[4*l + 0] * (float)((int8_t)((q1[l] & 0xF) | ((qh[l] & 0x03) << 4)) - 32); + sums[1] += yl[4*l + 1] * (float)((int8_t)((q2[l] & 0xF) | ((qh[l] & 0x0C) << 2)) - 32); + sums[2] += yl[4*l + 2] * (float)((int8_t)((q1[l] >> 4) | ((qh[l] & 0x30) << 0)) - 32); + sums[3] += yl[4*l + 3] * (float)((int8_t)((q2[l] >> 4) | ((qh[l] & 0xC0) >> 2)) - 32); + }} + sumf[row] += d * (sums[0]*sc[0] + sums[1]*sc[2] + sums[2]*sc[4] + sums[3]*sc[6]); + }} + }} + + for (short row = 0; row < N_R0; ++row) {{ + const int r = first_row + row; + const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && r < N) {{ + {write} + }} + }} +""" + + +# Prefill mat-mat kernel (32x32x32 tiles, 4 simdgroups / 128 threads). +# Each threadgroup computes a BM x BN output tile; weights are dequantized +# on the fly into threadgroup memory and reused across the BM activation rows. +# C[m, n] = sum_k x[m, k] * dequant(weight)[n, k] (+ bias[n]). +def _q6k_matmul_source(has_bias: bool) -> str: + bias_line = " v += (float) bias[gn];\n" if has_bias else "" + return f""" + constexpr short BM = 32; // activation rows per tile (M) + constexpr short BN = 32; // output features per tile (N) + constexpr short BK = 32; // K-chunk per iteration + + threadgroup half As[BM * BK]; + threadgroup half Bs[BN * BK]; + threadgroup float Cs[BM * BN]; + + const ushort tid = thread_index_in_threadgroup; // 0..127 + const ushort sgitg = simdgroup_index_in_threadgroup; // 0..3 + const short sg_row = sgitg / 2; // 0/1 + const short sg_col = sgitg % 2; // 0/1 + + const uint tile_n_idx = thread_position_in_grid.x / 128u; + const uint tile_m_idx = thread_position_in_grid.y; + const int tile_m0 = (int)tile_m_idx * BM; + const int tile_n0 = (int)tile_n_idx * BN; + + // M (number of activation rows) read at runtime from the injected x_shape, + // so this kernel works for both static and symbolic seqlen. + int M = 1; + for (uint d = 0; d + 1 < x_ndim; ++d) {{ M *= (int) x_shape[d]; }} + + const int nb = K / QK_K; + device const block_q6_K * wrows = (device const block_q6_K *) weight; + + simdgroup_float8x8 mc[2][2]; + for (short a = 0; a < 2; ++a) {{ + for (short b = 0; b < 2; ++b) {{ + mc[a][b] = make_filled_simdgroup_matrix(0.f); + }} + }} + + for (int k0 = 0; k0 < K; k0 += BK) {{ + // Cooperative load: activation tile (BM x BK), row-major in As. + for (short i = 0; i < (BM * BK) / 128; ++i) {{ + const short idx = tid + i * 128; + const short mm = idx / BK; + const short kk = idx % BK; + const int gm = tile_m0 + mm; + As[mm * BK + kk] = (gm < M) ? (half) x[(uint)gm * (uint)K + (k0 + kk)] : (half)0; + }} + // Cooperative load: dequantized weight tile (BN x BK), row-major in Bs. + for (short i = 0; i < (BN * BK) / 128; ++i) {{ + const short idx = tid + i * 128; + const short nn = idx / BK; + const short kk = idx % BK; + const int gn = tile_n0 + nn; + const int gk = k0 + kk; + half val = (half)0; + if (gn < N) {{ + device const block_q6_K * blk = wrows + (uint)gn * nb + (gk / QK_K); + val = (half) dequant_q6k_elem(blk, gk % QK_K); + }} + Bs[nn * BK + kk] = val; + }} + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short kk = 0; kk < BK / 8; ++kk) {{ + simdgroup_half8x8 a[2], b[2]; + for (short sr = 0; sr < 2; ++sr) {{ + simdgroup_load(a[sr], As + (16 * sg_row + 8 * sr) * BK + 8 * kk, BK, ulong2(0, 0), false); + }} + for (short sc = 0; sc < 2; ++sc) {{ + // transpose=true yields b[k][n] = Bs[n][k] for C = A @ B^T. + simdgroup_load(b[sc], Bs + (16 * sg_col + 8 * sc) * BK + 8 * kk, BK, ulong2(0, 0), true); + }} + for (short sr = 0; sr < 2; ++sr) {{ + for (short sc = 0; sc < 2; ++sc) {{ + simdgroup_multiply_accumulate(mc[sr][sc], a[sr], b[sc], mc[sr][sc]); + }} + }} + }} + threadgroup_barrier(mem_flags::mem_threadgroup); + }} + + for (short sr = 0; sr < 2; ++sr) {{ + for (short sc = 0; sc < 2; ++sc) {{ + simdgroup_store(mc[sr][sc], Cs + (16 * sg_row + 8 * sr) * BN + (16 * sg_col + 8 * sc), BN, ulong2(0, 0), false); + }} + }} + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short i = 0; i < (BM * BN) / 128; ++i) {{ + const short idx = tid + i * 128; + const short mm = idx / BN; + const short nn = idx % BN; + const int gm = tile_m0 + mm; + const int gn = tile_n0 + nn; + if (gm < M && gn < N) {{ + float v = Cs[mm * BN + nn]; +{bias_line} out[(uint)gm * (uint)N + gn] = (InT) v; + }} + }} +""" + + +# Number of simdgroups per threadgroup for the mat-vec kernel. +_Q6K_MV_NSG = 4 +# Tile size for the mat-mat kernel (BM == BN); threadgroup handles a tile. +_Q6K_MM_TILE = 32 + + +def _emit_q6k_matvec( + P: MLXProgramBuilder, + n: Node, + x_node: Node, + x_slot: Slot, + weight_slot: Slot, + bias_slot: Optional[Slot], + N: int, + K: int, + out: Slot, +) -> None: + in_dtype_int = torch_dtype_to_scalar_type(x_node.meta["val"].dtype) + + leading = emit_shape(P, x_node, x_slot, end_dim=-1) + M_iov = emit_product(P, leading) + out_shape_flat = leading + [IntOrVid.from_literal(N)] + + n_r0 = 2 + nsg = _Q6K_MV_NSG + num_row_groups = (N + nsg * n_r0 - 1) // (nsg * n_r0) + grid_x = num_row_groups * 32 * nsg + + has_bias = bias_slot is not None + inputs = [P.slot_to_tid(x_slot), P.slot_to_tid(weight_slot)] + input_names = ["x", "weight"] + if has_bias: + inputs.append(P.slot_to_tid(bias_slot)) + input_names.append("bias") + + P.emit( + MetalKernelNode( + name="gguf_q6k_matvec", + source=_q6k_matvec_source(has_bias), + header=_Q6K_HEADER, + inputs=inputs, + outputs=[P.slot_to_tid(out)], + grid=[ + IntOrVid.from_literal(grid_x), + M_iov, + IntOrVid.from_literal(1), + ], + threadgroup=[ + IntOrVid.from_literal(32 * nsg), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=input_names, + output_names=["out"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[in_dtype_int], + template_arg_names=["InT", "N", "K", "NSG"], + template_arg_kinds=[2, 0, 0, 0], # dtype, int, int, int + template_arg_values=[in_dtype_int, N, K, nsg], + ) + ) + + +def _emit_q6k_matmul( + P: MLXProgramBuilder, + n: Node, + x_node: Node, + x_slot: Slot, + weight_slot: Slot, + bias_slot: Optional[Slot], + N: int, + K: int, + blocks_m_iov: IntOrVid, + out: Slot, +) -> None: + in_dtype_int = torch_dtype_to_scalar_type(x_node.meta["val"].dtype) + + leading = emit_shape(P, x_node, x_slot, end_dim=-1) + out_shape_flat = leading + [IntOrVid.from_literal(N)] + + tile = _Q6K_MM_TILE + blocks_n = (N + tile - 1) // tile + grid_x = blocks_n * 128 # 128 threads (4 simdgroups) per threadgroup + + has_bias = bias_slot is not None + inputs = [P.slot_to_tid(x_slot), P.slot_to_tid(weight_slot)] + input_names = ["x", "weight"] + if has_bias: + inputs.append(P.slot_to_tid(bias_slot)) + input_names.append("bias") + + P.emit( + MetalKernelNode( + name="gguf_q6k_matmul", + source=_q6k_matmul_source(has_bias), + header=_Q6K_HEADER, + inputs=inputs, + outputs=[P.slot_to_tid(out)], + grid=[ + IntOrVid.from_literal(grid_x), + blocks_m_iov, + IntOrVid.from_literal(1), + ], + threadgroup=[ + IntOrVid.from_literal(128), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=input_names, + output_names=["out"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[in_dtype_int], + template_arg_names=["InT", "N", "K"], + template_arg_kinds=[2, 0, 0], + template_arg_values=[in_dtype_int, N, K], + ) + ) + + +@REGISTRY.register(target=[torch.ops.mlx.gguf_linear.default]) +def _gguf_linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Lower ``mlx::gguf_linear`` to fused Q6_K Metal kernels.""" + args = P.args(n) + if len(args) != 4: + raise ValueError( + f"mlx::gguf_linear: expected 4 args (x, weight, format, bias); " + f"got {len(args)}" + ) + x_slot, weight_slot, fmt, bias_slot = args + x_node = n.args[0] + weight_node = n.args[1] + + if fmt != "q6k": + raise NotImplementedError( + f"mlx::gguf_linear: unsupported format {fmt!r}; only 'q6k' is supported" + ) + + weight_meta = weight_node.meta["val"] + if weight_meta.dim() != 2: + raise NotImplementedError( + f"mlx::gguf_linear: weight must be 2-D (N, row_bytes); got " + f"shape {tuple(weight_meta.shape)}" + ) + N = weight_meta.shape[0] + row_bytes = weight_meta.shape[1] + if not isinstance(N, int) or not isinstance(row_bytes, int): + raise NotImplementedError( + "mlx::gguf_linear: weight shape must be statically known" + ) + if row_bytes % Q6K_BLOCK_BYTES != 0: + raise ValueError( + f"mlx::gguf_linear: weight row bytes {row_bytes} must be a multiple of " + f"{Q6K_BLOCK_BYTES}" + ) + K = (row_bytes // Q6K_BLOCK_BYTES) * QK_K + + # Determine M (product of x's leading dims). Static M lets us pick the + # optimal kernel and (for mat-mat) compute a literal launch grid. + x_meta = x_node.meta["val"] + leading_dims = x_meta.shape[:-1] + M: Optional[int] = 1 + for d in leading_dims: + if isinstance(d, int): + M *= d + else: + M = None # dynamic / symbolic + break + + out = P.make_or_get_slot(n) + tile = _Q6K_MM_TILE + if M == 1: + # Static decode -> mat-vec. + _emit_q6k_matvec(P, n, x_node, x_slot, weight_slot, bias_slot, N, K, out) + elif M is not None: + # Static prefill -> tiled simdgroup mat-mat (literal grid). + blocks_m = (M + tile - 1) // tile + _emit_q6k_matmul( + P, n, x_node, x_slot, weight_slot, bias_slot, N, K, + IntOrVid.from_literal(blocks_m), out, + ) + else: + # Dynamic seqlen -> emit both kernels in separate chains and select at + # runtime with an IfNode. cond = M - 1: nonzero (M>1) runs the mat-mat + # (then) chain, zero (M==1) runs the mat-vec (else) chain. + leading = emit_shape(P, x_node, x_slot, end_dim=-1) + m_iov = emit_product(P, leading) + + _, cond_slot = P.make_tmp_value_slot() + P.emit( + SubtractIntNode( + a=m_iov, + b=IntOrVid.from_literal(1), + out=P.slot_to_vid(cond_slot), + ) + ) + cond_iov = IntOrVid.from_vid(P.slot_to_vid(cond_slot)) + + # blocks_m = (M + tile - 1) // tile (mat-mat grid.y). + _, sum_slot = P.make_tmp_value_slot() + P.emit( + AddIntNode( + a=m_iov, + b=IntOrVid.from_literal(tile - 1), + out=P.slot_to_vid(sum_slot), + ) + ) + _, blocks_m_slot = P.make_tmp_value_slot() + P.emit( + FloorDivideIntNode( + a=IntOrVid.from_vid(P.slot_to_vid(sum_slot)), + b=IntOrVid.from_literal(tile), + out=P.slot_to_vid(blocks_m_slot), + ) + ) + blocks_m_iov = IntOrVid.from_vid(P.slot_to_vid(blocks_m_slot)) + + with P.new_chain() as then_idx: # prefill / mat-mat + _emit_q6k_matmul( + P, n, x_node, x_slot, weight_slot, bias_slot, N, K, + blocks_m_iov, out, + ) + with P.new_chain() as else_idx: # decode / mat-vec + _emit_q6k_matvec( + P, n, x_node, x_slot, weight_slot, bias_slot, N, K, out + ) + + P.emit( + IfNode( + cond=cond_iov, + then_chain_idx=then_idx, + else_chain_idx=else_idx, + ) + ) + return out diff --git a/backends/mlx/model_ops/test_gguf_linear.py b/backends/mlx/model_ops/test_gguf_linear.py new file mode 100644 index 00000000000..8235a6935a0 --- /dev/null +++ b/backends/mlx/model_ops/test_gguf_linear.py @@ -0,0 +1,349 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for ``mlx::gguf_linear`` (GGUF Q6_K linear). + +Compares the fused Metal kernels (mat-vec for decode, mat-mat for prefill) +against the eager pure-torch reference on the *same* packed Q6_K weight, so +quantization quality is irrelevant -- only the kernel-vs-reference numerics +are checked. Tolerances follow the activation dtype presets. + +``GGUFLinearDynamicTest`` additionally exports once with a symbolic seqlen and +runs the same .pte with M=1 and M>1 to exercise both branches of the runtime +``IfNode`` (decode mat-vec vs prefill mat-mat). + +Usage:: + + python -m executorch.backends.mlx.model_ops.test_gguf_linear run + python -m executorch.backends.mlx.model_ops.test_gguf_linear run -v + python -m executorch.backends.mlx.model_ops.test_gguf_linear run --rebuild + python -m executorch.backends.mlx.model_ops.test_gguf_linear eager +""" + +from typing import List, Tuple + +import executorch.backends.mlx.model_ops.gguf_linear # noqa: F401 + +import torch +import torch.nn as nn + +from executorch.backends.mlx.model_ops.gguf_linear import ( + dequantize_q6_k, + Q6K_BLOCK_BYTES, + QK_K, +) + +from executorch.backends.mlx.test.test_utils import OpTestCase + + +# --------------------------------------------------------------------------- +# GGUF Q6_K packing helper (mirrors quantize_row_q6_K_ref in ggml-quants.c). +# Quantization quality is not important here -- we only need valid, deterministic +# bytes that both the eager reference and the kernel will dequantize identically. +# --------------------------------------------------------------------------- + + +def pack_q6_k(w: torch.Tensor) -> torch.Tensor: + """Pack a ``(N, K)`` float weight into the GGUF ``block_q6_K`` byte layout. + + Returns ``(N, (K/256)*210)`` uint8. + """ + assert w.dim() == 2 + N, K = w.shape + assert K % QK_K == 0, f"K={K} must be a multiple of {QK_K}" + nb = K // QK_K + w = w.to(torch.float32).cpu() + + out = torch.empty(N, nb * Q6K_BLOCK_BYTES, dtype=torch.uint8) + + for row in range(N): + for b in range(nb): + block = w[row, b * QK_K : (b + 1) * QK_K] # (256,) + sub = block.view(QK_K // 16, 16) # (16, 16) + + # Per-sub-block scale via the simple max-abs / 32 scheme, then a + # super-block scale so sub-scales fit in int8 [-128, 127]. + sub_max = sub.abs().amax(dim=1) # (16,) + scales_f = sub_max / 32.0 # (16,) + max_scale = scales_f.abs().amax() + if max_scale < 1e-12: + # All-zero block. + continue + iscale = -128.0 / float(max_scale) + d = 1.0 / iscale + scales_i = torch.clamp( + torch.round(iscale * scales_f), max=127.0 + ).to(torch.int32) # (16,) + + # Quantize each element with the realized per-sub-block scale. + L = torch.zeros(QK_K, dtype=torch.int32) + for j in range(QK_K // 16): + dj = d * float(scales_i[j]) + seg = block[j * 16 : (j + 1) * 16] + if dj == 0.0: + q = torch.zeros(16, dtype=torch.int32) + else: + q = torch.clamp(torch.round(seg / dj), min=-32, max=31).to( + torch.int32 + ) + L[j * 16 : (j + 1) * 16] = q + 32 # 0..63 + + ql = torch.zeros(QK_K // 2, dtype=torch.int32) # 128 + qh = torch.zeros(QK_K // 4, dtype=torch.int32) # 64 + for half in range(2): # 128 elems each + jbase = half * 128 + qlb = half * 64 + qhb = half * 32 + for l in range(32): + l1 = L[jbase + l + 0] & 0xF + l2 = L[jbase + l + 32] & 0xF + l3 = L[jbase + l + 64] & 0xF + l4 = L[jbase + l + 96] & 0xF + ql[qlb + l + 0] = l1 | (l3 << 4) + ql[qlb + l + 32] = l2 | (l4 << 4) + qh[qhb + l] = ( + (L[jbase + l + 0] >> 4) + | ((L[jbase + l + 32] >> 4) << 2) + | ((L[jbase + l + 64] >> 4) << 4) + | ((L[jbase + l + 96] >> 4) << 6) + ) + + blk = out[row, b * Q6K_BLOCK_BYTES : (b + 1) * Q6K_BLOCK_BYTES] + blk[0:128] = ql.to(torch.uint8) + blk[128:192] = qh.to(torch.uint8) + blk[192:208] = scales_i.to(torch.int8).view(torch.uint8) + d_bytes = torch.tensor([d], dtype=torch.float16).view(torch.uint8) + blk[208:210] = d_bytes + + return out + + +# --------------------------------------------------------------------------- +# Test cases +# --------------------------------------------------------------------------- + + +class GGUFLinearModel(nn.Module): + def forward( + self, + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + ) -> torch.Tensor: + return torch.ops.mlx.gguf_linear(x, weight, "q6k", bias) + + +_DTYPE_TOL = { + torch.bfloat16: (2e-2, 2e-2), + torch.float16: (1e-3, 1e-3), + torch.float32: (1e-4, 1e-4), +} +_DTYPE_TAG = {torch.bfloat16: "bf16", torch.float16: "fp16", torch.float32: "fp32"} + + +class GGUFLinearTest(OpTestCase): + name = "gguf_linear" + + def __init__( + self, + M: int = 1, + N: int = 256, + K: int = 256, + dtype: torch.dtype = torch.bfloat16, + ): + self.M = M + self.N = N + self.K = K + self.dtype = dtype + self.rtol, self.atol = _DTYPE_TOL[dtype] + self.name = f"gguf_linear_m{M}_n{N}_k{K}_{_DTYPE_TAG[dtype]}" + + @classmethod + def get_test_configs(cls) -> List["GGUFLinearTest"]: + cfgs: List["GGUFLinearTest"] = [] + # Decode (mat-vec). + for K in (256, 512, 1024): + for N in (256, 512): + cfgs.append(cls(M=1, N=N, K=K, dtype=torch.bfloat16)) + cfgs.append(cls(M=1, N=256, K=256, dtype=torch.float16)) + cfgs.append(cls(M=1, N=256, K=256, dtype=torch.float32)) + # Prefill (mat-mat). + for M in (8, 64, 128): + cfgs.append(cls(M=M, N=512, K=512, dtype=torch.bfloat16)) + cfgs.append(cls(M=32, N=256, K=256, dtype=torch.float16)) + # Ragged shapes (M and N not multiples of the 32-wide tile / row group). + cfgs.append(cls(M=40, N=300, K=256, dtype=torch.bfloat16)) + cfgs.append(cls(M=1, N=300, K=256, dtype=torch.bfloat16)) + return cfgs + + def create_model(self) -> nn.Module: + return GGUFLinearModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + torch.manual_seed(0) + x = torch.randn(self.M, self.K, dtype=self.dtype) + w = torch.randn(self.N, self.K, dtype=torch.float32) * 0.1 + weight = pack_q6_k(w) + bias = torch.randn(self.N, dtype=self.dtype) + return (x, weight, bias) + + +class GGUFLinearDynamicTest(OpTestCase): + """Dynamic seqlen: export once with a symbolic M, run with M=1 (decode / + else chain) and M>1 (prefill / then chain) to exercise both IfNode branches. + """ + + name = "gguf_linear_dynamic" + + def __init__( + self, + export_M: int = 4, + test_M: int = 1, + N: int = 512, + K: int = 512, + dtype: torch.dtype = torch.bfloat16, + ): + self.export_M = export_M + self.test_M = test_M + self.seq_len = export_M # used by create_inputs (export tracing) + self.N = N + self.K = K + self.dtype = dtype + self.rtol, self.atol = _DTYPE_TOL[dtype] + self.name = ( + f"gguf_linear_dyn_exp{export_M}_test{test_M}_n{N}_k{K}_" + f"{_DTYPE_TAG[dtype]}" + ) + + @classmethod + def get_test_configs(cls) -> List["GGUFLinearDynamicTest"]: + return [ + cls(export_M=4, test_M=1, dtype=torch.bfloat16), # decode / else + cls(export_M=4, test_M=8, dtype=torch.bfloat16), # prefill / then + cls(export_M=4, test_M=4, dtype=torch.bfloat16), # control + cls(export_M=4, test_M=1, dtype=torch.float16), + cls(export_M=4, test_M=40, N=300, K=256, dtype=torch.bfloat16), # ragged + ] + + def get_dynamic_shapes(self): + seq_dim = torch.export.Dim("seq_len", min=1, max=64) + return {"x": {0: seq_dim}, "weight": None, "bias": None} + + def create_model(self) -> nn.Module: + return GGUFLinearModel() + + def _make_inputs(self, M: int) -> Tuple[torch.Tensor, ...]: + # Deterministic weight/bias so export-time and run-time (and the eager + # reference) all use the same quantized weight (it is a runtime input). + torch.manual_seed(0) + w = torch.randn(self.N, self.K, dtype=torch.float32) * 0.1 + weight = pack_q6_k(w) + bias = torch.randn(self.N, dtype=self.dtype) + x = torch.randn(M, self.K, dtype=self.dtype) + return (x, weight, bias) + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + return self._make_inputs(self.export_M) + + def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: + return self._make_inputs(self.test_M) + + +def _eager_sanity() -> None: + """Quick CPU check: pack -> dequant -> matmul matches a direct reference.""" + torch.manual_seed(0) + N, K = 4, 512 + w = torch.randn(N, K) * 0.1 + packed = pack_q6_k(w) + w_deq = dequantize_q6_k(packed, K) + # Round-trip error should be small for this benign distribution. + rel = (w_deq - w).norm() / w.norm() + print(f"pack/dequant relative error: {rel.item():.4f}") + x = torch.randn(3, K) + ref = x @ w_deq.t() + out = torch.ops.mlx.gguf_linear(x, packed, "q6k", None) + err = (out - ref).abs().max() + print(f"eager op vs reference max abs err: {err.item():.6e}") + assert err < 1e-3, err + # Unsupported format raises. + try: + torch.ops.mlx.gguf_linear(x, packed, "q4k", None) + raise AssertionError("expected NotImplementedError for q4k") + except (NotImplementedError, RuntimeError) as e: + print(f"q4k correctly rejected: {type(e).__name__}") + print("eager sanity OK") + + +if __name__ == "__main__": # noqa: C901 + import argparse + import sys + + from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner + + parser = argparse.ArgumentParser(description="Test mlx::gguf_linear op") + parser.add_argument( + "action", choices=["generate", "compare", "run", "list", "eager"] + ) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--rebuild", action="store_true") + parser.add_argument("--config", type=str, default=None) + args = parser.parse_args() + + if args.action == "eager": + _eager_sanity() + sys.exit(0) + + if args.rebuild and not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + configs = ( + GGUFLinearTest.get_test_configs() + + GGUFLinearDynamicTest.get_test_configs() + ) + + if args.action == "list": + for cfg in configs: + print(f" {cfg.name}") + sys.exit(0) + + if args.config: + configs = [c for c in configs if c.name == args.config] + if not configs: + print(f"No config matching '{args.config}'") + sys.exit(1) + + passed = 0 + failed = 0 + failed_names: List[str] = [] + + for test in configs: + if args.action == "generate": + pte_path, _, _ = test.generate_test_files(verbose=args.verbose) + print(f"Generated: {pte_path}") + elif args.action == "compare": + actual_path = test.get_test_dir() / "actual_output.bin" + ok, msg = test.compare_with_actual(actual_path) + print(f"{'✓' if ok else '✗'} {test.name}: {msg}") + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + elif args.action == "run": + ok = test.run_test(verbose=args.verbose) + if ok: + passed += 1 + else: + failed += 1 + failed_names.append(test.name) + + if args.action in ("run", "compare"): + print(f"\nPassed: {passed}, Failed: {failed}") + if failed_names: + print(f"Failed: {', '.join(failed_names)}") + sys.exit(0 if failed == 0 else 1) diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 34fd8815ba8..44cb3d8056a 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -1837,6 +1837,8 @@ class Interpreter { st.begin_op(idx, op_name(instr.op)); if (instr.op == OpCode::SCAN) { exec_scan(prog, std::get(instr.node), st, stream); + } else if (instr.op == OpCode::IF) { + exec_if(prog, std::get(instr.node), st, stream); } else { dispatch(instr, st, stream); } @@ -1846,6 +1848,20 @@ class Interpreter { } private: + void exec_if( + const MLXProgram& prog, + const IfNode& n, + ExecutionState& st, + StreamOrDevice s) const { + // Select one branch at runtime based on the integer condition. + // Nonzero -> then_chain, zero -> else_chain. The selected chain's + // instructions write the output slot(s) directly. + const int64_t cond = resolve_int(n.cond, st); + const uint32_t chain_idx = + (cond != 0) ? n.then_chain_idx : n.else_chain_idx; + run_chain(prog, chain_idx, st, s); + } + void exec_scan( const MLXProgram& prog, const ScanNode& n, diff --git a/backends/mlx/serialization/schema.fbs b/backends/mlx/serialization/schema.fbs index 3c02e5785ce..42c53e5172b 100644 --- a/backends/mlx/serialization/schema.fbs +++ b/backends/mlx/serialization/schema.fbs @@ -976,6 +976,15 @@ table ScanNode { scan_axis: int32 = 1; // dimension to iterate over } +// Runtime conditional: select one of two instruction chains based on a runtime +// integer condition. The selected branch writes its output slot(s) directly, so +// no `outputs` field is needed (unlike ScanNode, which post-processes/stacks). +table IfNode { + cond: IntOrVid (required); // nonzero -> then_chain, zero -> else_chain + then_chain_idx: uint32; // index into MLXGraph.instruction_chains + else_chain_idx: uint32; // index into MLXGraph.instruction_chains +} + // Custom Metal kernel execution via mlx::core::fast::metal_kernel(). // Two-phase API: // 1. Factory: metal_kernel(name, input_names, output_names, source, header, @@ -1151,7 +1160,8 @@ union OpNode { RollNode, BitwiseAndNode, BitwiseOrNode, - BitwiseXorNode + BitwiseXorNode, + IfNode // BC: Add new op nodes here (append only) } diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index ed3dcdba9c3..e46b51a8411 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -471,18 +471,13 @@ def main() -> None: backend=args.backend, ) - if args.gguf and args.backend == "mlx": - os.environ["ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS"] = "1" - try: - export_and_lower( - model, - config, - args.output_dir, - backend=args.backend, - use_turboquant=args.turboquant, - ) - finally: - os.environ.pop("ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS", None) + export_and_lower( + model, + config, + args.output_dir, + backend=args.backend, + use_turboquant=args.turboquant, + ) if __name__ == "__main__": diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 35dddb5a0dc..ac7ad5ebf89 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -138,8 +138,12 @@ def load_gguf_model( embed_quant = None n_processed = 0 + from gguf import GGMLQuantizationType + print(f"Streaming GGUF from {gguf_path}...") - for gguf_name, result in iter_gguf_tensors(gguf_path): + for gguf_name, result, gguf_type in iter_gguf_tensors( + gguf_path, q6k_raw=(backend == "mlx") + ): model_key = gguf_to_model_key(gguf_name) if model_key is None: continue @@ -152,6 +156,18 @@ def load_gguf_model( if backend == "cuda": result = dequantize_weight(result, torch.bfloat16) + # MLX Q6_K: keep the raw GGUF blob and swap the nn.Linear for a + # GGUFLinear that dispatches to the fused mlx::gguf_linear kernel, + # bypassing the slow group_size=16 non-fused affine path. + if backend == "mlx" and gguf_type == GGMLQuantizationType.Q6_K: + from executorch.examples.models.gemma4_31b.mlx_gguf_linear import ( + replace_with_gguf_linear, + ) + + replace_with_gguf_linear(model, model_key, result, format="q6k") + n_processed += 1 + continue + pack_one(model, model_key, result, packers) n_processed += 1 diff --git a/examples/models/gemma4_31b/mlx_gguf_linear.py b/examples/models/gemma4_31b/mlx_gguf_linear.py new file mode 100644 index 00000000000..53a4d236bf6 --- /dev/null +++ b/examples/models/gemma4_31b/mlx_gguf_linear.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""MLX carrier module for GGUF Q6_K linears. + +Wraps the raw GGUF ``block_q6_K`` weight blob and dispatches to the fused +``mlx::gguf_linear`` Metal kernel (decode mat-vec / prefill mat-mat), instead of +the slow non-fused dequantize+matmul path that group_size=16 affine quant takes +through the MLX ``QUANTIZED_LINEAR`` pattern. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn + +# Importing the op module registers ``torch.ops.mlx.gguf_linear``. +import executorch.backends.mlx.model_ops.gguf_linear # noqa: F401 +from executorch.backends.mlx.model_ops.gguf_linear import Q6K_BLOCK_BYTES, QK_K + + +class GGUFLinear(nn.Module): + """``y = gguf_linear(x, weight_blob, format)`` for a GGUF-quantized linear. + + The weight is the **raw** GGUF block blob, stored as a uint8 buffer of shape + ``(out_features, n_blocks * block_bytes)``. Gemma linears are bias-free, so + bias is always ``None``. + """ + + def __init__(self, weight_blob: torch.Tensor, format: str = "q6k"): + super().__init__() + if weight_blob.dim() != 2 or weight_blob.dtype != torch.uint8: + raise ValueError( + f"GGUFLinear: weight_blob must be 2-D uint8; got " + f"shape {tuple(weight_blob.shape)} dtype {weight_blob.dtype}" + ) + if format != "q6k": + raise NotImplementedError( + f"GGUFLinear: unsupported format {format!r}; only 'q6k' supported" + ) + row_bytes = int(weight_blob.shape[1]) + if row_bytes % Q6K_BLOCK_BYTES != 0: + raise ValueError( + f"GGUFLinear: weight row bytes {row_bytes} must be a multiple of " + f"{Q6K_BLOCK_BYTES}" + ) + self.format = format + self.out_features = int(weight_blob.shape[0]) + self.in_features = (row_bytes // Q6K_BLOCK_BYTES) * QK_K + # uint8 cannot be a grad-requiring Parameter; store as a buffer so it is + # serialized as a constant in the exported program. + self.register_buffer("weight", weight_blob) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.mlx.gguf_linear(x, self.weight, self.format, None) + + def extra_repr(self) -> str: + return ( + f"in_features={self.in_features}, out_features={self.out_features}, " + f"format={self.format!r}" + ) + + +def replace_with_gguf_linear( + model: nn.Module, + weight_fqn: str, + weight_blob: torch.Tensor, + format: str = "q6k", +) -> None: + """Replace the ``nn.Linear`` owning ``weight_fqn`` with a ``GGUFLinear``. + + ``weight_fqn`` is the fully-qualified name of the ``.weight`` tensor + (e.g. ``model.layers.0.mlp.down_proj.weight``). The parent linear module is + swapped in place on its grandparent module. + """ + parts = weight_fqn.rsplit(".", 1) + if len(parts) != 2 or parts[1] != "weight": + raise ValueError( + f"replace_with_gguf_linear: expected a '*.weight' fqn; got {weight_fqn!r}" + ) + linear_fqn = parts[0] + grandparent_fqn, _, child_name = linear_fqn.rpartition(".") + grandparent = ( + model.get_submodule(grandparent_fqn) if grandparent_fqn else model + ) + setattr(grandparent, child_name, GGUFLinear(weight_blob, format=format)) diff --git a/examples/models/gemma4_31b/quant/gguf.py b/examples/models/gemma4_31b/quant/gguf.py index 78c3aa3d8f9..5c58c86cde6 100644 --- a/examples/models/gemma4_31b/quant/gguf.py +++ b/examples/models/gemma4_31b/quant/gguf.py @@ -162,6 +162,22 @@ def _unpack_q6_k(data, shape: list[int]) -> torch.Tensor: ) +def _raw_q6_k(data, shape: list[int]) -> torch.Tensor: + """Return the raw GGUF Q6_K bytes as ``(N, n_blocks*210)`` uint8. + + Unlike ``_unpack_q6_k`` (which dequantizes and deinterleaves into an + ``IntxUnpackedToInt8Tensor``), this preserves the exact GGUF ``block_q6_K`` + byte layout so it can be consumed directly by the fused + ``mlx::gguf_linear`` Metal kernel. + """ + N, K = shape + assert K % QK_K == 0, f"Q6_K requires K divisible by {QK_K}, got {K}" + block_bytes = 2 + QK_K // 2 + QK_K // 4 + QK_K // 16 # 210 + n_blocks = N * (K // QK_K) + raw = _raw_tensor(data).reshape(n_blocks, block_bytes) + return raw.reshape(N, -1).clone() + + def unpack_gguf_tensor( tensor_data, tensor_type, @@ -193,17 +209,28 @@ def unpack_gguf_tensor( def iter_gguf_tensors( path: str, -) -> Iterator[tuple[str, torch.Tensor]]: - """Yield ``(name, result)`` for each tensor in a GGUF file. + q6k_raw: bool = False, +) -> Iterator[tuple[str, torch.Tensor, object]]: + """Yield ``(name, result, tensor_type)`` for each tensor in a GGUF file. Processes one tensor at a time for low peak memory. Tensor names are GGUF names (e.g., ``blk.0.attn_q.weight``); the caller handles key remapping. GGUF shapes are reversed to PyTorch convention automatically. + ``tensor_type`` is the ``gguf.GGMLQuantizationType`` so callers can branch + on the quant format. + + If ``q6k_raw`` is set, Q6_K tensors are yielded as the **raw** + ``(N, n_blocks*210)`` uint8 GGUF block blob (for the fused + ``mlx::gguf_linear`` path) instead of being dequantized/deinterleaved. """ - from gguf import GGUFReader + from gguf import GGMLQuantizationType, GGUFReader reader = GGUFReader(path) for tensor in reader.tensors: shape = list(reversed(tensor.shape.tolist())) - result = unpack_gguf_tensor(tensor.data, tensor.tensor_type, shape) - yield tensor.name, result + if q6k_raw and tensor.tensor_type == GGMLQuantizationType.Q6_K: + # Keep the raw GGUF block bytes for the fused mlx::gguf_linear path. + result = _raw_q6_k(tensor.data, shape) + else: + result = unpack_gguf_tensor(tensor.data, tensor.tensor_type, shape) + yield tensor.name, result, tensor.tensor_type diff --git a/examples/models/gemma4_31b/quant/tests/test_gguf.py b/examples/models/gemma4_31b/quant/tests/test_gguf.py index 89a7099d6f0..4a5dd4da68b 100644 --- a/examples/models/gemma4_31b/quant/tests/test_gguf.py +++ b/examples/models/gemma4_31b/quant/tests/test_gguf.py @@ -278,5 +278,51 @@ def test_unsupported_type_raises(self): ) +@unittest.skipUnless(_HAS_GGUF, "gguf package not installed") +class TestQ6KRawBlob(unittest.TestCase): + """``iter_gguf_tensors(q6k_raw=True)`` keeps the GGUF block bytes for + mlx::gguf_linear; ``_raw_q6_k`` produces them verbatim.""" + + def _block(self): + d = 0.5 + scales_16 = list(range(1, 17)) + qvals = [(i * 7 + 3) % 64 for i in range(256)] + return d, scales_16, qvals, _make_q6_k_block(d, scales_16, qvals) + + def test_raw_blob_preserves_bytes(self): + from executorch.examples.models.gemma4_31b.quant.gguf import _raw_q6_k + + _, _, _, block = self._block() + data = np.frombuffer(bytes(block), dtype=np.uint8).reshape(1, 210) + raw = _raw_q6_k(data, [1, 256]) + self.assertEqual(raw.dtype, torch.uint8) + self.assertEqual(tuple(raw.shape), (1, 210)) + self.assertTrue( + torch.equal(raw[0], torch.tensor(list(block), dtype=torch.uint8)) + ) + + def test_raw_blob_dequant_matches_gguf_reference(self): + from executorch.backends.mlx.model_ops.gguf_linear import dequantize_q6_k + from executorch.examples.models.gemma4_31b.quant.gguf import _raw_q6_k + + d, scales_16, qvals, block = self._block() + data = np.frombuffer(bytes(block), dtype=np.uint8).reshape(1, 210) + raw = _raw_q6_k(data, [1, 256]) + deq = dequantize_q6_k(raw, 256)[0] + ref = torch.tensor(_q6_k_reference_dequant(d, scales_16, qvals)) + self.assertTrue( + torch.allclose(deq, ref, atol=1e-3), + f"max diff: {(deq - ref).abs().max():.6g}", + ) + + def test_default_still_unpacks_to_intx(self): + from torchao.quantization import IntxUnpackedToInt8Tensor + + _, _, _, block = self._block() + data = np.frombuffer(bytes(block), dtype=np.uint8).reshape(1, 210) + result = unpack_gguf_tensor(data, GGMLQuantizationType.Q6_K, [1, 256]) + self.assertIsInstance(result, IntxUnpackedToInt8Tensor) + + if __name__ == "__main__": unittest.main() diff --git a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py index 37f61fddb0f..e4eb1393e67 100644 --- a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py @@ -323,5 +323,71 @@ def test_embedding_packing_preserves_values(self): ) +class TestGgufLinearMlx(unittest.TestCase): + """Q6_K weights route to the fused mlx::gguf_linear op (raw-blob path).""" + + def _make_blob(self, N: int, K: int) -> torch.Tensor: + from executorch.backends.mlx.model_ops.test_gguf_linear import pack_q6_k + + torch.manual_seed(0) + w = torch.randn(N, K, dtype=torch.float32) * 0.1 + return pack_q6_k(w) + + def test_replace_with_gguf_linear_swaps_module(self): + from executorch.examples.models.gemma4_31b.mlx_gguf_linear import ( + GGUFLinear, + replace_with_gguf_linear, + ) + + model = build_random_tiny_model() + # Pick a Linear whose in_features is a multiple of 256 (Q6_K requires + # K % 256 == 0) -- e.g. down_proj (in = intermediate_size). + target_fqn = None + for name, mod in model.named_modules(): + if isinstance(mod, nn.Linear) and mod.in_features % 256 == 0: + target_fqn = name + N, K = mod.out_features, mod.in_features + break + self.assertIsNotNone(target_fqn, "no Linear with in_features % 256 == 0") + + blob = self._make_blob(N, K) + replace_with_gguf_linear(model, target_fqn + ".weight", blob, format="q6k") + + swapped = model.get_submodule(target_fqn) + self.assertIsInstance(swapped, GGUFLinear) + self.assertEqual(swapped.in_features, K) + self.assertEqual(swapped.out_features, N) + self.assertEqual(swapped.weight.dtype, torch.uint8) + + # Forward dispatches to the op (eager fallback) and matches it exactly. + x = torch.randn(2, K, dtype=torch.bfloat16) + y = swapped(x) + ref = torch.ops.mlx.gguf_linear(x, blob, "q6k", None) + self.assertEqual(y.shape, torch.Size([2, N])) + self.assertTrue(torch.equal(y, ref)) + + def test_gguf_linear_appears_in_exported_graph(self): + from executorch.examples.models.gemma4_31b.mlx_gguf_linear import GGUFLinear + from torch.export import Dim, export + + N, K = 256, 512 + blob = self._make_blob(N, K) + m = GGUFLinear(blob, format="q6k").eval() + + # Dynamic seq_len exercises the runtime-routed (IfNode) lowering path. + seq = Dim("seq", min=1, max=8) + ep = export( + m, + (torch.randn(4, K, dtype=torch.bfloat16),), + dynamic_shapes=({0: seq},), + strict=True, + ) + targets = [str(n.target) for n in ep.graph.nodes if n.op == "call_function"] + self.assertTrue( + any("mlx.gguf_linear" in t for t in targets), + f"gguf_linear not in exported graph: {targets}", + ) + + if __name__ == "__main__": unittest.main() From 6c546d52fa16cc6994c58928a06a0971996e6782 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 3 Jun 2026 17:08:53 -0700 Subject: [PATCH 02/18] up --- backends/mlx/model_ops/gguf_linear.py | 32 +++-- backends/mlx/model_ops/test_gguf_linear.py | 120 +++++------------- examples/models/gemma4_31b/mlx_gguf_linear.py | 10 +- .../gemma4_31b/quant/tests/test_gguf.py | 20 +++ .../gemma4_31b/tests/test_mlx_pipeline.py | 6 +- 5 files changed, 83 insertions(+), 105 deletions(-) diff --git a/backends/mlx/model_ops/gguf_linear.py b/backends/mlx/model_ops/gguf_linear.py index a58096aa131..a255bd71643 100644 --- a/backends/mlx/model_ops/gguf_linear.py +++ b/backends/mlx/model_ops/gguf_linear.py @@ -14,7 +14,7 @@ The weight is stored in the **exact GGUF packed block layout** (no repacking), so weights converted by llama.cpp / gguf-py can be consumed directly. The ``format`` argument selects the GGUF quantization type; only ``"q6k"`` is -supported initially and anything else raises ``NotImplementedError``. +supported and anything else raises ``NotImplementedError``. Q6_K layout (per 256-element super-block, 210 bytes, see llama.cpp ``block_q6_K`` in ``ggml-common.h``):: @@ -64,7 +64,7 @@ # --------------------------------------------------------------------------- -# Q6_K constants and pure-torch reference (also used by the eager fallback) +# Q6_K constants and pure-torch dequant reference # --------------------------------------------------------------------------- QK_K = 256 @@ -627,8 +627,16 @@ def _gguf_linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: # Static prefill -> tiled simdgroup mat-mat (literal grid). blocks_m = (M + tile - 1) // tile _emit_q6k_matmul( - P, n, x_node, x_slot, weight_slot, bias_slot, N, K, - IntOrVid.from_literal(blocks_m), out, + P, + n, + x_node, + x_slot, + weight_slot, + bias_slot, + N, + K, + IntOrVid.from_literal(blocks_m), + out, ) else: # Dynamic seqlen -> emit both kernels in separate chains and select at @@ -668,13 +676,19 @@ def _gguf_linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: with P.new_chain() as then_idx: # prefill / mat-mat _emit_q6k_matmul( - P, n, x_node, x_slot, weight_slot, bias_slot, N, K, - blocks_m_iov, out, + P, + n, + x_node, + x_slot, + weight_slot, + bias_slot, + N, + K, + blocks_m_iov, + out, ) with P.new_chain() as else_idx: # decode / mat-vec - _emit_q6k_matvec( - P, n, x_node, x_slot, weight_slot, bias_slot, N, K, out - ) + _emit_q6k_matvec(P, n, x_node, x_slot, weight_slot, bias_slot, N, K, out) P.emit( IfNode( diff --git a/backends/mlx/model_ops/test_gguf_linear.py b/backends/mlx/model_ops/test_gguf_linear.py index 8235a6935a0..98d62d82305 100644 --- a/backends/mlx/model_ops/test_gguf_linear.py +++ b/backends/mlx/model_ops/test_gguf_linear.py @@ -42,84 +42,36 @@ # --------------------------------------------------------------------------- -# GGUF Q6_K packing helper (mirrors quantize_row_q6_K_ref in ggml-quants.c). -# Quantization quality is not important here -- we only need valid, deterministic -# bytes that both the eager reference and the kernel will dequantize identically. +# GGUF Q6_K test fixtures. +# +# The Python ``gguf`` package can dequantize Q6_K but does NOT implement Q6_K +# quantization, so we build the packed weight here. Quantization quality is +# irrelevant: the tests only compare the kernel against the eager op on the +# *same* bytes, so we just emit valid random blocks (random ql/qh/scales plus a +# small finite fp16 ``d`` -- the one field that must be finite). # --------------------------------------------------------------------------- -def pack_q6_k(w: torch.Tensor) -> torch.Tensor: - """Pack a ``(N, K)`` float weight into the GGUF ``block_q6_K`` byte layout. - - Returns ``(N, (K/256)*210)`` uint8. - """ - assert w.dim() == 2 - N, K = w.shape +def make_q6_k_blob(N: int, K: int, seed: int = 0) -> torch.Tensor: + """Build a ``(N, (K/256)*210)`` uint8 tensor of valid GGUF Q6_K blocks.""" assert K % QK_K == 0, f"K={K} must be a multiple of {QK_K}" nb = K // QK_K - w = w.to(torch.float32).cpu() - + g = torch.Generator().manual_seed(seed) out = torch.empty(N, nb * Q6K_BLOCK_BYTES, dtype=torch.uint8) - - for row in range(N): - for b in range(nb): - block = w[row, b * QK_K : (b + 1) * QK_K] # (256,) - sub = block.view(QK_K // 16, 16) # (16, 16) - - # Per-sub-block scale via the simple max-abs / 32 scheme, then a - # super-block scale so sub-scales fit in int8 [-128, 127]. - sub_max = sub.abs().amax(dim=1) # (16,) - scales_f = sub_max / 32.0 # (16,) - max_scale = scales_f.abs().amax() - if max_scale < 1e-12: - # All-zero block. - continue - iscale = -128.0 / float(max_scale) - d = 1.0 / iscale - scales_i = torch.clamp( - torch.round(iscale * scales_f), max=127.0 - ).to(torch.int32) # (16,) - - # Quantize each element with the realized per-sub-block scale. - L = torch.zeros(QK_K, dtype=torch.int32) - for j in range(QK_K // 16): - dj = d * float(scales_i[j]) - seg = block[j * 16 : (j + 1) * 16] - if dj == 0.0: - q = torch.zeros(16, dtype=torch.int32) - else: - q = torch.clamp(torch.round(seg / dj), min=-32, max=31).to( - torch.int32 - ) - L[j * 16 : (j + 1) * 16] = q + 32 # 0..63 - - ql = torch.zeros(QK_K // 2, dtype=torch.int32) # 128 - qh = torch.zeros(QK_K // 4, dtype=torch.int32) # 64 - for half in range(2): # 128 elems each - jbase = half * 128 - qlb = half * 64 - qhb = half * 32 - for l in range(32): - l1 = L[jbase + l + 0] & 0xF - l2 = L[jbase + l + 32] & 0xF - l3 = L[jbase + l + 64] & 0xF - l4 = L[jbase + l + 96] & 0xF - ql[qlb + l + 0] = l1 | (l3 << 4) - ql[qlb + l + 32] = l2 | (l4 << 4) - qh[qhb + l] = ( - (L[jbase + l + 0] >> 4) - | ((L[jbase + l + 32] >> 4) << 2) - | ((L[jbase + l + 64] >> 4) << 4) - | ((L[jbase + l + 96] >> 4) << 6) - ) - - blk = out[row, b * Q6K_BLOCK_BYTES : (b + 1) * Q6K_BLOCK_BYTES] - blk[0:128] = ql.to(torch.uint8) - blk[128:192] = qh.to(torch.uint8) - blk[192:208] = scales_i.to(torch.int8).view(torch.uint8) - d_bytes = torch.tensor([d], dtype=torch.float16).view(torch.uint8) - blk[208:210] = d_bytes - + blocks = out.view(N, nb, Q6K_BLOCK_BYTES) + # ql (0:128) + qh (128:192): any byte values are valid 6-bit quants. + blocks[..., :192] = torch.randint( + 0, 256, (N, nb, 192), dtype=torch.uint8, generator=g + ) + # scales (192:208): modest positive int8 scales keep magnitudes sane. + blocks[..., 192:208] = torch.randint( + 1, 17, (N, nb, 16), dtype=torch.uint8, generator=g + ) + # d (208:210): a small finite fp16 super-block scale. Chosen so dequantized + # element magnitudes (~ d * scale * (q-32)) are O(0.1), like real Q6_K + # weights -- the mat-mat kernel stores tiles in half precision (as in + # llama.cpp), so unrealistically large magnitudes would exceed bf16 tol. + blocks[..., 208:210] = torch.tensor([7e-4], dtype=torch.float16).view(torch.uint8) return out @@ -140,7 +92,9 @@ def forward( _DTYPE_TOL = { torch.bfloat16: (2e-2, 2e-2), - torch.float16: (1e-3, 1e-3), + # The mat-mat (prefill) kernel stores tiles in half precision (as in + # llama.cpp), so fp16 outputs are accurate to ~half precision (~4e-3). + torch.float16: (5e-3, 5e-3), torch.float32: (1e-4, 1e-4), } _DTYPE_TAG = {torch.bfloat16: "bf16", torch.float16: "fp16", torch.float32: "fp32"} @@ -187,8 +141,7 @@ def create_model(self) -> nn.Module: def create_inputs(self) -> Tuple[torch.Tensor, ...]: torch.manual_seed(0) x = torch.randn(self.M, self.K, dtype=self.dtype) - w = torch.randn(self.N, self.K, dtype=torch.float32) * 0.1 - weight = pack_q6_k(w) + weight = make_q6_k_blob(self.N, self.K) bias = torch.randn(self.N, dtype=self.dtype) return (x, weight, bias) @@ -241,8 +194,7 @@ def _make_inputs(self, M: int) -> Tuple[torch.Tensor, ...]: # Deterministic weight/bias so export-time and run-time (and the eager # reference) all use the same quantized weight (it is a runtime input). torch.manual_seed(0) - w = torch.randn(self.N, self.K, dtype=torch.float32) * 0.1 - weight = pack_q6_k(w) + weight = make_q6_k_blob(self.N, self.K) bias = torch.randn(self.N, dtype=self.dtype) x = torch.randn(M, self.K, dtype=self.dtype) return (x, weight, bias) @@ -255,15 +207,12 @@ def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: def _eager_sanity() -> None: - """Quick CPU check: pack -> dequant -> matmul matches a direct reference.""" + """Quick CPU check: dequant + matmul matches the eager op on the same bytes.""" torch.manual_seed(0) N, K = 4, 512 - w = torch.randn(N, K) * 0.1 - packed = pack_q6_k(w) + packed = make_q6_k_blob(N, K) w_deq = dequantize_q6_k(packed, K) - # Round-trip error should be small for this benign distribution. - rel = (w_deq - w).norm() / w.norm() - print(f"pack/dequant relative error: {rel.item():.4f}") + print(f"dequant finite: {torch.isfinite(w_deq).all().item()}") x = torch.randn(3, K) ref = x @ w_deq.t() out = torch.ops.mlx.gguf_linear(x, packed, "q6k", None) @@ -274,7 +223,7 @@ def _eager_sanity() -> None: try: torch.ops.mlx.gguf_linear(x, packed, "q4k", None) raise AssertionError("expected NotImplementedError for q4k") - except (NotImplementedError, RuntimeError) as e: + except RuntimeError as e: print(f"q4k correctly rejected: {type(e).__name__}") print("eager sanity OK") @@ -302,8 +251,7 @@ def _eager_sanity() -> None: sys.exit(1) configs = ( - GGUFLinearTest.get_test_configs() - + GGUFLinearDynamicTest.get_test_configs() + GGUFLinearTest.get_test_configs() + GGUFLinearDynamicTest.get_test_configs() ) if args.action == "list": diff --git a/examples/models/gemma4_31b/mlx_gguf_linear.py b/examples/models/gemma4_31b/mlx_gguf_linear.py index 53a4d236bf6..e22b3bde2bc 100644 --- a/examples/models/gemma4_31b/mlx_gguf_linear.py +++ b/examples/models/gemma4_31b/mlx_gguf_linear.py @@ -14,11 +14,11 @@ from __future__ import annotations -import torch -import torch.nn as nn - # Importing the op module registers ``torch.ops.mlx.gguf_linear``. import executorch.backends.mlx.model_ops.gguf_linear # noqa: F401 + +import torch +import torch.nn as nn from executorch.backends.mlx.model_ops.gguf_linear import Q6K_BLOCK_BYTES, QK_K @@ -83,7 +83,5 @@ def replace_with_gguf_linear( ) linear_fqn = parts[0] grandparent_fqn, _, child_name = linear_fqn.rpartition(".") - grandparent = ( - model.get_submodule(grandparent_fqn) if grandparent_fqn else model - ) + grandparent = model.get_submodule(grandparent_fqn) if grandparent_fqn else model setattr(grandparent, child_name, GGUFLinear(weight_blob, format=format)) diff --git a/examples/models/gemma4_31b/quant/tests/test_gguf.py b/examples/models/gemma4_31b/quant/tests/test_gguf.py index 4a5dd4da68b..dd426aab73b 100644 --- a/examples/models/gemma4_31b/quant/tests/test_gguf.py +++ b/examples/models/gemma4_31b/quant/tests/test_gguf.py @@ -315,6 +315,26 @@ def test_raw_blob_dequant_matches_gguf_reference(self): f"max diff: {(deq - ref).abs().max():.6g}", ) + def test_raw_blob_dequant_matches_gguf_lib(self): + # Cross-check our dequant against gguf's own Q6_K dequantizer (gguf can + # dequantize Q6_K even though it cannot quantize to it). + import gguf + + from executorch.backends.mlx.model_ops.gguf_linear import dequantize_q6_k + from executorch.examples.models.gemma4_31b.quant.gguf import _raw_q6_k + + _, _, _, block = self._block() + data = np.frombuffer(bytes(block), dtype=np.uint8).reshape(1, 210) + raw = _raw_q6_k(data, [1, 256]) + ours = dequantize_q6_k(raw, 256)[0] + theirs = torch.tensor( + np.asarray(gguf.dequantize(data, GGMLQuantizationType.Q6_K))[0] + ) + self.assertTrue( + torch.allclose(ours, theirs, atol=1e-3), + f"max diff: {(ours - theirs).abs().max():.6g}", + ) + def test_default_still_unpacks_to_intx(self): from torchao.quantization import IntxUnpackedToInt8Tensor diff --git a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py index e4eb1393e67..2eba09f2e17 100644 --- a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py @@ -327,11 +327,9 @@ class TestGgufLinearMlx(unittest.TestCase): """Q6_K weights route to the fused mlx::gguf_linear op (raw-blob path).""" def _make_blob(self, N: int, K: int) -> torch.Tensor: - from executorch.backends.mlx.model_ops.test_gguf_linear import pack_q6_k + from executorch.backends.mlx.model_ops.test_gguf_linear import make_q6_k_blob - torch.manual_seed(0) - w = torch.randn(N, K, dtype=torch.float32) * 0.1 - return pack_q6_k(w) + return make_q6_k_blob(N, K) def test_replace_with_gguf_linear_swaps_module(self): from executorch.examples.models.gemma4_31b.mlx_gguf_linear import ( From a0e24550cdc01a21a5cd7dae04b0d30b36cc9efb Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Wed, 3 Jun 2026 22:36:15 -0700 Subject: [PATCH 03/18] up --- backends/mlx/model_ops/gguf_embedding.py | 195 ++++++++++++++++++ backends/mlx/model_ops/gguf_linear.py | 10 +- backends/mlx/model_ops/test_gguf_embedding.py | 130 ++++++++++++ backends/mlx/model_ops/test_gguf_linear.py | 8 +- examples/models/gemma4_31b/gguf_loader.py | 67 ++++-- examples/models/gemma4_31b/mlx_gguf_linear.py | 78 ++++++- examples/models/gemma4_31b/quant/gguf.py | 6 +- .../gemma4_31b/tests/test_mlx_pipeline.py | 78 ++++++- 8 files changed, 535 insertions(+), 37 deletions(-) create mode 100644 backends/mlx/model_ops/gguf_embedding.py create mode 100644 backends/mlx/model_ops/test_gguf_embedding.py diff --git a/backends/mlx/model_ops/gguf_embedding.py b/backends/mlx/model_ops/gguf_embedding.py new file mode 100644 index 00000000000..81795a11754 --- /dev/null +++ b/backends/mlx/model_ops/gguf_embedding.py @@ -0,0 +1,195 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +""" +``mlx::gguf_embedding``: embedding gather against a GGUF-quantized table. + + out = dequant(weight[indices]) + +The embedding table is the raw GGUF ``block_q6_K`` blob (one quantized row per +vocab entry). This is the gather counterpart to ``mlx::gguf_linear`` and exists +because MLX's affine dequantize has no group_size=16 Metal kernel, so a Q6_K +embedding (group_size 16) cannot use the generic quantized-embedding path. + +``format`` selects the GGUF quant type; only ``"q6k"`` is supported. Output is +bfloat16. + +Usage:: + + import executorch.backends.mlx.model_ops.gguf_embedding # noqa: F401 + + out = torch.ops.mlx.gguf_embedding(weight, indices, "q6k") + # weight: (vocab, (K/256)*210) uint8 GGUF q6_K blob + # indices: (...) int + # out: (..., K) bfloat16 +""" + +from __future__ import annotations + +import torch +from torch import Tensor +from torch.fx.node import Node + +from executorch.backends.mlx.model_ops.gguf_linear import ( + _Q6K_HEADER, + dequantize_q6_k, + Q6K_BLOCK_BYTES, + QK_K, +) + + +# --------------------------------------------------------------------------- +# Custom op + eager fallback +# --------------------------------------------------------------------------- + + +@torch.library.custom_op("mlx::gguf_embedding", mutates_args=()) +def gguf_embedding(weight: Tensor, indices: Tensor, format: str) -> Tensor: + """Gather + dequantize rows of a GGUF-quantized embedding table. + + Args: + weight: ``(vocab, (K/256)*210)`` uint8 GGUF ``q6_K`` blob. + indices: integer token ids of any shape. + format: GGUF quant type; only ``"q6k"`` supported. + + Returns: + ``(*indices.shape, K)`` bfloat16. + """ + if format != "q6k": + raise NotImplementedError( + f"mlx::gguf_embedding: unsupported format {format!r}; only 'q6k' " + f"is supported" + ) + if weight.dim() != 2: + raise ValueError( + f"mlx::gguf_embedding: weight must be 2-D (vocab, row_bytes); got " + f"shape {tuple(weight.shape)}" + ) + row_bytes = weight.shape[1] + if row_bytes % Q6K_BLOCK_BYTES != 0: + raise ValueError( + f"mlx::gguf_embedding: weight row bytes {row_bytes} must be a " + f"multiple of {Q6K_BLOCK_BYTES}" + ) + K = (row_bytes // Q6K_BLOCK_BYTES) * QK_K + + rows = weight[indices.reshape(-1).long()] # (num, row_bytes) + deq = dequantize_q6_k(rows, K) # (num, K) float32 + return deq.reshape(*indices.shape, K).to(torch.bfloat16) + + +@torch.library.register_fake("mlx::gguf_embedding") +def gguf_embedding_fake(weight: Tensor, indices: Tensor, format: str) -> Tensor: + row_bytes = weight.shape[1] + K = (row_bytes // Q6K_BLOCK_BYTES) * QK_K + return indices.new_empty((*indices.shape, K), dtype=torch.bfloat16) + + +# --------------------------------------------------------------------------- +# MLX handler +# --------------------------------------------------------------------------- + +from executorch.backends.mlx.builder.op_helpers import ( + emit_product, + emit_shape, + torch_dtype_to_scalar_type, +) +from executorch.backends.mlx.builder.op_registry import REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + MetalKernelNode, +) + + +# One thread per output element. grid = (K, num_idx, 1): x picks the feature j, +# y picks the gathered row; each thread dequantizes a single Q6_K element. +_Q6K_EMBED_SOURCE = """ + const uint j = thread_position_in_grid.x; // 0..K-1 + const uint r = thread_position_in_grid.y; // gathered row + const int row = (int) indices[r]; + const int nb = K / QK_K; + device const block_q6_K * blk = + ((device const block_q6_K *) weight) + (uint)row * nb + (j / QK_K); + out[r * (uint)K + j] = (OutT) dequant_q6k_elem(blk, j % QK_K); +""" + + +@REGISTRY.register(target=[torch.ops.mlx.gguf_embedding.default]) +def _gguf_embedding_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Lower ``mlx::gguf_embedding`` to a fused Q6_K gather Metal kernel.""" + args = P.args(n) + if len(args) != 3: + raise ValueError( + f"mlx::gguf_embedding: expected 3 args (weight, indices, format); " + f"got {len(args)}" + ) + weight_slot, indices_slot, fmt = args + weight_node = n.args[0] + indices_node = n.args[1] + + if fmt != "q6k": + raise NotImplementedError( + f"mlx::gguf_embedding: unsupported format {fmt!r}; only 'q6k' " + f"is supported" + ) + + weight_meta = weight_node.meta["val"] + if weight_meta.dim() != 2: + raise NotImplementedError( + f"mlx::gguf_embedding: weight must be 2-D (vocab, row_bytes); got " + f"shape {tuple(weight_meta.shape)}" + ) + row_bytes = weight_meta.shape[1] + if not isinstance(row_bytes, int): + raise NotImplementedError( + "mlx::gguf_embedding: weight shape must be statically known" + ) + if row_bytes % Q6K_BLOCK_BYTES != 0: + raise ValueError( + f"mlx::gguf_embedding: weight row bytes {row_bytes} must be a " + f"multiple of {Q6K_BLOCK_BYTES}" + ) + K = (row_bytes // Q6K_BLOCK_BYTES) * QK_K + + out_dtype_int = torch_dtype_to_scalar_type(torch.bfloat16) + + out = P.make_or_get_slot(n) + leading = emit_shape(P, indices_node, indices_slot, end_dim=None) + num_idx_iov = emit_product(P, leading) + out_shape_flat = leading + [IntOrVid.from_literal(K)] + + # threadgroup.x must divide grid.x (= K, a multiple of 256). + tg_x = 256 if K % 256 == 0 else K + + P.emit( + MetalKernelNode( + name="gguf_q6k_embedding", + source=_Q6K_EMBED_SOURCE, + header=_Q6K_HEADER, + inputs=[P.slot_to_tid(weight_slot), P.slot_to_tid(indices_slot)], + outputs=[P.slot_to_tid(out)], + grid=[IntOrVid.from_literal(K), num_idx_iov, IntOrVid.from_literal(1)], + threadgroup=[ + IntOrVid.from_literal(tg_x), + IntOrVid.from_literal(1), + IntOrVid.from_literal(1), + ], + input_names=["weight", "indices"], + output_names=["out"], + output_shapes_flat=out_shape_flat, + output_shape_lengths=[len(out_shape_flat)], + output_dtypes=[out_dtype_int], + template_arg_names=["OutT", "K"], + template_arg_kinds=[2, 0], # dtype, int + template_arg_values=[out_dtype_int, K], + ) + ) + + return out diff --git a/backends/mlx/model_ops/gguf_linear.py b/backends/mlx/model_ops/gguf_linear.py index a255bd71643..4730683c330 100644 --- a/backends/mlx/model_ops/gguf_linear.py +++ b/backends/mlx/model_ops/gguf_linear.py @@ -573,12 +573,16 @@ def _emit_q6k_matmul( def _gguf_linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: """Lower ``mlx::gguf_linear`` to fused Q6_K Metal kernels.""" args = P.args(n) - if len(args) != 4: + if len(args) == 4: + x_slot, weight_slot, fmt, bias_slot = args + elif len(args) == 3: + x_slot, weight_slot, fmt = args + bias_slot = None + else: raise ValueError( - f"mlx::gguf_linear: expected 4 args (x, weight, format, bias); " + f"mlx::gguf_linear: expected 3 or 4 args (x, weight, format[, bias]); " f"got {len(args)}" ) - x_slot, weight_slot, fmt, bias_slot = args x_node = n.args[0] weight_node = n.args[1] diff --git a/backends/mlx/model_ops/test_gguf_embedding.py b/backends/mlx/model_ops/test_gguf_embedding.py new file mode 100644 index 00000000000..8c8cc97b41e --- /dev/null +++ b/backends/mlx/model_ops/test_gguf_embedding.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Tests for ``mlx::gguf_embedding`` (GGUF Q6_K embedding gather). + +Compares the fused gather Metal kernel against the eager reference on the same +packed Q6_K table. The kernel and reference run identical per-element float +dequant, so the bf16 outputs match exactly. + +Usage:: + + python -m executorch.backends.mlx.model_ops.test_gguf_embedding run + python -m executorch.backends.mlx.model_ops.test_gguf_embedding run --rebuild +""" + +from typing import List, Tuple + +import executorch.backends.mlx.model_ops.gguf_embedding # noqa: F401 + +import torch +import torch.nn as nn + +from executorch.backends.mlx.model_ops.test_gguf_linear import make_q6_k_blob +from executorch.backends.mlx.test.test_utils import OpTestCase + + +class GGUFEmbeddingModel(nn.Module): + def forward(self, weight: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: + return torch.ops.mlx.gguf_embedding(weight, indices, "q6k") + + +class GGUFEmbeddingTest(OpTestCase): + name = "gguf_embedding" + rtol = 0.0 + atol = 0.0 + + def __init__( + self, + vocab: int = 512, + K: int = 256, + idx_shape: Tuple[int, ...] = (8,), + ): + self.vocab = vocab + self.K = K + self.idx_shape = idx_shape + shp = "x".join(str(d) for d in idx_shape) + self.name = f"gguf_embedding_v{vocab}_k{K}_idx{shp}" + + @classmethod + def get_test_configs(cls) -> List["GGUFEmbeddingTest"]: + return [ + cls(vocab=512, K=256, idx_shape=(1,)), + cls(vocab=512, K=256, idx_shape=(8,)), + cls(vocab=512, K=256, idx_shape=(64,)), + cls(vocab=512, K=512, idx_shape=(8,)), + cls(vocab=512, K=1024, idx_shape=(4,)), + cls(vocab=300, K=256, idx_shape=(16,)), # vocab not tile-aligned + cls(vocab=512, K=256, idx_shape=(2, 3)), # multi-dim indices + ] + + def create_model(self) -> nn.Module: + return GGUFEmbeddingModel() + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + torch.manual_seed(0) + weight = make_q6_k_blob(self.vocab, self.K) + indices = torch.randint(0, self.vocab, self.idx_shape, dtype=torch.int32) + return (weight, indices) + + +if __name__ == "__main__": + import argparse + import sys + + from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner + + parser = argparse.ArgumentParser(description="Test mlx::gguf_embedding op") + parser.add_argument("action", choices=["generate", "compare", "run", "list"]) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--rebuild", action="store_true") + parser.add_argument("--config", type=str, default=None) + args = parser.parse_args() + + if args.rebuild and not rebuild_op_test_runner(verbose=args.verbose): + sys.exit(1) + + configs = GGUFEmbeddingTest.get_test_configs() + + if args.action == "list": + for cfg in configs: + print(f" {cfg.name}") + sys.exit(0) + + if args.config: + configs = [c for c in configs if c.name == args.config] + if not configs: + print(f"No config matching '{args.config}'") + sys.exit(1) + + passed = 0 + failed = 0 + failed_names: List[str] = [] + + for test in configs: + if args.action == "generate": + pte_path, _, _ = test.generate_test_files(verbose=args.verbose) + print(f"Generated: {pte_path}") + elif args.action == "compare": + actual_path = test.get_test_dir() / "actual_output.bin" + ok, msg = test.compare_with_actual(actual_path) + print(f"{'✓' if ok else '✗'} {test.name}: {msg}") + passed, failed = (passed + 1, failed) if ok else (passed, failed + 1) + if not ok: + failed_names.append(test.name) + elif args.action == "run": + ok = test.run_test(verbose=args.verbose) + passed, failed = (passed + 1, failed) if ok else (passed, failed + 1) + if not ok: + failed_names.append(test.name) + + if args.action in ("run", "compare"): + print(f"\nPassed: {passed}, Failed: {failed}") + if failed_names: + print(f"Failed: {', '.join(failed_names)}") + sys.exit(0 if failed == 0 else 1) diff --git a/backends/mlx/model_ops/test_gguf_linear.py b/backends/mlx/model_ops/test_gguf_linear.py index 98d62d82305..11cc6b2c4ce 100644 --- a/backends/mlx/model_ops/test_gguf_linear.py +++ b/backends/mlx/model_ops/test_gguf_linear.py @@ -63,10 +63,10 @@ def make_q6_k_blob(N: int, K: int, seed: int = 0) -> torch.Tensor: blocks[..., :192] = torch.randint( 0, 256, (N, nb, 192), dtype=torch.uint8, generator=g ) - # scales (192:208): modest positive int8 scales keep magnitudes sane. - blocks[..., 192:208] = torch.randint( - 1, 17, (N, nb, 16), dtype=torch.uint8, generator=g - ) + # scales (192:208): signed int8 scales (real Q6_K scales can be negative); + # a modest magnitude keeps dequantized values sane. + scales = torch.randint(-16, 17, (N, nb, 16), dtype=torch.int32, generator=g) + blocks[..., 192:208] = scales.to(torch.int8).view(torch.uint8) # d (208:210): a small finite fp16 super-block scale. Chosen so dequantized # element magnitudes (~ d * scale * (q-32)) are O(0.1), like real Q6_K # weights -- the mat-mat kernel stores tiles in half precision (as in diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index ac7ad5ebf89..b3ad89fc2e3 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -65,14 +65,22 @@ def gguf_to_model_key(gguf_key: str) -> Optional[str]: return None -def _resolve_tied_lm_head(model, embed_quant, packers): +def _resolve_tied_lm_head(model, embed_quant, embed_q6k_raw, packers): """Handle tied embed/lm_head after streaming all tensors.""" from executorch.examples.models.gemma4_31b.quant import pack_one lm_head = getattr(model.lm_head, "weight", None) if lm_head is None or lm_head.device.type != "meta": return - if embed_quant is not None: + if embed_q6k_raw is not None: + # Tied Q6_K weights: lm_head is a matmul, so use the fused gguf_linear + # op (the embedding itself stays a quantized gather). + from executorch.examples.models.gemma4_31b.mlx_gguf_linear import ( + replace_with_gguf_linear, + ) + + replace_with_gguf_linear(model, "lm_head.weight", embed_q6k_raw, format="q6k") + elif embed_quant is not None: pack_one(model, "lm_head.weight", embed_quant, packers) else: pack_one( @@ -115,7 +123,11 @@ def load_gguf_model( """ from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig from executorch.examples.models.gemma4_31b.quant import dequantize_weight, pack_one - from executorch.examples.models.gemma4_31b.quant.gguf import iter_gguf_tensors + from executorch.examples.models.gemma4_31b.quant.gguf import ( + _unpack_q6_k, + iter_gguf_tensors, + ) + from torchao.quantization import IntxUnpackedToInt8Tensor from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor if backend == "cuda": @@ -136,6 +148,7 @@ def load_gguf_model( model = Gemma4_31B(config) embed_quant = None + embed_q6k_raw = None # raw Q6_K embedding blob, reused for a tied lm_head n_processed = 0 from gguf import GGMLQuantizationType @@ -151,22 +164,46 @@ def load_gguf_model( if type(result) is torch.Tensor and result.dtype == torch.float32: result = result.to(torch.bfloat16) - if model_key == "embed_tokens.weight" and isinstance(result, Int4Tensor): - embed_quant = result - if backend == "cuda": - result = dequantize_weight(result, torch.bfloat16) - - # MLX Q6_K: keep the raw GGUF blob and swap the nn.Linear for a - # GGUFLinear that dispatches to the fused mlx::gguf_linear kernel, - # bypassing the slow group_size=16 non-fused affine path. + # MLX Q6_K: Linear weights use the fused gguf_linear op, the token + # embedding uses the fused gguf_embedding op -- both consume the raw + # GGUF blob (group_size=16 has no MLX affine kernel). The embedding's + # raw blob is also kept so a tied lm_head can use gguf_linear. if backend == "mlx" and gguf_type == GGMLQuantizationType.Q6_K: from executorch.examples.models.gemma4_31b.mlx_gguf_linear import ( + replace_with_gguf_embedding, replace_with_gguf_linear, ) - replace_with_gguf_linear(model, model_key, result, format="q6k") - n_processed += 1 - continue + parent = model.get_submodule(model_key.rsplit(".", 1)[0]) + if isinstance(parent, torch.nn.Linear): + replace_with_gguf_linear(model, model_key, result, format="q6k") + n_processed += 1 + continue + if isinstance(parent, torch.nn.Embedding): + if model_key == "embed_tokens.weight": + embed_q6k_raw = result + replace_with_gguf_embedding(model, model_key, result, format="q6k") + n_processed += 1 + continue + + # Any other Q6_K module: fall back to a quantized tensor. + from executorch.backends.mlx.model_ops.gguf_linear import ( + Q6K_BLOCK_BYTES, + QK_K, + ) + + n_rows, row_bytes = result.shape + result = _unpack_q6_k( + result.reshape(-1), + [n_rows, (row_bytes // Q6K_BLOCK_BYTES) * QK_K], + ) + + if model_key == "embed_tokens.weight" and isinstance( + result, (Int4Tensor, IntxUnpackedToInt8Tensor) + ): + embed_quant = result + if backend == "cuda": + result = dequantize_weight(result, torch.bfloat16) pack_one(model, model_key, result, packers) @@ -174,7 +211,7 @@ def load_gguf_model( if n_processed % 100 == 0: print(f" Processed {n_processed} tensors...") - _resolve_tied_lm_head(model, embed_quant, packers) + _resolve_tied_lm_head(model, embed_quant, embed_q6k_raw, packers) del embed_quant _validate_no_meta(model) diff --git a/examples/models/gemma4_31b/mlx_gguf_linear.py b/examples/models/gemma4_31b/mlx_gguf_linear.py index e22b3bde2bc..f7555f6150d 100644 --- a/examples/models/gemma4_31b/mlx_gguf_linear.py +++ b/examples/models/gemma4_31b/mlx_gguf_linear.py @@ -4,17 +4,18 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -"""MLX carrier module for GGUF Q6_K linears. +"""MLX carrier modules for GGUF Q6_K weights. -Wraps the raw GGUF ``block_q6_K`` weight blob and dispatches to the fused -``mlx::gguf_linear`` Metal kernel (decode mat-vec / prefill mat-mat), instead of -the slow non-fused dequantize+matmul path that group_size=16 affine quant takes -through the MLX ``QUANTIZED_LINEAR`` pattern. +Wrap raw GGUF ``block_q6_K`` blobs and dispatch to the fused ``mlx::gguf_linear`` +(matmul) and ``mlx::gguf_embedding`` (gather) Metal kernels, instead of the slow +non-fused dequantize paths that group_size=16 affine quant takes through the MLX +``QUANTIZED_LINEAR`` / quantized-embedding patterns. """ from __future__ import annotations -# Importing the op module registers ``torch.ops.mlx.gguf_linear``. +# Importing the op modules registers the custom ops. +import executorch.backends.mlx.model_ops.gguf_embedding # noqa: F401 import executorch.backends.mlx.model_ops.gguf_linear # noqa: F401 import torch @@ -83,5 +84,68 @@ def replace_with_gguf_linear( ) linear_fqn = parts[0] grandparent_fqn, _, child_name = linear_fqn.rpartition(".") - grandparent = model.get_submodule(grandparent_fqn) if grandparent_fqn else model + grandparent = ( + model.get_submodule(grandparent_fqn) if grandparent_fqn else model + ) setattr(grandparent, child_name, GGUFLinear(weight_blob, format=format)) + + +class GGUFEmbedding(nn.Module): + """``y = gguf_embedding(weight_blob, indices, format)`` for a GGUF table. + + The weight is the **raw** GGUF block blob, stored as a uint8 buffer of shape + ``(num_embeddings, n_blocks * block_bytes)``. ``forward`` returns bfloat16, + matching the model's embedding dtype. + """ + + def __init__(self, weight_blob: torch.Tensor, format: str = "q6k"): + super().__init__() + if weight_blob.dim() != 2 or weight_blob.dtype != torch.uint8: + raise ValueError( + f"GGUFEmbedding: weight_blob must be 2-D uint8; got " + f"shape {tuple(weight_blob.shape)} dtype {weight_blob.dtype}" + ) + if format != "q6k": + raise NotImplementedError( + f"GGUFEmbedding: unsupported format {format!r}; only 'q6k' supported" + ) + row_bytes = int(weight_blob.shape[1]) + if row_bytes % Q6K_BLOCK_BYTES != 0: + raise ValueError( + f"GGUFEmbedding: weight row bytes {row_bytes} must be a multiple of " + f"{Q6K_BLOCK_BYTES}" + ) + self.format = format + self.num_embeddings = int(weight_blob.shape[0]) + self.embedding_dim = (row_bytes // Q6K_BLOCK_BYTES) * QK_K + self.register_buffer("weight", weight_blob) + + def forward(self, indices: torch.Tensor) -> torch.Tensor: + return torch.ops.mlx.gguf_embedding(self.weight, indices, self.format) + + def extra_repr(self) -> str: + return ( + f"num_embeddings={self.num_embeddings}, " + f"embedding_dim={self.embedding_dim}, format={self.format!r}" + ) + + +def replace_with_gguf_embedding( + model: nn.Module, + weight_fqn: str, + weight_blob: torch.Tensor, + format: str = "q6k", +) -> None: + """Replace the ``nn.Embedding`` owning ``weight_fqn`` with a ``GGUFEmbedding``.""" + parts = weight_fqn.rsplit(".", 1) + if len(parts) != 2 or parts[1] != "weight": + raise ValueError( + f"replace_with_gguf_embedding: expected a '*.weight' fqn; " + f"got {weight_fqn!r}" + ) + module_fqn = parts[0] + grandparent_fqn, _, child_name = module_fqn.rpartition(".") + grandparent = ( + model.get_submodule(grandparent_fqn) if grandparent_fqn else model + ) + setattr(grandparent, child_name, GGUFEmbedding(weight_blob, format=format)) diff --git a/examples/models/gemma4_31b/quant/gguf.py b/examples/models/gemma4_31b/quant/gguf.py index 5c58c86cde6..9eeec1e74d1 100644 --- a/examples/models/gemma4_31b/quant/gguf.py +++ b/examples/models/gemma4_31b/quant/gguf.py @@ -111,6 +111,9 @@ def _unpack_q4_k(data, shape: list[int]) -> torch.Tensor: def _unpack_q6_k(data, shape: list[int]) -> torch.Tensor: """Unpack Q6_K super-blocks into an ``IntxUnpackedToInt8Tensor``. + ``data`` may be a raw byte buffer or an already-materialized uint8 tensor of + the block bytes. + Q6_K block layout (210 bytes per 256 values): - ql (128B): lower 4 bits of 256 6-bit values - qh (64B): upper 2 bits of 256 6-bit values @@ -126,7 +129,8 @@ def _unpack_q6_k(data, shape: list[int]) -> torch.Tensor: assert K % QK_K == 0, f"Q6_K requires K divisible by {QK_K}, got {K}" n_blocks = N * (K // QK_K) block_bytes = 2 + QK_K // 2 + QK_K // 4 + QK_K // 16 # 210 - raw = _raw_tensor(data).reshape(n_blocks, block_bytes) + raw_u8 = data if isinstance(data, torch.Tensor) else _raw_tensor(data) + raw = raw_u8.reshape(n_blocks, block_bytes) ql = raw[:, 0:128] qh = raw[:, 128:192] diff --git a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py index 2eba09f2e17..63eeea157f5 100644 --- a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py @@ -364,15 +364,18 @@ def test_replace_with_gguf_linear_swaps_module(self): self.assertEqual(y.shape, torch.Size([2, N])) self.assertTrue(torch.equal(y, ref)) - def test_gguf_linear_appears_in_exported_graph(self): + def test_gguf_linear_delegates_to_mlx(self): + from executorch.backends.mlx import MLXPartitioner + from executorch.examples.models.gemma4_31b.mlx_gguf_linear import GGUFLinear + from executorch.exir import to_edge_transform_and_lower from torch.export import Dim, export N, K = 256, 512 blob = self._make_blob(N, K) + # GGUFLinear passes bias=None, so the lowered node has 3 args (no bias); + # dynamic seq_len exercises the runtime-routed (IfNode) lowering path. m = GGUFLinear(blob, format="q6k").eval() - - # Dynamic seq_len exercises the runtime-routed (IfNode) lowering path. seq = Dim("seq", min=1, max=8) ep = export( m, @@ -380,11 +383,72 @@ def test_gguf_linear_appears_in_exported_graph(self): dynamic_shapes=({0: seq},), strict=True, ) - targets = [str(n.target) for n in ep.graph.nodes if n.op == "call_function"] - self.assertTrue( - any("mlx.gguf_linear" in t for t in targets), - f"gguf_linear not in exported graph: {targets}", + et = to_edge_transform_and_lower(ep, partitioner=[MLXPartitioner()]) + remaining = [ + str(n.target) + for n in et.exported_program().graph.nodes + if n.op == "call_function" and "gguf_linear" in str(n.target) + ] + self.assertEqual(remaining, [], "gguf_linear was not delegated to MLX") + + +class TestGgufEmbeddingMlx(unittest.TestCase): + """Q6_K token embedding routes to the fused mlx::gguf_embedding op.""" + + def _make_blob(self, vocab: int, K: int) -> torch.Tensor: + from executorch.backends.mlx.model_ops.test_gguf_linear import make_q6_k_blob + + return make_q6_k_blob(vocab, K) + + def test_replace_with_gguf_embedding_swaps_module(self): + from executorch.examples.models.gemma4_31b.mlx_gguf_linear import ( + GGUFEmbedding, + replace_with_gguf_embedding, + ) + + model = build_random_tiny_model() + # Q6_K needs embedding_dim % 256 == 0, so use a fixed valid dim (the + # tiny model's hidden_size is smaller); this test exercises the swapped + # module directly, not the full model forward. + vocab, K = 512, 256 + blob = self._make_blob(vocab, K) + replace_with_gguf_embedding(model, "embed_tokens.weight", blob, format="q6k") + + swapped = model.get_submodule("embed_tokens") + self.assertIsInstance(swapped, GGUFEmbedding) + self.assertEqual(swapped.num_embeddings, vocab) + self.assertEqual(swapped.embedding_dim, K) + + idx = torch.randint(0, vocab, (2, 4), dtype=torch.int32) + y = swapped(idx) + ref = torch.ops.mlx.gguf_embedding(blob, idx, "q6k") + self.assertEqual(y.shape, torch.Size([2, 4, K])) + self.assertTrue(torch.equal(y, ref)) + + def test_gguf_embedding_delegates_to_mlx(self): + from executorch.backends.mlx import MLXPartitioner + + from executorch.examples.models.gemma4_31b.mlx_gguf_linear import ( + GGUFEmbedding, + ) + from executorch.exir import to_edge_transform_and_lower + from torch.export import Dim, export + + m = GGUFEmbedding(self._make_blob(512, 256), format="q6k").eval() + seq = Dim("seq", min=1, max=8) + ep = export( + m, + (torch.randint(0, 512, (4,), dtype=torch.int32),), + dynamic_shapes=({0: seq},), + strict=True, ) + et = to_edge_transform_and_lower(ep, partitioner=[MLXPartitioner()]) + remaining = [ + str(n.target) + for n in et.exported_program().graph.nodes + if n.op == "call_function" and "gguf_embedding" in str(n.target) + ] + self.assertEqual(remaining, [], "gguf_embedding was not delegated to MLX") if __name__ == "__main__": From 73b8f68c2143a3d29e772a360f04dc34a171e928 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 4 Jun 2026 13:03:51 -0700 Subject: [PATCH 04/18] up --- .github/workflows/mlx.yml | 9 +++++---- .../__init__.py | 0 .../gated_delta_rule.py | 0 .../gguf_embedding.py | 4 ++-- .../gguf_linear.py | 2 +- .../mlx/custom_kernel_ops/test/__init__.py | 5 +++++ .../test}/test_gated_delta_rule.py | 8 ++++---- .../test}/test_gguf_embedding.py | 9 +++++---- .../test}/test_gguf_linear.py | 18 ++++++++++++------ .../test}/test_tq4_compress.py | 8 ++++---- .../test}/test_tq_dequant.py | 8 ++++---- .../test}/test_tq_norm.py | 8 ++++---- .../tq4_compress.py | 2 +- .../tq_dequant.py | 2 +- .../tq_norm.py | 2 +- backends/mlx/llm/turboquant_cache.py | 6 +++--- examples/models/gemma4_31b/gguf_loader.py | 2 +- examples/models/gemma4_31b/mlx_gguf_linear.py | 6 +++--- .../models/gemma4_31b/quant/tests/test_gguf.py | 4 ++-- .../gemma4_31b/tests/test_mlx_pipeline.py | 4 ++-- .../qwen3_5_moe/mlx_source_transformations.py | 2 +- 21 files changed, 61 insertions(+), 48 deletions(-) rename backends/mlx/{model_ops => custom_kernel_ops}/__init__.py (100%) rename backends/mlx/{model_ops => custom_kernel_ops}/gated_delta_rule.py (100%) rename backends/mlx/{model_ops => custom_kernel_ops}/gguf_embedding.py (97%) rename backends/mlx/{model_ops => custom_kernel_ops}/gguf_linear.py (99%) create mode 100644 backends/mlx/custom_kernel_ops/test/__init__.py rename backends/mlx/{model_ops => custom_kernel_ops/test}/test_gated_delta_rule.py (98%) rename backends/mlx/{model_ops => custom_kernel_ops/test}/test_gguf_embedding.py (90%) rename backends/mlx/{model_ops => custom_kernel_ops/test}/test_gguf_linear.py (91%) rename backends/mlx/{model_ops => custom_kernel_ops/test}/test_tq4_compress.py (94%) rename backends/mlx/{model_ops => custom_kernel_ops/test}/test_tq_dequant.py (93%) rename backends/mlx/{model_ops => custom_kernel_ops/test}/test_tq_norm.py (93%) rename backends/mlx/{model_ops => custom_kernel_ops}/tq4_compress.py (98%) rename backends/mlx/{model_ops => custom_kernel_ops}/tq_dequant.py (98%) rename backends/mlx/{model_ops => custom_kernel_ops}/tq_norm.py (98%) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 69c767224fd..3b5eb80f176 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -77,6 +77,7 @@ jobs: backends/mlx/test/test_passes.py \ backends/mlx/test/test_pattern_utils.py \ backends/mlx/test/test_partitioner.py \ + backends/mlx/test/test_gguf_dequant.py \ examples/models/gemma4_31b/tests/test_mlx_pipeline.py \ examples/models/gemma4_31b/quant/tests/test_gguf.py \ -v @@ -90,12 +91,12 @@ jobs: ./cmake-out/backends/mlx/test/multi_thread_test_runner echo "::endgroup::" - echo "::group::Run model_ops op tests" - # Run every model_ops/test_*.py via its OpTestCase `run` CLI. Adding a + echo "::group::Run custom_kernel_ops op tests" + # Run every custom_kernel_ops/test/test_*.py via its OpTestCase `run` CLI. Adding a # new op test file requires no change here. set -e - for t in backends/mlx/model_ops/test_*.py; do - mod="executorch.backends.mlx.model_ops.$(basename "$t" .py)" + for t in backends/mlx/custom_kernel_ops/test/test_*.py; do + mod="executorch.backends.mlx.custom_kernel_ops.test.$(basename "$t" .py)" echo "--- ${mod} ---" ${CONDA_RUN} python -m "${mod}" run -v done diff --git a/backends/mlx/model_ops/__init__.py b/backends/mlx/custom_kernel_ops/__init__.py similarity index 100% rename from backends/mlx/model_ops/__init__.py rename to backends/mlx/custom_kernel_ops/__init__.py diff --git a/backends/mlx/model_ops/gated_delta_rule.py b/backends/mlx/custom_kernel_ops/gated_delta_rule.py similarity index 100% rename from backends/mlx/model_ops/gated_delta_rule.py rename to backends/mlx/custom_kernel_ops/gated_delta_rule.py diff --git a/backends/mlx/model_ops/gguf_embedding.py b/backends/mlx/custom_kernel_ops/gguf_embedding.py similarity index 97% rename from backends/mlx/model_ops/gguf_embedding.py rename to backends/mlx/custom_kernel_ops/gguf_embedding.py index 81795a11754..11a17c1b08a 100644 --- a/backends/mlx/model_ops/gguf_embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf_embedding.py @@ -21,7 +21,7 @@ Usage:: - import executorch.backends.mlx.model_ops.gguf_embedding # noqa: F401 + import executorch.backends.mlx.custom_kernel_ops.gguf_embedding # noqa: F401 out = torch.ops.mlx.gguf_embedding(weight, indices, "q6k") # weight: (vocab, (K/256)*210) uint8 GGUF q6_K blob @@ -35,7 +35,7 @@ from torch import Tensor from torch.fx.node import Node -from executorch.backends.mlx.model_ops.gguf_linear import ( +from executorch.backends.mlx.custom_kernel_ops.gguf_linear import ( _Q6K_HEADER, dequantize_q6_k, Q6K_BLOCK_BYTES, diff --git a/backends/mlx/model_ops/gguf_linear.py b/backends/mlx/custom_kernel_ops/gguf_linear.py similarity index 99% rename from backends/mlx/model_ops/gguf_linear.py rename to backends/mlx/custom_kernel_ops/gguf_linear.py index 4730683c330..0ac5a550fe4 100644 --- a/backends/mlx/model_ops/gguf_linear.py +++ b/backends/mlx/custom_kernel_ops/gguf_linear.py @@ -45,7 +45,7 @@ Usage:: - import executorch.backends.mlx.model_ops.gguf_linear # noqa: F401 + import executorch.backends.mlx.custom_kernel_ops.gguf_linear # noqa: F401 out = torch.ops.mlx.gguf_linear(x, weight, "q6k", bias) # x: (..., K) bf16 / fp16 / fp32 diff --git a/backends/mlx/custom_kernel_ops/test/__init__.py b/backends/mlx/custom_kernel_ops/test/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/test/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/model_ops/test_gated_delta_rule.py b/backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py similarity index 98% rename from backends/mlx/model_ops/test_gated_delta_rule.py rename to backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py index 10dceef14b1..0a7e6a687f9 100644 --- a/backends/mlx/model_ops/test_gated_delta_rule.py +++ b/backends/mlx/custom_kernel_ops/test/test_gated_delta_rule.py @@ -10,18 +10,18 @@ Usage: # Run all configs: - python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run + python -m executorch.backends.mlx.custom_kernel_ops.test.test_gated_delta_rule run # Run with verbose output: - python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run -v + python -m executorch.backends.mlx.custom_kernel_ops.test.test_gated_delta_rule run -v # Rebuild C++ runner first: - python -m executorch.backends.mlx.model_ops.test_gated_delta_rule run --rebuild + python -m executorch.backends.mlx.custom_kernel_ops.test.test_gated_delta_rule run --rebuild """ from typing import List, Tuple -import executorch.backends.mlx.model_ops.gated_delta_rule # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.gated_delta_rule # noqa: F401 import torch import torch.nn as nn diff --git a/backends/mlx/model_ops/test_gguf_embedding.py b/backends/mlx/custom_kernel_ops/test/test_gguf_embedding.py similarity index 90% rename from backends/mlx/model_ops/test_gguf_embedding.py rename to backends/mlx/custom_kernel_ops/test/test_gguf_embedding.py index 8c8cc97b41e..19464b8f131 100644 --- a/backends/mlx/model_ops/test_gguf_embedding.py +++ b/backends/mlx/custom_kernel_ops/test/test_gguf_embedding.py @@ -14,18 +14,18 @@ Usage:: - python -m executorch.backends.mlx.model_ops.test_gguf_embedding run - python -m executorch.backends.mlx.model_ops.test_gguf_embedding run --rebuild + python -m executorch.backends.mlx.custom_kernel_ops.test.test_gguf_embedding run + python -m executorch.backends.mlx.custom_kernel_ops.test.test_gguf_embedding run --rebuild """ from typing import List, Tuple -import executorch.backends.mlx.model_ops.gguf_embedding # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.gguf_embedding # noqa: F401 import torch import torch.nn as nn -from executorch.backends.mlx.model_ops.test_gguf_linear import make_q6_k_blob +from executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear import make_q6_k_blob from executorch.backends.mlx.test.test_utils import OpTestCase @@ -61,6 +61,7 @@ def get_test_configs(cls) -> List["GGUFEmbeddingTest"]: cls(vocab=512, K=1024, idx_shape=(4,)), cls(vocab=300, K=256, idx_shape=(16,)), # vocab not tile-aligned cls(vocab=512, K=256, idx_shape=(2, 3)), # multi-dim indices + cls(vocab=262144, K=5376, idx_shape=(8,)), # real Gemma-4-31B embed ] def create_model(self) -> nn.Module: diff --git a/backends/mlx/model_ops/test_gguf_linear.py b/backends/mlx/custom_kernel_ops/test/test_gguf_linear.py similarity index 91% rename from backends/mlx/model_ops/test_gguf_linear.py rename to backends/mlx/custom_kernel_ops/test/test_gguf_linear.py index 11cc6b2c4ce..d4b5db2b197 100644 --- a/backends/mlx/model_ops/test_gguf_linear.py +++ b/backends/mlx/custom_kernel_ops/test/test_gguf_linear.py @@ -19,20 +19,20 @@ Usage:: - python -m executorch.backends.mlx.model_ops.test_gguf_linear run - python -m executorch.backends.mlx.model_ops.test_gguf_linear run -v - python -m executorch.backends.mlx.model_ops.test_gguf_linear run --rebuild - python -m executorch.backends.mlx.model_ops.test_gguf_linear eager + python -m executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear run + python -m executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear run -v + python -m executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear run --rebuild + python -m executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear eager """ from typing import List, Tuple -import executorch.backends.mlx.model_ops.gguf_linear # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.gguf_linear # noqa: F401 import torch import torch.nn as nn -from executorch.backends.mlx.model_ops.gguf_linear import ( +from executorch.backends.mlx.custom_kernel_ops.gguf_linear import ( dequantize_q6_k, Q6K_BLOCK_BYTES, QK_K, @@ -133,6 +133,12 @@ def get_test_configs(cls) -> List["GGUFLinearTest"]: # Ragged shapes (M and N not multiples of the 32-wide tile / row group). cfgs.append(cls(M=40, N=300, K=256, dtype=torch.bfloat16)) cfgs.append(cls(M=1, N=300, K=256, dtype=torch.bfloat16)) + # Real Gemma-4-31B shapes (hidden=5376, ffn=21504, vocab=262144) to + # exercise the kernels at production N/K (decode + prefill). + cfgs.append(cls(M=1, N=4096, K=5376, dtype=torch.bfloat16)) # attn_v + cfgs.append(cls(M=1, N=5376, K=21504, dtype=torch.bfloat16)) # ffn_down + cfgs.append(cls(M=8, N=5376, K=21504, dtype=torch.bfloat16)) # ffn_down prefill + cfgs.append(cls(M=1, N=262144, K=5376, dtype=torch.bfloat16)) # lm_head return cfgs def create_model(self) -> nn.Module: diff --git a/backends/mlx/model_ops/test_tq4_compress.py b/backends/mlx/custom_kernel_ops/test/test_tq4_compress.py similarity index 94% rename from backends/mlx/model_ops/test_tq4_compress.py rename to backends/mlx/custom_kernel_ops/test/test_tq4_compress.py index c2aaa13afa7..ba114e67b23 100644 --- a/backends/mlx/model_ops/test_tq4_compress.py +++ b/backends/mlx/custom_kernel_ops/test/test_tq4_compress.py @@ -13,14 +13,14 @@ Usage:: - python -m executorch.backends.mlx.model_ops.test_tq4_compress run - python -m executorch.backends.mlx.model_ops.test_tq4_compress run -v - python -m executorch.backends.mlx.model_ops.test_tq4_compress run --rebuild + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq4_compress run + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq4_compress run -v + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq4_compress run --rebuild """ from typing import List, Tuple -import executorch.backends.mlx.model_ops.tq4_compress # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.tq4_compress # noqa: F401 import torch import torch.nn as nn diff --git a/backends/mlx/model_ops/test_tq_dequant.py b/backends/mlx/custom_kernel_ops/test/test_tq_dequant.py similarity index 93% rename from backends/mlx/model_ops/test_tq_dequant.py rename to backends/mlx/custom_kernel_ops/test/test_tq_dequant.py index 07d9deb895a..f50fad9b651 100644 --- a/backends/mlx/model_ops/test_tq_dequant.py +++ b/backends/mlx/custom_kernel_ops/test/test_tq_dequant.py @@ -15,14 +15,14 @@ Usage:: - python -m executorch.backends.mlx.model_ops.test_tq_dequant run - python -m executorch.backends.mlx.model_ops.test_tq_dequant run -v - python -m executorch.backends.mlx.model_ops.test_tq_dequant run --rebuild + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq_dequant run + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq_dequant run -v + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq_dequant run --rebuild """ from typing import List, Tuple -import executorch.backends.mlx.model_ops.tq_dequant # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.tq_dequant # noqa: F401 import torch import torch.nn as nn diff --git a/backends/mlx/model_ops/test_tq_norm.py b/backends/mlx/custom_kernel_ops/test/test_tq_norm.py similarity index 93% rename from backends/mlx/model_ops/test_tq_norm.py rename to backends/mlx/custom_kernel_ops/test/test_tq_norm.py index 35c4491d8ae..4f3b93a945f 100644 --- a/backends/mlx/model_ops/test_tq_norm.py +++ b/backends/mlx/custom_kernel_ops/test/test_tq_norm.py @@ -13,14 +13,14 @@ Usage:: - python -m executorch.backends.mlx.model_ops.test_tq_norm run - python -m executorch.backends.mlx.model_ops.test_tq_norm run -v - python -m executorch.backends.mlx.model_ops.test_tq_norm run --rebuild + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq_norm run + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq_norm run -v + python -m executorch.backends.mlx.custom_kernel_ops.test.test_tq_norm run --rebuild """ from typing import List, Tuple -import executorch.backends.mlx.model_ops.tq_norm # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.tq_norm # noqa: F401 import torch import torch.nn as nn diff --git a/backends/mlx/model_ops/tq4_compress.py b/backends/mlx/custom_kernel_ops/tq4_compress.py similarity index 98% rename from backends/mlx/model_ops/tq4_compress.py rename to backends/mlx/custom_kernel_ops/tq4_compress.py index f08d47b9a11..f957be379c0 100644 --- a/backends/mlx/model_ops/tq4_compress.py +++ b/backends/mlx/custom_kernel_ops/tq4_compress.py @@ -20,7 +20,7 @@ Usage:: - import executorch.backends.mlx.model_ops.tq4_compress # noqa: F401 + import executorch.backends.mlx.custom_kernel_ops.tq4_compress # noqa: F401 packed = torch.ops.mlx.tq4_compress(rotated, boundaries) # rotated: (..., D) float diff --git a/backends/mlx/model_ops/tq_dequant.py b/backends/mlx/custom_kernel_ops/tq_dequant.py similarity index 98% rename from backends/mlx/model_ops/tq_dequant.py rename to backends/mlx/custom_kernel_ops/tq_dequant.py index 28a168e9be0..0c1842712e4 100644 --- a/backends/mlx/model_ops/tq_dequant.py +++ b/backends/mlx/custom_kernel_ops/tq_dequant.py @@ -23,7 +23,7 @@ Usage:: - import executorch.backends.mlx.model_ops.tq_dequant # noqa: F401 + import executorch.backends.mlx.custom_kernel_ops.tq_dequant # noqa: F401 out = torch.ops.mlx.tq_dequant(packed, norms, centroids) # packed: (..., D/2) uint8 diff --git a/backends/mlx/model_ops/tq_norm.py b/backends/mlx/custom_kernel_ops/tq_norm.py similarity index 98% rename from backends/mlx/model_ops/tq_norm.py rename to backends/mlx/custom_kernel_ops/tq_norm.py index 7e6a4d657f3..e456c2f6aa4 100644 --- a/backends/mlx/model_ops/tq_norm.py +++ b/backends/mlx/custom_kernel_ops/tq_norm.py @@ -20,7 +20,7 @@ Usage:: - import executorch.backends.mlx.model_ops.tq_norm # noqa: F401 + import executorch.backends.mlx.custom_kernel_ops.tq_norm # noqa: F401 norms = torch.ops.mlx.tq_norm(x) # x: (..., D) bf16 diff --git a/backends/mlx/llm/turboquant_cache.py b/backends/mlx/llm/turboquant_cache.py index 7f2109ba074..f9dc311cada 100644 --- a/backends/mlx/llm/turboquant_cache.py +++ b/backends/mlx/llm/turboquant_cache.py @@ -27,9 +27,9 @@ # Register the MLX custom ops used by this cache. import executorch.backends.mlx.custom_ops # noqa: F401 mlx::custom_sdpa, mlx::kv_cache_update -import executorch.backends.mlx.model_ops.tq4_compress # noqa: F401 mlx::tq4_compress -import executorch.backends.mlx.model_ops.tq_dequant # noqa: F401 mlx::tq_dequant -import executorch.backends.mlx.model_ops.tq_norm # noqa: F401 mlx::tq_norm +import executorch.backends.mlx.custom_kernel_ops.tq4_compress # noqa: F401 mlx::tq4_compress +import executorch.backends.mlx.custom_kernel_ops.tq_dequant # noqa: F401 mlx::tq_dequant +import executorch.backends.mlx.custom_kernel_ops.tq_norm # noqa: F401 mlx::tq_norm import torch diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index b3ad89fc2e3..6f53436808c 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -187,7 +187,7 @@ def load_gguf_model( continue # Any other Q6_K module: fall back to a quantized tensor. - from executorch.backends.mlx.model_ops.gguf_linear import ( + from executorch.backends.mlx.custom_kernel_ops.gguf_linear import ( Q6K_BLOCK_BYTES, QK_K, ) diff --git a/examples/models/gemma4_31b/mlx_gguf_linear.py b/examples/models/gemma4_31b/mlx_gguf_linear.py index f7555f6150d..150c2773fd6 100644 --- a/examples/models/gemma4_31b/mlx_gguf_linear.py +++ b/examples/models/gemma4_31b/mlx_gguf_linear.py @@ -15,12 +15,12 @@ from __future__ import annotations # Importing the op modules registers the custom ops. -import executorch.backends.mlx.model_ops.gguf_embedding # noqa: F401 -import executorch.backends.mlx.model_ops.gguf_linear # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.gguf_embedding # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.gguf_linear # noqa: F401 import torch import torch.nn as nn -from executorch.backends.mlx.model_ops.gguf_linear import Q6K_BLOCK_BYTES, QK_K +from executorch.backends.mlx.custom_kernel_ops.gguf_linear import Q6K_BLOCK_BYTES, QK_K class GGUFLinear(nn.Module): diff --git a/examples/models/gemma4_31b/quant/tests/test_gguf.py b/examples/models/gemma4_31b/quant/tests/test_gguf.py index dd426aab73b..5ebc93279bb 100644 --- a/examples/models/gemma4_31b/quant/tests/test_gguf.py +++ b/examples/models/gemma4_31b/quant/tests/test_gguf.py @@ -302,7 +302,7 @@ def test_raw_blob_preserves_bytes(self): ) def test_raw_blob_dequant_matches_gguf_reference(self): - from executorch.backends.mlx.model_ops.gguf_linear import dequantize_q6_k + from executorch.backends.mlx.custom_kernel_ops.gguf_linear import dequantize_q6_k from executorch.examples.models.gemma4_31b.quant.gguf import _raw_q6_k d, scales_16, qvals, block = self._block() @@ -320,7 +320,7 @@ def test_raw_blob_dequant_matches_gguf_lib(self): # dequantize Q6_K even though it cannot quantize to it). import gguf - from executorch.backends.mlx.model_ops.gguf_linear import dequantize_q6_k + from executorch.backends.mlx.custom_kernel_ops.gguf_linear import dequantize_q6_k from executorch.examples.models.gemma4_31b.quant.gguf import _raw_q6_k _, _, _, block = self._block() diff --git a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py index 63eeea157f5..2bdce6f60e2 100644 --- a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py @@ -327,7 +327,7 @@ class TestGgufLinearMlx(unittest.TestCase): """Q6_K weights route to the fused mlx::gguf_linear op (raw-blob path).""" def _make_blob(self, N: int, K: int) -> torch.Tensor: - from executorch.backends.mlx.model_ops.test_gguf_linear import make_q6_k_blob + from executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear import make_q6_k_blob return make_q6_k_blob(N, K) @@ -396,7 +396,7 @@ class TestGgufEmbeddingMlx(unittest.TestCase): """Q6_K token embedding routes to the fused mlx::gguf_embedding op.""" def _make_blob(self, vocab: int, K: int) -> torch.Tensor: - from executorch.backends.mlx.model_ops.test_gguf_linear import make_q6_k_blob + from executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear import make_q6_k_blob return make_q6_k_blob(vocab, K) diff --git a/examples/models/qwen3_5_moe/mlx_source_transformations.py b/examples/models/qwen3_5_moe/mlx_source_transformations.py index 25605fb6342..9a49f8a84f6 100644 --- a/examples/models/qwen3_5_moe/mlx_source_transformations.py +++ b/examples/models/qwen3_5_moe/mlx_source_transformations.py @@ -194,7 +194,7 @@ def _exportable_gated_delta_net_forward(self, x, input_pos): x = a + self.dt_bias g = (-self.A_log.exp() * torch.logaddexp(x, torch.zeros_like(x))).exp() - import executorch.backends.mlx.model_ops.gated_delta_rule as _ # noqa: ensure op registered + import executorch.backends.mlx.custom_kernel_ops.gated_delta_rule as _ # noqa: ensure op registered output = torch.ops.mlx.gated_delta_rule( q, From d8fd5ffd5a00d50a30a441a544ca94dc711db8ef Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 4 Jun 2026 17:10:42 -0700 Subject: [PATCH 05/18] up --- .../mlx/custom_kernel_ops/gguf_embedding.py | 4 +- .../test/test_gguf_embedding.py | 10 +++- backends/mlx/llm/turboquant_cache.py | 5 +- examples/models/gemma4_31b/gguf_loader.py | 48 +++++++++++++------ examples/models/gemma4_31b/mlx_gguf_linear.py | 8 +--- .../gemma4_31b/quant/tests/test_gguf.py | 8 +++- .../gemma4_31b/tests/test_mlx_pipeline.py | 12 +++-- 7 files changed, 61 insertions(+), 34 deletions(-) diff --git a/backends/mlx/custom_kernel_ops/gguf_embedding.py b/backends/mlx/custom_kernel_ops/gguf_embedding.py index 11a17c1b08a..3930ab75c90 100644 --- a/backends/mlx/custom_kernel_ops/gguf_embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf_embedding.py @@ -32,8 +32,6 @@ from __future__ import annotations import torch -from torch import Tensor -from torch.fx.node import Node from executorch.backends.mlx.custom_kernel_ops.gguf_linear import ( _Q6K_HEADER, @@ -41,6 +39,8 @@ Q6K_BLOCK_BYTES, QK_K, ) +from torch import Tensor +from torch.fx.node import Node # --------------------------------------------------------------------------- diff --git a/backends/mlx/custom_kernel_ops/test/test_gguf_embedding.py b/backends/mlx/custom_kernel_ops/test/test_gguf_embedding.py index 19464b8f131..6548e9d2785 100644 --- a/backends/mlx/custom_kernel_ops/test/test_gguf_embedding.py +++ b/backends/mlx/custom_kernel_ops/test/test_gguf_embedding.py @@ -25,7 +25,9 @@ import torch import torch.nn as nn -from executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear import make_q6_k_blob +from executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear import ( + make_q6_k_blob, +) from executorch.backends.mlx.test.test_utils import OpTestCase @@ -74,7 +76,7 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: return (weight, indices) -if __name__ == "__main__": +def _main() -> None: # noqa: C901 import argparse import sys @@ -129,3 +131,7 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: if failed_names: print(f"Failed: {', '.join(failed_names)}") sys.exit(0 if failed == 0 else 1) + + +if __name__ == "__main__": + _main() diff --git a/backends/mlx/llm/turboquant_cache.py b/backends/mlx/llm/turboquant_cache.py index f9dc311cada..b262876c481 100644 --- a/backends/mlx/llm/turboquant_cache.py +++ b/backends/mlx/llm/turboquant_cache.py @@ -25,12 +25,13 @@ from typing import Optional, Tuple -# Register the MLX custom ops used by this cache. -import executorch.backends.mlx.custom_ops # noqa: F401 mlx::custom_sdpa, mlx::kv_cache_update import executorch.backends.mlx.custom_kernel_ops.tq4_compress # noqa: F401 mlx::tq4_compress import executorch.backends.mlx.custom_kernel_ops.tq_dequant # noqa: F401 mlx::tq_dequant import executorch.backends.mlx.custom_kernel_ops.tq_norm # noqa: F401 mlx::tq_norm +# Register the MLX custom ops used by this cache. +import executorch.backends.mlx.custom_ops # noqa: F401 mlx::custom_sdpa, mlx::kv_cache_update + import torch from executorch.extension.llm.modules.turboquant.kv_cache import ( diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 6f53436808c..2954a7c7644 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -103,6 +103,34 @@ def _validate_no_meta(model): p.requires_grad_(False) +def _handle_mlx_q6k(model, model_key, result): + """Handle Q6_K tensors for the MLX backend. + + Returns ``(processed, embed_q6k_raw)`` where ``processed`` is True + if the tensor was consumed by the fused gguf_linear/gguf_embedding + path (caller should ``continue``), and ``embed_q6k_raw`` is the raw + blob when the tensor is the embedding (for tied lm_head reuse). + """ + from executorch.examples.models.gemma4_31b.mlx_gguf_linear import ( + replace_with_gguf_embedding, + replace_with_gguf_linear, + ) + + embed_q6k_raw = None + parent = model.get_submodule(model_key.rsplit(".", 1)[0]) + if isinstance(parent, torch.nn.Linear): + replace_with_gguf_linear(model, model_key, result, format="q6k") + return True, None + if isinstance(parent, torch.nn.Embedding): + if model_key == "embed_tokens.weight": + embed_q6k_raw = result + replace_with_gguf_embedding(model, model_key, result, format="q6k") + return True, embed_q6k_raw + + # Any other Q6_K module: fall back to a quantized tensor. + return False, None + + def load_gguf_model( gguf_path: str, max_seq_len: int = 4096, @@ -169,24 +197,14 @@ def load_gguf_model( # GGUF blob (group_size=16 has no MLX affine kernel). The embedding's # raw blob is also kept so a tied lm_head can use gguf_linear. if backend == "mlx" and gguf_type == GGMLQuantizationType.Q6_K: - from executorch.examples.models.gemma4_31b.mlx_gguf_linear import ( - replace_with_gguf_embedding, - replace_with_gguf_linear, - ) - - parent = model.get_submodule(model_key.rsplit(".", 1)[0]) - if isinstance(parent, torch.nn.Linear): - replace_with_gguf_linear(model, model_key, result, format="q6k") - n_processed += 1 - continue - if isinstance(parent, torch.nn.Embedding): - if model_key == "embed_tokens.weight": - embed_q6k_raw = result - replace_with_gguf_embedding(model, model_key, result, format="q6k") + processed, raw = _handle_mlx_q6k(model, model_key, result) + if raw is not None: + embed_q6k_raw = raw + if processed: n_processed += 1 continue - # Any other Q6_K module: fall back to a quantized tensor. + # Fallback: unpack Q6_K to a quantized tensor. from executorch.backends.mlx.custom_kernel_ops.gguf_linear import ( Q6K_BLOCK_BYTES, QK_K, diff --git a/examples/models/gemma4_31b/mlx_gguf_linear.py b/examples/models/gemma4_31b/mlx_gguf_linear.py index 150c2773fd6..b8cb18ad47e 100644 --- a/examples/models/gemma4_31b/mlx_gguf_linear.py +++ b/examples/models/gemma4_31b/mlx_gguf_linear.py @@ -84,9 +84,7 @@ def replace_with_gguf_linear( ) linear_fqn = parts[0] grandparent_fqn, _, child_name = linear_fqn.rpartition(".") - grandparent = ( - model.get_submodule(grandparent_fqn) if grandparent_fqn else model - ) + grandparent = model.get_submodule(grandparent_fqn) if grandparent_fqn else model setattr(grandparent, child_name, GGUFLinear(weight_blob, format=format)) @@ -145,7 +143,5 @@ def replace_with_gguf_embedding( ) module_fqn = parts[0] grandparent_fqn, _, child_name = module_fqn.rpartition(".") - grandparent = ( - model.get_submodule(grandparent_fqn) if grandparent_fqn else model - ) + grandparent = model.get_submodule(grandparent_fqn) if grandparent_fqn else model setattr(grandparent, child_name, GGUFEmbedding(weight_blob, format=format)) diff --git a/examples/models/gemma4_31b/quant/tests/test_gguf.py b/examples/models/gemma4_31b/quant/tests/test_gguf.py index 5ebc93279bb..f42f2e7aeb6 100644 --- a/examples/models/gemma4_31b/quant/tests/test_gguf.py +++ b/examples/models/gemma4_31b/quant/tests/test_gguf.py @@ -302,7 +302,9 @@ def test_raw_blob_preserves_bytes(self): ) def test_raw_blob_dequant_matches_gguf_reference(self): - from executorch.backends.mlx.custom_kernel_ops.gguf_linear import dequantize_q6_k + from executorch.backends.mlx.custom_kernel_ops.gguf_linear import ( + dequantize_q6_k, + ) from executorch.examples.models.gemma4_31b.quant.gguf import _raw_q6_k d, scales_16, qvals, block = self._block() @@ -320,7 +322,9 @@ def test_raw_blob_dequant_matches_gguf_lib(self): # dequantize Q6_K even though it cannot quantize to it). import gguf - from executorch.backends.mlx.custom_kernel_ops.gguf_linear import dequantize_q6_k + from executorch.backends.mlx.custom_kernel_ops.gguf_linear import ( + dequantize_q6_k, + ) from executorch.examples.models.gemma4_31b.quant.gguf import _raw_q6_k _, _, _, block = self._block() diff --git a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py index 2bdce6f60e2..3f9d3cc96e9 100644 --- a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py @@ -327,7 +327,9 @@ class TestGgufLinearMlx(unittest.TestCase): """Q6_K weights route to the fused mlx::gguf_linear op (raw-blob path).""" def _make_blob(self, N: int, K: int) -> torch.Tensor: - from executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear import make_q6_k_blob + from executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear import ( + make_q6_k_blob, + ) return make_q6_k_blob(N, K) @@ -396,7 +398,9 @@ class TestGgufEmbeddingMlx(unittest.TestCase): """Q6_K token embedding routes to the fused mlx::gguf_embedding op.""" def _make_blob(self, vocab: int, K: int) -> torch.Tensor: - from executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear import make_q6_k_blob + from executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear import ( + make_q6_k_blob, + ) return make_q6_k_blob(vocab, K) @@ -428,9 +432,7 @@ def test_replace_with_gguf_embedding_swaps_module(self): def test_gguf_embedding_delegates_to_mlx(self): from executorch.backends.mlx import MLXPartitioner - from executorch.examples.models.gemma4_31b.mlx_gguf_linear import ( - GGUFEmbedding, - ) + from executorch.examples.models.gemma4_31b.mlx_gguf_linear import GGUFEmbedding from executorch.exir import to_edge_transform_and_lower from torch.export import Dim, export From e9eb22a371cb90b7df69db7bc5fe2e076c2d911a Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 4 Jun 2026 17:39:50 -0700 Subject: [PATCH 06/18] up --- .github/workflows/mlx.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 3b5eb80f176..5175c565075 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -77,9 +77,7 @@ jobs: backends/mlx/test/test_passes.py \ backends/mlx/test/test_pattern_utils.py \ backends/mlx/test/test_partitioner.py \ - backends/mlx/test/test_gguf_dequant.py \ examples/models/gemma4_31b/tests/test_mlx_pipeline.py \ - examples/models/gemma4_31b/quant/tests/test_gguf.py \ -v echo "::endgroup::" From acb917fca2c9cdc1e918e949c6006feb8073d5af Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Thu, 4 Jun 2026 18:12:27 -0700 Subject: [PATCH 07/18] perf improvements --- backends/mlx/custom_kernel_ops/gguf_linear.py | 268 ++++++++++++------ 1 file changed, 180 insertions(+), 88 deletions(-) diff --git a/backends/mlx/custom_kernel_ops/gguf_linear.py b/backends/mlx/custom_kernel_ops/gguf_linear.py index 0ac5a550fe4..484d12dbc24 100644 --- a/backends/mlx/custom_kernel_ops/gguf_linear.py +++ b/backends/mlx/custom_kernel_ops/gguf_linear.py @@ -226,13 +226,14 @@ def gguf_linear_fake( IfNode, IntOrVid, MetalKernelNode, + MultiplyIntNode, SubtractIntNode, ) # Shared Metal header: the GGUF block_q6_K struct (matches llama.cpp -# ggml-common.h; sizeof == 210, no padding since max align is 2) plus a -# per-element dequant helper used by the mat-mat kernel. +# ggml-common.h; sizeof == 210, no padding since max align is 2) plus +# dequant helpers for both per-element (embedding) and vectorized (matmul). _Q6K_HEADER = """ #include #include @@ -248,7 +249,7 @@ def gguf_linear_fake( } block_q6_K; // Dequantize a single element at within-block position p (0..255) of a -// block_q6_K. Matches the canonical ggml dequantize_row_q6_K layout. +// block_q6_K. Used by the embedding kernel. inline float dequant_q6k_elem(device const block_q6_K * blk, int p) { const int h = p >> 7; // which 128-element half (0/1) const int pp = p & 127; // position within half (0..127) @@ -267,6 +268,43 @@ def gguf_linear_fake( const float scale = (float) sc[is + 2 * g]; return (float) blk->d * scale * (float)(q - 32); } + +// Vectorized Q6_K dequantize: decodes 16 values per call into a 4x4 half +// register. Ported from llama.cpp dequantize_q6_K. `il` ranges 0..15 and +// selects which 16-element slice of the 256-element block to decode. +inline void dequantize_q6_K_16(device const block_q6_K * xb, short il, + thread half4x4 & reg) { + const half d_all = xb->d; + device const uint16_t * ql = (device const uint16_t *)xb->ql; + device const uint16_t * qh = (device const uint16_t *)xb->qh; + device const int8_t * scales = (device const int8_t *)xb->scales; + + ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1); + qh = qh + 16*(il/8) + 8*(il&1); + float sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; + + const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303); + const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F; + const float coeff = d_all * sc; + const float ml = coeff * 32.f; + const float dl0 = coeff; + const float dl1 = dl0 / 256.f; + const float dl2 = dl0 / (256.f * 256.f); + const float dl3 = dl0 / (256.f * 256.f * 256.f); + const uint8_t shr_h = il>2 ? 2 : 0; + const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4); + const uint8_t shr_l = il>1 ? 4 : 0; + for (int i = 0; i < 4; ++i) { + const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2; + const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1; + const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l); + reg[i][0] = (half)(dl0 * ((half)(q & 0xFF)) - ml); + reg[i][1] = (half)(dl1 * ((float)(q & 0xFF00)) - ml); + reg[i][2] = (half)(dl2 * ((float)(q & 0xFF0000)) - ml); + reg[i][3] = (half)(dl3 * ((float)(q & 0xFF000000)) - ml); + } +} """ @@ -344,105 +382,147 @@ def _q6k_matvec_source(has_bias: bool) -> str: """ -# Prefill mat-mat kernel (32x32x32 tiles, 4 simdgroups / 128 threads). -# Each threadgroup computes a BM x BN output tile; weights are dequantized -# on the fly into threadgroup memory and reused across the BM activation rows. +# Prefill mat-mat kernel, ported from llama.cpp kernel_mul_mm (Q6_K variant). +# 64x32 output tiles, 4 simdgroups / 128 threads per threadgroup. +# Uses vectorized dequantize_q6_K_16 to decode 16 weight values per thread +# into threadgroup memory, then runs simdgroup_multiply_accumulate on 8x8 +# tiles. NL=16 for Q6_K (QK_K / 16 = 16 dequant steps per super-block). # C[m, n] = sum_k x[m, k] * dequant(weight)[n, k] (+ bias[n]). def _q6k_matmul_source(has_bias: bool) -> str: - bias_line = " v += (float) bias[gn];\n" if has_bias else "" + bias_add = "+ (float) bias[r0 + i]" if has_bias else "" return f""" - constexpr short BM = 32; // activation rows per tile (M) - constexpr short BN = 32; // output features per tile (N) - constexpr short BK = 32; // K-chunk per iteration - - threadgroup half As[BM * BK]; - threadgroup half Bs[BN * BK]; - threadgroup float Cs[BM * BN]; - - const ushort tid = thread_index_in_threadgroup; // 0..127 - const ushort sgitg = simdgroup_index_in_threadgroup; // 0..3 - const short sg_row = sgitg / 2; // 0/1 - const short sg_col = sgitg % 2; // 0/1 - - const uint tile_n_idx = thread_position_in_grid.x / 128u; - const uint tile_m_idx = thread_position_in_grid.y; - const int tile_m0 = (int)tile_m_idx * BM; - const int tile_n0 = (int)tile_n_idx * BN; - - // M (number of activation rows) read at runtime from the injected x_shape, - // so this kernel works for both static and symbolic seqlen. + constexpr short NR0 = 64; // weight/output rows per tile (N dim) + constexpr short NR1 = 32; // activation rows per tile (M dim) + constexpr short NK = 32; // K-chunk per iteration + constexpr short NL = 16; // Q6_K: QK_K / 16 + constexpr short NL0 = NK / 16; // = 2 — dequant iterations per thread for weight + constexpr short NL1 = NK / 8; // = 4 — load iterations per thread for activation + + threadgroup half sa[4096]; // NR0 * NK storage (strided by 64) + threadgroup half sb[4096]; // NR1 * NK storage (strided by 64) + + const ushort tid = thread_index_in_threadgroup; // 0..127 + const ushort sgitg = simdgroup_index_in_threadgroup; // 0..3 + + const uint r0 = thread_position_in_grid.y * NR0; // first weight row + const uint r1 = (thread_position_in_grid.x / 128u) * NR1; // first activation row + + // M (number of activation rows) read at runtime. int M = 1; for (uint d = 0; d + 1 < x_ndim; ++d) {{ M *= (int) x_shape[d]; }} const int nb = K / QK_K; - device const block_q6_K * wrows = (device const block_q6_K *) weight; - simdgroup_float8x8 mc[2][2]; - for (short a = 0; a < 2; ++a) {{ - for (short b = 0; b < 2; ++b) {{ - mc[a][b] = make_filled_simdgroup_matrix(0.f); - }} + // Clamp tile edges. + const short nr0 = (N - (int)r0 < NR0) ? (N - (int)r0) : NR0; + const short nr1 = (M - (int)r1 < NR1) ? (M - (int)r1) : NR1; + + // Thread → element mapping for cooperative loads. + const short lr0 = ((short)(tid / NL0) < nr0) ? (short)(tid / NL0) : (nr0 - 1); // 0..63 + const short lr1 = ((short)(tid / NL1) < nr1) ? (short)(tid / NL1) : (nr1 - 1); // 0..31 + + short il0 = tid % NL0; + short il = il0; // current dequant sub-block index within Q6_K block + + const short offset1 = il0 / NL; // always 0 for NL=16, NL0=2 + + // Pointer to weight block for this thread's assigned row. + device const block_q6_K * wblk = (device const block_q6_K *) weight + + (uint)(r0 + lr0) * nb + offset1; + + // Pointer to activation row for this thread. + const short iy = 8 * (tid % NL1); + device const InT * yp = x + (uint)(r1 + lr1) * (uint)K + iy; + + // Accumulator: 8 simdgroup 8x8 matrices (4 sgitg configs x 2 sub-tiles). + simdgroup_half8x8 ma[4]; + simdgroup_half8x8 mb[2]; + simdgroup_float8x8 mc[8]; + for (short i = 0; i < 8; ++i) {{ + mc[i] = make_filled_simdgroup_matrix(0.f); }} - for (int k0 = 0; k0 < K; k0 += BK) {{ - // Cooperative load: activation tile (BM x BK), row-major in As. - for (short i = 0; i < (BM * BK) / 128; ++i) {{ - const short idx = tid + i * 128; - const short mm = idx / BK; - const short kk = idx % BK; - const int gm = tile_m0 + mm; - As[mm * BK + kk] = (gm < M) ? (half) x[(uint)gm * (uint)K + (k0 + kk)] : (half)0; + for (int loop_k = 0; loop_k < K; loop_k += NK) {{ + // --- Cooperative load: dequantized weight tile (NR0 x NK) into sa --- + half4x4 temp_a; + dequantize_q6_K_16(wblk, il, temp_a); + + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (short i = 0; i < 16; ++i) {{ + const short sx = 2 * il0 + i / 8; + const short sy = (tid / NL0) / 8; + const short lx = (tid / NL0) % 8; + const short ly = i % 8; + const short ib = 8 * sx + sy; + *(sa + 64 * ib + 8 * ly + lx) = temp_a[i / 4][i % 4]; }} - // Cooperative load: dequantized weight tile (BN x BK), row-major in Bs. - for (short i = 0; i < (BN * BK) / 128; ++i) {{ - const short idx = tid + i * 128; - const short nn = idx / BK; - const short kk = idx % BK; - const int gn = tile_n0 + nn; - const int gk = k0 + kk; - half val = (half)0; - if (gn < N) {{ - device const block_q6_K * blk = wrows + (uint)gn * nb + (gk / QK_K); - val = (half) dequant_q6k_elem(blk, gk % QK_K); - }} - Bs[nn * BK + kk] = val; + + // --- Cooperative load: activation tile (NR1 x NK) into sb --- + const short sx_b = tid % NL1; + const short sy_b = (tid / NL1) / 8; + const short ly_b = (tid / NL1) % 8; + const short ib_b = 4 * sx_b + sy_b; + + for (short i = 0; i < 8; ++i) {{ + *(sb + 64 * ib_b + 8 * ly_b + i) = (half) *(yp + i); }} + + // Advance weight pointer through Q6_K sub-blocks. + il = (il + 2 < NL) ? il + 2 : il % 2; + wblk = (il < 2) ? wblk + (2 + NL - 1) / NL : wblk; + + yp += NK; + threadgroup_barrier(mem_flags::mem_threadgroup); - for (short kk = 0; kk < BK / 8; ++kk) {{ - simdgroup_half8x8 a[2], b[2]; - for (short sr = 0; sr < 2; ++sr) {{ - simdgroup_load(a[sr], As + (16 * sg_row + 8 * sr) * BK + 8 * kk, BK, ulong2(0, 0), false); + // --- Simdgroup matmul on loaded tiles --- + threadgroup const half * lsma = sa + 4 * 64 * (sgitg % 2); + threadgroup const half * lsmb = sb + 2 * 64 * (sgitg / 2); + + for (short ik = 0; ik < NK / 8; ++ik) {{ + simdgroup_barrier(mem_flags::mem_none); + for (short i = 0; i < 4; ++i) {{ + simdgroup_load(ma[i], lsma + 64 * i, 8, ulong2(0, 0), false); }} - for (short sc = 0; sc < 2; ++sc) {{ - // transpose=true yields b[k][n] = Bs[n][k] for C = A @ B^T. - simdgroup_load(b[sc], Bs + (16 * sg_col + 8 * sc) * BK + 8 * kk, BK, ulong2(0, 0), true); + simdgroup_barrier(mem_flags::mem_none); + for (short i = 0; i < 2; ++i) {{ + simdgroup_load(mb[i], lsmb + 64 * i, 8, ulong2(0, 0), false); }} - for (short sr = 0; sr < 2; ++sr) {{ - for (short sc = 0; sc < 2; ++sc) {{ - simdgroup_multiply_accumulate(mc[sr][sc], a[sr], b[sc], mc[sr][sc]); - }} + simdgroup_barrier(mem_flags::mem_none); + for (short i = 0; i < 8; ++i) {{ + simdgroup_multiply_accumulate(mc[i], mb[i / 4], ma[i % 4], mc[i]); }} + lsma += 8 * 64; + lsmb += 4 * 64; }} - threadgroup_barrier(mem_flags::mem_threadgroup); }} - for (short sr = 0; sr < 2; ++sr) {{ - for (short sc = 0; sc < 2; ++sc) {{ - simdgroup_store(mc[sr][sc], Cs + (16 * sg_row + 8 * sr) * BN + (16 * sg_col + 8 * sc), BN, ulong2(0, 0), false); - }} - }} + // --- Write results: always via threadgroup memory for float→InT cast --- + // Barrier needed: sa was used for weight tiles during the K-loop and is now + // reused as float staging for the output. Without this barrier, a fast + // simdgroup could start writing mc[] into sa while a slower one is still + // reading the last weight tile via simdgroup_load(ma[]). + // (Matches ggml-metal.metal:9546 in llama.cpp's bounds-checked write path.) threadgroup_barrier(mem_flags::mem_threadgroup); + {{ + threadgroup float * temp_str = ((threadgroup float *) sa) + + 32 * (sgitg & 1) + (16 * (sgitg >> 1)) * NR0; + for (short i = 0; i < 8; ++i) {{ + simdgroup_store(mc[i], temp_str + 8 * (i % 4) + 8 * NR0 * (i / 4), + NR0, ulong2(0, 0), false); + }} + threadgroup_barrier(mem_flags::mem_threadgroup); - for (short i = 0; i < (BM * BN) / 128; ++i) {{ - const short idx = tid + i * 128; - const short mm = idx / BN; - const short nn = idx % BN; - const int gm = tile_m0 + mm; - const int gn = tile_n0 + nn; - if (gm < M && gn < N) {{ - float v = Cs[mm * BN + nn]; -{bias_line} out[(uint)gm * (uint)N + gn] = (InT) v; + if (sgitg == 0) {{ + for (int j = tid; j < nr1; j += NR1) {{ + device InT * D = out + (uint)(r1 + j) * (uint)N + r0; + threadgroup float * Cp = ((threadgroup float *) sa) + j * NR0; + for (int i = 0; i < nr0; ++i) {{ + float v = Cp[i]; + D[i] = (InT)(v {bias_add}); + }} + }} }} }} """ @@ -450,8 +530,9 @@ def _q6k_matmul_source(has_bias: bool) -> str: # Number of simdgroups per threadgroup for the mat-vec kernel. _Q6K_MV_NSG = 4 -# Tile size for the mat-mat kernel (BM == BN); threadgroup handles a tile. -_Q6K_MM_TILE = 32 +# Tile sizes for the mat-mat kernel (from llama.cpp kernel_mul_mm). +_Q6K_MM_NR0 = 64 # weight/output rows (N dim) per threadgroup +_Q6K_MM_NR1 = 32 # activation rows (M dim) per threadgroup def _emit_q6k_matvec( @@ -529,9 +610,9 @@ def _emit_q6k_matmul( leading = emit_shape(P, x_node, x_slot, end_dim=-1) out_shape_flat = leading + [IntOrVid.from_literal(N)] - tile = _Q6K_MM_TILE - blocks_n = (N + tile - 1) // tile - grid_x = blocks_n * 128 # 128 threads (4 simdgroups) per threadgroup + # grid.x = ceil(M / NR1) * 128 threads (activation tiles) + # grid.y = ceil(N / NR0) (weight tiles) + blocks_n = (N + _Q6K_MM_NR0 - 1) // _Q6K_MM_NR0 has_bias = bias_slot is not None inputs = [P.slot_to_tid(x_slot), P.slot_to_tid(weight_slot)] @@ -540,6 +621,17 @@ def _emit_q6k_matmul( inputs.append(P.slot_to_tid(bias_slot)) input_names.append("bias") + # blocks_m_iov = ceil(M / NR1); multiply by 128 for grid.x + _, grid_x_slot = P.make_tmp_value_slot() + P.emit( + MultiplyIntNode( + a=blocks_m_iov, + b=IntOrVid.from_literal(128), + out=P.slot_to_vid(grid_x_slot), + ) + ) + grid_x_iov = IntOrVid.from_vid(P.slot_to_vid(grid_x_slot)) + P.emit( MetalKernelNode( name="gguf_q6k_matmul", @@ -548,8 +640,8 @@ def _emit_q6k_matmul( inputs=inputs, outputs=[P.slot_to_tid(out)], grid=[ - IntOrVid.from_literal(grid_x), - blocks_m_iov, + grid_x_iov, + IntOrVid.from_literal(blocks_n), IntOrVid.from_literal(1), ], threadgroup=[ @@ -623,7 +715,7 @@ def _gguf_linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: break out = P.make_or_get_slot(n) - tile = _Q6K_MM_TILE + tile = _Q6K_MM_NR1 # M-dimension tile (activation rows per threadgroup) if M == 1: # Static decode -> mat-vec. _emit_q6k_matvec(P, n, x_node, x_slot, weight_slot, bias_slot, N, K, out) From 450cf99f8d2c9361c9d003a1c062476c304220f8 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 5 Jun 2026 12:22:34 -0700 Subject: [PATCH 08/18] up --- .github/workflows/mlx.yml | 8 +- .../mlx/custom_kernel_ops/gguf/__init__.py | 23 ++ .../mlx/custom_kernel_ops/gguf/embedding.py | 97 ++++++ backends/mlx/custom_kernel_ops/gguf/linear.py | 113 +++++++ .../custom_kernel_ops/gguf/q6k/__init__.py | 28 ++ .../mlx/custom_kernel_ops/gguf/q6k/common.py | 205 +++++++++++++ .../q6k/embedding.py} | 96 ++---- .../{gguf_linear.py => gguf/q6k/linear.py} | 288 +++--------------- .../custom_kernel_ops/gguf/test/__init__.py | 5 + .../test/test_embedding.py} | 8 +- .../test/test_linear.py} | 12 +- backends/mlx/runtime/MLXInterpreter.h | 4 +- backends/mlx/serialization/MLXLoader.cpp.tmpl | 9 +- backends/mlx/serialization/MLXLoader.h.tmpl | 23 +- backends/mlx/serialization/generate.py | 42 ++- .../mlx/serialization/mlx_graph_serialize.py | 7 +- backends/mlx/test/test_serialization_dedup.py | 84 +++++ examples/models/gemma4_31b/gguf_loader.py | 2 +- examples/models/gemma4_31b/mlx_gguf_linear.py | 6 +- .../gemma4_31b/quant/tests/test_gguf.py | 8 +- .../gemma4_31b/tests/test_mlx_pipeline.py | 4 +- 21 files changed, 713 insertions(+), 359 deletions(-) create mode 100644 backends/mlx/custom_kernel_ops/gguf/__init__.py create mode 100644 backends/mlx/custom_kernel_ops/gguf/embedding.py create mode 100644 backends/mlx/custom_kernel_ops/gguf/linear.py create mode 100644 backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py create mode 100644 backends/mlx/custom_kernel_ops/gguf/q6k/common.py rename backends/mlx/custom_kernel_ops/{gguf_embedding.py => gguf/q6k/embedding.py} (65%) rename backends/mlx/custom_kernel_ops/{gguf_linear.py => gguf/q6k/linear.py} (68%) create mode 100644 backends/mlx/custom_kernel_ops/gguf/test/__init__.py rename backends/mlx/custom_kernel_ops/{test/test_gguf_embedding.py => gguf/test/test_embedding.py} (93%) rename backends/mlx/custom_kernel_ops/{test/test_gguf_linear.py => gguf/test/test_linear.py} (95%) create mode 100644 backends/mlx/test/test_serialization_dedup.py diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index 5175c565075..c1c7af82a06 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -77,6 +77,7 @@ jobs: backends/mlx/test/test_passes.py \ backends/mlx/test/test_pattern_utils.py \ backends/mlx/test/test_partitioner.py \ + backends/mlx/test/test_serialization_dedup.py \ examples/models/gemma4_31b/tests/test_mlx_pipeline.py \ -v echo "::endgroup::" @@ -90,11 +91,12 @@ jobs: echo "::endgroup::" echo "::group::Run custom_kernel_ops op tests" - # Run every custom_kernel_ops/test/test_*.py via its OpTestCase `run` CLI. Adding a + # Run every custom_kernel_ops/**/test/test_*.py via its OpTestCase `run` + # CLI. Recurses into per-format subpackages (e.g. gguf/test), so adding a # new op test file requires no change here. set -e - for t in backends/mlx/custom_kernel_ops/test/test_*.py; do - mod="executorch.backends.mlx.custom_kernel_ops.test.$(basename "$t" .py)" + for t in $(find backends/mlx/custom_kernel_ops -path '*/test/test_*.py' | sort); do + mod="executorch.$(echo "${t%.py}" | tr '/' '.')" echo "--- ${mod} ---" ${CONDA_RUN} python -m "${mod}" run -v done diff --git a/backends/mlx/custom_kernel_ops/gguf/__init__.py b/backends/mlx/custom_kernel_ops/gguf/__init__.py new file mode 100644 index 00000000000..2baaf73217a --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/__init__.py @@ -0,0 +1,23 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""GGUF-quantized custom ops for the MLX backend. + +Submodules: + +* :mod:`.q6k` -- shared Q6_K primitives (constants, pure-torch dequant, + Metal header). Import this for symbols; it does not register any op. +* :mod:`.linear` -- registers ``mlx::gguf_linear``. +* :mod:`.embedding` -- registers ``mlx::gguf_embedding``. + +To register an op, import the corresponding submodule for its side effect, e.g. +``import executorch.backends.mlx.custom_kernel_ops.gguf.linear # noqa: F401``. + +This package ``__init__`` is intentionally side-effect free so importing +``.q6k`` for the pure-torch dequant does not pull in the MLX builder/registry. +""" diff --git a/backends/mlx/custom_kernel_ops/gguf/embedding.py b/backends/mlx/custom_kernel_ops/gguf/embedding.py new file mode 100644 index 00000000000..0b63792c63f --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/embedding.py @@ -0,0 +1,97 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""``mlx::gguf_embedding``: embedding gather against a GGUF-quantized table. + + out = dequant(weight[indices]) + +This module is a thin **format router**: it owns the ``mlx::gguf_embedding`` op +identity (custom op, fake, and lowering registration) and dispatches on +``format`` to a per-format implementation. Only ``"q6k"`` is supported today +(see :mod:`.q6k.embedding`); other formats raise ``NotImplementedError``. + +Usage:: + + import executorch.backends.mlx.custom_kernel_ops.gguf.embedding # noqa: F401 + + out = torch.ops.mlx.gguf_embedding(weight, indices, "q6k") + # weight: (vocab, (K/256)*210) uint8 GGUF q6_K blob + # indices: (...) int + # out: (..., K) bfloat16 +""" + +from __future__ import annotations + +import torch +from torch import Tensor +from torch.fx.node import Node + + +# --------------------------------------------------------------------------- +# Custom op + eager fallback (format-agnostic shell; dispatches by format) +# --------------------------------------------------------------------------- + + +@torch.library.custom_op("mlx::gguf_embedding", mutates_args=()) +def gguf_embedding(weight: Tensor, indices: Tensor, format: str) -> Tensor: + """Gather + dequantize rows of a GGUF-quantized embedding table. + + Args: + weight: ``(vocab, row_bytes)`` uint8 GGUF quant blob. + indices: integer token ids of any shape. + format: GGUF quant type; only ``"q6k"`` supported. + + Returns: + ``(*indices.shape, K)`` bfloat16. + """ + if format == "q6k": + from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.embedding import ( + eager_embedding, + ) + + return eager_embedding(weight, indices) + raise NotImplementedError( + f"mlx::gguf_embedding: unsupported format {format!r}; only 'q6k' is supported" + ) + + +@torch.library.register_fake("mlx::gguf_embedding") +def gguf_embedding_fake(weight: Tensor, indices: Tensor, format: str) -> Tensor: + from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.common import ( + Q6K_BLOCK_BYTES, + QK_K, + ) + + row_bytes = weight.shape[1] + K = (row_bytes // Q6K_BLOCK_BYTES) * QK_K + return indices.new_empty((*indices.shape, K), dtype=torch.bfloat16) + + +# --------------------------------------------------------------------------- +# MLX handler (format router) +# --------------------------------------------------------------------------- + +from executorch.backends.mlx.builder.op_registry import REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot + + +@REGISTRY.register(target=[torch.ops.mlx.gguf_embedding.default]) +def _gguf_embedding_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Route ``mlx::gguf_embedding`` lowering to the per-format implementation.""" + args = P.args(n) + fmt = args[2] if len(args) >= 3 else None + if fmt == "q6k": + from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.embedding import ( + emit_embedding, + ) + + return emit_embedding(P, n) + raise NotImplementedError( + f"mlx::gguf_embedding: unsupported format {fmt!r}; only 'q6k' is supported" + ) diff --git a/backends/mlx/custom_kernel_ops/gguf/linear.py b/backends/mlx/custom_kernel_ops/gguf/linear.py new file mode 100644 index 00000000000..42999c589e4 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/linear.py @@ -0,0 +1,113 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""``mlx::gguf_linear``: linear layer against a GGUF-quantized weight. + + out = x @ dequant(weight)^T (+ bias) + +The weight is stored in the **exact GGUF packed block layout** (no repacking), +so weights converted by llama.cpp / gguf-py can be consumed directly. The +``format`` argument selects the GGUF quantization type. + +This module is a thin **format router**: it owns the ``mlx::gguf_linear`` op +identity (custom op, fake, and lowering registration) and dispatches on +``format`` to a per-format implementation. Only ``"q6k"`` is supported today +(see :mod:`.q6k.linear`); other formats raise ``NotImplementedError``. To add a +format, implement ``eager_linear`` / ``emit_linear`` in a sibling package (e.g. +``q4k``) and add a branch below. + +Usage:: + + import executorch.backends.mlx.custom_kernel_ops.gguf.linear # noqa: F401 + + out = torch.ops.mlx.gguf_linear(x, weight, "q6k", bias) + # x: (..., K) bf16 / fp16 / fp32 + # weight: (N, (K/256)*210) uint8 GGUF q6_K blob + # bias: (N,) or None activation dtype + # out: (..., N) activation dtype +""" + +from __future__ import annotations + +from typing import Optional + +import torch +from torch import Tensor +from torch.fx.node import Node + + +# --------------------------------------------------------------------------- +# Custom op + eager fallback (format-agnostic shell; dispatches by format) +# --------------------------------------------------------------------------- + + +@torch.library.custom_op("mlx::gguf_linear", mutates_args=()) +def gguf_linear( + x: Tensor, + weight: Tensor, + format: str, + bias: Optional[Tensor] = None, +) -> Tensor: + """Linear against a GGUF-quantized weight. + + Args: + x: ``(..., K)`` activations (bf16 / fp16 / fp32). + weight: ``(N, row_bytes)`` uint8 GGUF quant blob. + format: GGUF quant type; only ``"q6k"`` supported. + bias: optional ``(N,)`` of activation dtype. + + Returns: + ``(..., N)`` of activation dtype. + """ + if format == "q6k": + from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.linear import ( + eager_linear, + ) + + return eager_linear(x, weight, bias) + raise NotImplementedError( + f"mlx::gguf_linear: unsupported format {format!r}; only 'q6k' is supported" + ) + + +@torch.library.register_fake("mlx::gguf_linear") +def gguf_linear_fake( + x: Tensor, + weight: Tensor, + format: str, + bias: Optional[Tensor] = None, +) -> Tensor: + N = weight.shape[0] + out_shape = list(x.shape) + out_shape[-1] = N + return x.new_empty(out_shape, dtype=x.dtype) + + +# --------------------------------------------------------------------------- +# MLX handler (format router) +# --------------------------------------------------------------------------- + +from executorch.backends.mlx.builder.op_registry import REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot + + +@REGISTRY.register(target=[torch.ops.mlx.gguf_linear.default]) +def _gguf_linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: + """Route ``mlx::gguf_linear`` lowering to the per-format implementation.""" + args = P.args(n) + fmt = args[2] if len(args) >= 3 else None + if fmt == "q6k": + from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.linear import ( + emit_linear, + ) + + return emit_linear(P, n) + raise NotImplementedError( + f"mlx::gguf_linear: unsupported format {fmt!r}; only 'q6k' is supported" + ) diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py b/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py new file mode 100644 index 00000000000..0da3e6d3303 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py @@ -0,0 +1,28 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""GGUF **Q6_K** format implementation. + +* :mod:`.common` -- shared primitives (constants, pure-torch dequant, Metal + header). Re-exported here so ``from ...gguf.q6k import dequantize_q6_k`` stays + lightweight (no MLX builder import). +* :mod:`.linear` -- Q6_K mat-vec/mat-mat kernels + eager compute + lowering. +* :mod:`.embedding` -- Q6_K gather kernel + eager compute + lowering. + +The format-agnostic op routers live one level up in +``custom_kernel_ops.gguf.linear`` / ``.embedding`` and dispatch here on +``format``. ``.linear`` / ``.embedding`` are intentionally NOT imported here so +importing :mod:`.common` for the pure-torch dequant does not pull in the builder. +""" + +from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.common import ( # noqa: F401 + _Q6K_HEADER, + dequantize_q6_k, + Q6K_BLOCK_BYTES, + QK_K, +) diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/common.py b/backends/mlx/custom_kernel_ops/gguf/q6k/common.py new file mode 100644 index 00000000000..0b870d73a4a --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/common.py @@ -0,0 +1,205 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""Shared GGUF **Q6_K** primitives for the MLX custom ops. + +This module holds the pieces common to every Q6_K kernel (linear matmul/matvec +and the embedding gather), so format-specific op modules import from here rather +than from each other: + +* ``QK_K`` / ``Q6K_BLOCK_BYTES`` and the per-super-block byte layout constants. +* ``dequantize_q6_k`` -- the pure-torch dequant oracle (eager fallback + tests). +* ``_Q6K_HEADER`` -- the Metal header (the ``block_q6_K`` struct plus the + per-element and vectorized dequant helpers) shared by all Q6_K Metal kernels. + +Adding another GGUF format (e.g. Q4_K) should mirror this module (``q4k.py``) +and the op handlers in :mod:`.linear` / :mod:`.embedding` dispatch on ``format``. + +Q6_K layout (per 256-element super-block, 210 bytes, see llama.cpp +``block_q6_K`` in ``ggml-common.h``):: + + uint8 ql[128] # quants, lower 4 bits + uint8 qh[64] # quants, upper 2 bits + int8 scales[16] # per-16-element sub-block scales (8-bit) + half d # super-block scale + +The dequantized value for a 6-bit code ``q`` (0..63) in sub-block ``s`` is +``d * scales[s] * (q - 32)``. +""" + +from __future__ import annotations + +import torch +from torch import Tensor + + +# --------------------------------------------------------------------------- +# Q6_K constants +# --------------------------------------------------------------------------- + +QK_K = 256 +# Per-super-block byte counts. +_Q6K_QL_BYTES = QK_K // 2 # 128 +_Q6K_QH_BYTES = QK_K // 4 # 64 +_Q6K_SCALES = QK_K // 16 # 16 +_Q6K_D_BYTES = 2 # one fp16 +Q6K_BLOCK_BYTES = _Q6K_QL_BYTES + _Q6K_QH_BYTES + _Q6K_SCALES + _Q6K_D_BYTES # 210 + + +# --------------------------------------------------------------------------- +# Pure-torch dequant reference +# --------------------------------------------------------------------------- + + +def dequantize_q6_k(weight: Tensor, K: int) -> Tensor: + """Dequantize a GGUF Q6_K blob to float32. + + Args: + weight: ``(N, n_blocks * 210)`` uint8, GGUF ``block_q6_K`` layout. + K: number of logical input features (``n_blocks * 256``). + + Returns: + ``(N, K)`` float32 dequantized weight. + """ + if weight.dtype != torch.uint8: + raise ValueError(f"gguf_linear: weight must be uint8; got {weight.dtype}") + N = weight.shape[0] + nb = K // QK_K + if weight.shape[-1] != nb * Q6K_BLOCK_BYTES: + raise ValueError( + f"gguf_linear: weight row bytes {weight.shape[-1]} != " + f"{nb} blocks * {Q6K_BLOCK_BYTES}" + ) + + blocks = weight.view(N, nb, Q6K_BLOCK_BYTES) + ql = blocks[..., 0:_Q6K_QL_BYTES].to(torch.int32) + qh = blocks[..., _Q6K_QL_BYTES : _Q6K_QL_BYTES + _Q6K_QH_BYTES].to(torch.int32) + sc_off = _Q6K_QL_BYTES + _Q6K_QH_BYTES + scales = ( + blocks[..., sc_off : sc_off + _Q6K_SCALES] + .contiguous() + .view(torch.int8) + .to(torch.float32) + ) + d = ( + blocks[..., sc_off + _Q6K_SCALES : sc_off + _Q6K_SCALES + _Q6K_D_BYTES] + .contiguous() + .view(torch.float16) + .to(torch.float32) + ) # (N, nb, 1) + + y = torch.empty(N, nb, QK_K, dtype=torch.float32, device=weight.device) + # is = l // 16 over l in 0..31 -> selects which of the 8 half-scales. + is_idx = (torch.arange(32, device=weight.device) // 16).long() # (32,) + + for h in range(2): # two 128-element halves + ql_h = ql[..., h * 64 : h * 64 + 64] # (N, nb, 64) + qh_h = qh[..., h * 32 : h * 32 + 32] # (N, nb, 32) + sc_h = scales[..., h * 8 : h * 8 + 8] # (N, nb, 8) + + ql_lo = ql_h[..., 0:32] + ql_hi = ql_h[..., 32:64] + + q1 = (ql_lo & 0xF) | ((qh_h & 0x3) << 4) + q2 = (ql_hi & 0xF) | (((qh_h >> 2) & 0x3) << 4) + q3 = (ql_lo >> 4) | (((qh_h >> 4) & 0x3) << 4) + q4 = (ql_hi >> 4) | (((qh_h >> 6) & 0x3) << 4) + + sc0 = sc_h[..., is_idx + 0] + sc2 = sc_h[..., is_idx + 2] + sc4 = sc_h[..., is_idx + 4] + sc6 = sc_h[..., is_idx + 6] + + base = h * 128 + y[..., base + 0 : base + 32] = d * sc0 * (q1 - 32).to(torch.float32) + y[..., base + 32 : base + 64] = d * sc2 * (q2 - 32).to(torch.float32) + y[..., base + 64 : base + 96] = d * sc4 * (q3 - 32).to(torch.float32) + y[..., base + 96 : base + 128] = d * sc6 * (q4 - 32).to(torch.float32) + + return y.reshape(N, K) + + +# --------------------------------------------------------------------------- +# Shared Metal header +# --------------------------------------------------------------------------- + +# The GGUF block_q6_K struct (matches llama.cpp ggml-common.h; sizeof == 210, no +# padding since max align is 2) plus dequant helpers for both per-element +# (embedding) and vectorized (matmul) use. +_Q6K_HEADER = """ +#include +#include +using namespace metal; + +#define QK_K 256 + +typedef struct { + uint8_t ql[QK_K/2]; // lower 4 bits + uint8_t qh[QK_K/4]; // upper 2 bits + int8_t scales[QK_K/16]; // per-16-element sub-block scales + half d; // super-block scale +} block_q6_K; + +// Dequantize a single element at within-block position p (0..255) of a +// block_q6_K. Used by the embedding kernel. +inline float dequant_q6k_elem(device const block_q6_K * blk, int p) { + const int h = p >> 7; // which 128-element half (0/1) + const int pp = p & 127; // position within half (0..127) + const int g = pp >> 5; // group: 0=q1, 1=q2, 2=q3, 3=q4 + const int l = pp & 31; // 0..31 + device const uint8_t * ql = blk->ql + h * 64; + device const uint8_t * qh = blk->qh + h * 32; + device const int8_t * sc = blk->scales + h * 8; + const int is = l >> 4; // 0/1 + const uint8_t qhb = qh[l]; + int q; + if (g == 0) { q = (ql[l] & 0xF) | ((qhb & 0x03) << 4); } + else if (g == 1) { q = (ql[l + 32] & 0xF) | ((qhb & 0x0C) << 2); } + else if (g == 2) { q = (ql[l] >> 4) | ((qhb & 0x30) << 0); } + else { q = (ql[l + 32] >> 4) | ((qhb & 0xC0) >> 2); } + const float scale = (float) sc[is + 2 * g]; + return (float) blk->d * scale * (float)(q - 32); +} + +// Vectorized Q6_K dequantize: decodes 16 values per call into a 4x4 half +// register. Ported from llama.cpp dequantize_q6_K. `il` ranges 0..15 and +// selects which 16-element slice of the 256-element block to decode. +inline void dequantize_q6_K_16(device const block_q6_K * xb, short il, + thread half4x4 & reg) { + const half d_all = xb->d; + device const uint16_t * ql = (device const uint16_t *)xb->ql; + device const uint16_t * qh = (device const uint16_t *)xb->qh; + device const int8_t * scales = (device const int8_t *)xb->scales; + + ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1); + qh = qh + 16*(il/8) + 8*(il&1); + float sc = scales[(il%2) + 2 * ((il/2))]; + il = (il/2) & 3; + + const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303); + const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F; + const float coeff = d_all * sc; + const float ml = coeff * 32.f; + const float dl0 = coeff; + const float dl1 = dl0 / 256.f; + const float dl2 = dl0 / (256.f * 256.f); + const float dl3 = dl0 / (256.f * 256.f * 256.f); + const uint8_t shr_h = il>2 ? 2 : 0; + const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4); + const uint8_t shr_l = il>1 ? 4 : 0; + for (int i = 0; i < 4; ++i) { + const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2; + const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1; + const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l); + reg[i][0] = (half)(dl0 * ((half)(q & 0xFF)) - ml); + reg[i][1] = (half)(dl1 * ((float)(q & 0xFF00)) - ml); + reg[i][2] = (half)(dl2 * ((float)(q & 0xFF0000)) - ml); + reg[i][3] = (half)(dl3 * ((float)(q & 0xFF000000)) - ml); + } +} +""" diff --git a/backends/mlx/custom_kernel_ops/gguf_embedding.py b/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py similarity index 65% rename from backends/mlx/custom_kernel_ops/gguf_embedding.py rename to backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py index 3930ab75c90..d70cf821176 100644 --- a/backends/mlx/custom_kernel_ops/gguf_embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py @@ -6,65 +6,50 @@ # LICENSE file in the root directory of this source tree. # -""" -``mlx::gguf_embedding``: embedding gather against a GGUF-quantized table. - - out = dequant(weight[indices]) - -The embedding table is the raw GGUF ``block_q6_K`` blob (one quantized row per -vocab entry). This is the gather counterpart to ``mlx::gguf_linear`` and exists -because MLX's affine dequantize has no group_size=16 Metal kernel, so a Q6_K -embedding (group_size 16) cannot use the generic quantized-embedding path. +"""GGUF **Q6_K** embedding implementation. -``format`` selects the GGUF quant type; only ``"q6k"`` is supported. Output is -bfloat16. +Provides the two pieces the ``mlx::gguf_embedding`` router dispatches to for the +``"q6k"`` format: -Usage:: +* :func:`eager_embedding` -- pure-torch reference (the custom-op eager body). +* :func:`emit_embedding` -- lowers the op to a fused Q6_K gather Metal kernel. - import executorch.backends.mlx.custom_kernel_ops.gguf_embedding # noqa: F401 - - out = torch.ops.mlx.gguf_embedding(weight, indices, "q6k") - # weight: (vocab, (K/256)*210) uint8 GGUF q6_K blob - # indices: (...) int - # out: (..., K) bfloat16 +This is the gather counterpart to :mod:`.linear` and exists because MLX's affine +dequantize has no group_size=16 Metal kernel, so a Q6_K embedding (group_size 16) +cannot use the generic quantized-embedding path. Output is bfloat16. """ from __future__ import annotations import torch - -from executorch.backends.mlx.custom_kernel_ops.gguf_linear import ( +from executorch.backends.mlx.builder.op_helpers import ( + emit_product, + emit_shape, + torch_dtype_to_scalar_type, +) +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.common import ( _Q6K_HEADER, dequantize_q6_k, Q6K_BLOCK_BYTES, QK_K, ) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + IntOrVid, + MetalKernelNode, +) from torch import Tensor from torch.fx.node import Node # --------------------------------------------------------------------------- -# Custom op + eager fallback +# Eager reference # --------------------------------------------------------------------------- -@torch.library.custom_op("mlx::gguf_embedding", mutates_args=()) -def gguf_embedding(weight: Tensor, indices: Tensor, format: str) -> Tensor: - """Gather + dequantize rows of a GGUF-quantized embedding table. - - Args: - weight: ``(vocab, (K/256)*210)`` uint8 GGUF ``q6_K`` blob. - indices: integer token ids of any shape. - format: GGUF quant type; only ``"q6k"`` supported. - - Returns: - ``(*indices.shape, K)`` bfloat16. - """ - if format != "q6k": - raise NotImplementedError( - f"mlx::gguf_embedding: unsupported format {format!r}; only 'q6k' " - f"is supported" - ) +def eager_embedding(weight: Tensor, indices: Tensor) -> Tensor: + """Eager gather + dequantize of Q6_K embedding rows; returns bfloat16.""" if weight.dim() != 2: raise ValueError( f"mlx::gguf_embedding: weight must be 2-D (vocab, row_bytes); got " @@ -83,30 +68,10 @@ def gguf_embedding(weight: Tensor, indices: Tensor, format: str) -> Tensor: return deq.reshape(*indices.shape, K).to(torch.bfloat16) -@torch.library.register_fake("mlx::gguf_embedding") -def gguf_embedding_fake(weight: Tensor, indices: Tensor, format: str) -> Tensor: - row_bytes = weight.shape[1] - K = (row_bytes // Q6K_BLOCK_BYTES) * QK_K - return indices.new_empty((*indices.shape, K), dtype=torch.bfloat16) - - # --------------------------------------------------------------------------- -# MLX handler +# Metal kernel source # --------------------------------------------------------------------------- -from executorch.backends.mlx.builder.op_helpers import ( - emit_product, - emit_shape, - torch_dtype_to_scalar_type, -) -from executorch.backends.mlx.builder.op_registry import REGISTRY -from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder -from executorch.backends.mlx.builder.slot_manager import Slot -from executorch.backends.mlx.serialization.mlx_graph_schema import ( - IntOrVid, - MetalKernelNode, -) - # One thread per output element. grid = (K, num_idx, 1): x picks the feature j, # y picks the gathered row; each thread dequantizes a single Q6_K element. @@ -121,25 +86,18 @@ def gguf_embedding_fake(weight: Tensor, indices: Tensor, format: str) -> Tensor: """ -@REGISTRY.register(target=[torch.ops.mlx.gguf_embedding.default]) -def _gguf_embedding_handler(P: MLXProgramBuilder, n: Node) -> Slot: - """Lower ``mlx::gguf_embedding`` to a fused Q6_K gather Metal kernel.""" +def emit_embedding(P: MLXProgramBuilder, n: Node) -> Slot: + """Lower ``mlx::gguf_embedding`` (q6k) to a fused Q6_K gather Metal kernel.""" args = P.args(n) if len(args) != 3: raise ValueError( f"mlx::gguf_embedding: expected 3 args (weight, indices, format); " f"got {len(args)}" ) - weight_slot, indices_slot, fmt = args + weight_slot, indices_slot, _fmt = args weight_node = n.args[0] indices_node = n.args[1] - if fmt != "q6k": - raise NotImplementedError( - f"mlx::gguf_embedding: unsupported format {fmt!r}; only 'q6k' " - f"is supported" - ) - weight_meta = weight_node.meta["val"] if weight_meta.dim() != 2: raise NotImplementedError( diff --git a/backends/mlx/custom_kernel_ops/gguf_linear.py b/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py similarity index 68% rename from backends/mlx/custom_kernel_ops/gguf_linear.py rename to backends/mlx/custom_kernel_ops/gguf/q6k/linear.py index 484d12dbc24..e82707d8bfe 100644 --- a/backends/mlx/custom_kernel_ops/gguf_linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py @@ -6,26 +6,13 @@ # LICENSE file in the root directory of this source tree. # -""" -``mlx::gguf_linear``: linear layer against a GGUF-quantized weight. - - out = x @ dequant(weight)^T (+ bias) - -The weight is stored in the **exact GGUF packed block layout** (no repacking), -so weights converted by llama.cpp / gguf-py can be consumed directly. The -``format`` argument selects the GGUF quantization type; only ``"q6k"`` is -supported and anything else raises ``NotImplementedError``. +"""GGUF **Q6_K** linear implementation. -Q6_K layout (per 256-element super-block, 210 bytes, see llama.cpp -``block_q6_K`` in ``ggml-common.h``):: +Provides the two pieces the ``mlx::gguf_linear`` router dispatches to for the +``"q6k"`` format: - uint8 ql[128] # quants, lower 4 bits - uint8 qh[64] # quants, upper 2 bits - int8 scales[16] # per-16-element sub-block scales (8-bit) - half d # super-block scale - -The dequantized value for a 6-bit code ``q`` (0..63) in sub-block ``s`` is -``d * scales[s] * (q - 32)``. +* :func:`eager_linear` -- pure-torch reference (the custom-op eager body). +* :func:`emit_linear` -- lowers the op to fused Q6_K Metal kernels. Compute is keyed on the activation dtype (matching GGUF/llama.cpp): the Metal kernels are templated on ``InT``, accumulate in ``float32``, read ``d`` as @@ -42,16 +29,6 @@ both kernels are emitted into separate instruction chains and selected at runtime via an ``IfNode`` on ``M`` (``M > 1`` -> mat-mat, ``M == 1`` -> mat-vec). - -Usage:: - - import executorch.backends.mlx.custom_kernel_ops.gguf_linear # noqa: F401 - - out = torch.ops.mlx.gguf_linear(x, weight, "q6k", bias) - # x: (..., K) bf16 / fp16 / fp32 - # weight: (N, (K/256)*210) uint8 GGUF q6_K blob - # bias: (N,) or None activation dtype - # out: (..., N) activation dtype """ from __future__ import annotations @@ -59,124 +36,45 @@ from typing import Optional import torch +from executorch.backends.mlx.builder.op_helpers import ( + emit_product, + emit_shape, + torch_dtype_to_scalar_type, +) +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.common import ( + _Q6K_HEADER, + dequantize_q6_k, + Q6K_BLOCK_BYTES, + QK_K, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AddIntNode, + FloorDivideIntNode, + IfNode, + IntOrVid, + MetalKernelNode, + MultiplyIntNode, + SubtractIntNode, +) from torch import Tensor from torch.fx.node import Node # --------------------------------------------------------------------------- -# Q6_K constants and pure-torch dequant reference +# Eager reference # --------------------------------------------------------------------------- -QK_K = 256 -# Per-super-block byte counts. -_Q6K_QL_BYTES = QK_K // 2 # 128 -_Q6K_QH_BYTES = QK_K // 4 # 64 -_Q6K_SCALES = QK_K // 16 # 16 -_Q6K_D_BYTES = 2 # one fp16 -Q6K_BLOCK_BYTES = _Q6K_QL_BYTES + _Q6K_QH_BYTES + _Q6K_SCALES + _Q6K_D_BYTES # 210 - - -def dequantize_q6_k(weight: Tensor, K: int) -> Tensor: - """Dequantize a GGUF Q6_K blob to float32. - - Args: - weight: ``(N, n_blocks * 210)`` uint8, GGUF ``block_q6_K`` layout. - K: number of logical input features (``n_blocks * 256``). - - Returns: - ``(N, K)`` float32 dequantized weight. - """ - if weight.dtype != torch.uint8: - raise ValueError(f"gguf_linear: weight must be uint8; got {weight.dtype}") - N = weight.shape[0] - nb = K // QK_K - if weight.shape[-1] != nb * Q6K_BLOCK_BYTES: - raise ValueError( - f"gguf_linear: weight row bytes {weight.shape[-1]} != " - f"{nb} blocks * {Q6K_BLOCK_BYTES}" - ) - - blocks = weight.view(N, nb, Q6K_BLOCK_BYTES) - ql = blocks[..., 0:_Q6K_QL_BYTES].to(torch.int32) - qh = blocks[..., _Q6K_QL_BYTES : _Q6K_QL_BYTES + _Q6K_QH_BYTES].to(torch.int32) - sc_off = _Q6K_QL_BYTES + _Q6K_QH_BYTES - scales = ( - blocks[..., sc_off : sc_off + _Q6K_SCALES] - .contiguous() - .view(torch.int8) - .to(torch.float32) - ) - d = ( - blocks[..., sc_off + _Q6K_SCALES : sc_off + _Q6K_SCALES + _Q6K_D_BYTES] - .contiguous() - .view(torch.float16) - .to(torch.float32) - ) # (N, nb, 1) - - y = torch.empty(N, nb, QK_K, dtype=torch.float32, device=weight.device) - # is = l // 16 over l in 0..31 -> selects which of the 8 half-scales. - is_idx = (torch.arange(32, device=weight.device) // 16).long() # (32,) - - for h in range(2): # two 128-element halves - ql_h = ql[..., h * 64 : h * 64 + 64] # (N, nb, 64) - qh_h = qh[..., h * 32 : h * 32 + 32] # (N, nb, 32) - sc_h = scales[..., h * 8 : h * 8 + 8] # (N, nb, 8) - - ql_lo = ql_h[..., 0:32] - ql_hi = ql_h[..., 32:64] - - q1 = (ql_lo & 0xF) | ((qh_h & 0x3) << 4) - q2 = (ql_hi & 0xF) | (((qh_h >> 2) & 0x3) << 4) - q3 = (ql_lo >> 4) | (((qh_h >> 4) & 0x3) << 4) - q4 = (ql_hi >> 4) | (((qh_h >> 6) & 0x3) << 4) - - sc0 = sc_h[..., is_idx + 0] - sc2 = sc_h[..., is_idx + 2] - sc4 = sc_h[..., is_idx + 4] - sc6 = sc_h[..., is_idx + 6] - - base = h * 128 - y[..., base + 0 : base + 32] = d * sc0 * (q1 - 32).to(torch.float32) - y[..., base + 32 : base + 64] = d * sc2 * (q2 - 32).to(torch.float32) - y[..., base + 64 : base + 96] = d * sc4 * (q3 - 32).to(torch.float32) - y[..., base + 96 : base + 128] = d * sc6 * (q4 - 32).to(torch.float32) - return y.reshape(N, K) - - -# --------------------------------------------------------------------------- -# Custom op + eager fallback -# --------------------------------------------------------------------------- - - -@torch.library.custom_op("mlx::gguf_linear", mutates_args=()) -def gguf_linear( - x: Tensor, - weight: Tensor, - format: str, - bias: Optional[Tensor] = None, -) -> Tensor: - """Linear against a GGUF-quantized weight. - - Args: - x: ``(..., K)`` activations (bf16 / fp16 / fp32). - weight: ``(N, (K/256)*210)`` uint8 GGUF ``q6_K`` blob. - format: GGUF quant type; only ``"q6k"`` supported. - bias: optional ``(N,)`` of activation dtype. - - Returns: - ``(..., N)`` of activation dtype. - """ - if format != "q6k": - raise NotImplementedError( - f"mlx::gguf_linear: unsupported format {format!r}; only 'q6k' is supported" - ) +def eager_linear(x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: + """Eager ``x @ dequant(weight)^T (+ bias)`` for a Q6_K weight blob.""" if weight.dim() != 2: raise ValueError( f"mlx::gguf_linear: weight must be 2-D (N, row_bytes); got " f"shape {tuple(weight.shape)}" ) - N, row_bytes = weight.shape + _N, row_bytes = weight.shape if row_bytes % Q6K_BLOCK_BYTES != 0: raise ValueError( f"mlx::gguf_linear: weight row bytes {row_bytes} must be a multiple of " @@ -195,118 +93,10 @@ def gguf_linear( return out.to(x.dtype) -@torch.library.register_fake("mlx::gguf_linear") -def gguf_linear_fake( - x: Tensor, - weight: Tensor, - format: str, - bias: Optional[Tensor] = None, -) -> Tensor: - N = weight.shape[0] - out_shape = list(x.shape) - out_shape[-1] = N - return x.new_empty(out_shape, dtype=x.dtype) - - # --------------------------------------------------------------------------- -# MLX handler +# Metal kernel sources # --------------------------------------------------------------------------- -from executorch.backends.mlx.builder.op_helpers import ( - emit_product, - emit_shape, - torch_dtype_to_scalar_type, -) -from executorch.backends.mlx.builder.op_registry import REGISTRY -from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder -from executorch.backends.mlx.builder.slot_manager import Slot -from executorch.backends.mlx.serialization.mlx_graph_schema import ( - AddIntNode, - FloorDivideIntNode, - IfNode, - IntOrVid, - MetalKernelNode, - MultiplyIntNode, - SubtractIntNode, -) - - -# Shared Metal header: the GGUF block_q6_K struct (matches llama.cpp -# ggml-common.h; sizeof == 210, no padding since max align is 2) plus -# dequant helpers for both per-element (embedding) and vectorized (matmul). -_Q6K_HEADER = """ -#include -#include -using namespace metal; - -#define QK_K 256 - -typedef struct { - uint8_t ql[QK_K/2]; // lower 4 bits - uint8_t qh[QK_K/4]; // upper 2 bits - int8_t scales[QK_K/16]; // per-16-element sub-block scales - half d; // super-block scale -} block_q6_K; - -// Dequantize a single element at within-block position p (0..255) of a -// block_q6_K. Used by the embedding kernel. -inline float dequant_q6k_elem(device const block_q6_K * blk, int p) { - const int h = p >> 7; // which 128-element half (0/1) - const int pp = p & 127; // position within half (0..127) - const int g = pp >> 5; // group: 0=q1, 1=q2, 2=q3, 3=q4 - const int l = pp & 31; // 0..31 - device const uint8_t * ql = blk->ql + h * 64; - device const uint8_t * qh = blk->qh + h * 32; - device const int8_t * sc = blk->scales + h * 8; - const int is = l >> 4; // 0/1 - const uint8_t qhb = qh[l]; - int q; - if (g == 0) { q = (ql[l] & 0xF) | ((qhb & 0x03) << 4); } - else if (g == 1) { q = (ql[l + 32] & 0xF) | ((qhb & 0x0C) << 2); } - else if (g == 2) { q = (ql[l] >> 4) | ((qhb & 0x30) << 0); } - else { q = (ql[l + 32] >> 4) | ((qhb & 0xC0) >> 2); } - const float scale = (float) sc[is + 2 * g]; - return (float) blk->d * scale * (float)(q - 32); -} - -// Vectorized Q6_K dequantize: decodes 16 values per call into a 4x4 half -// register. Ported from llama.cpp dequantize_q6_K. `il` ranges 0..15 and -// selects which 16-element slice of the 256-element block to decode. -inline void dequantize_q6_K_16(device const block_q6_K * xb, short il, - thread half4x4 & reg) { - const half d_all = xb->d; - device const uint16_t * ql = (device const uint16_t *)xb->ql; - device const uint16_t * qh = (device const uint16_t *)xb->qh; - device const int8_t * scales = (device const int8_t *)xb->scales; - - ql = ql + 32*(il/8) + 16*((il/2)&1) + 8*(il&1); - qh = qh + 16*(il/8) + 8*(il&1); - float sc = scales[(il%2) + 2 * ((il/2))]; - il = (il/2) & 3; - - const uint32_t kmask1 = il>1 ? (il>2 ? 0xC0C0C0C0 : 0x30303030) : (il>0 ? 0x0C0C0C0C : 0x03030303); - const uint32_t kmask2 = il>1 ? 0xF0F0F0F0 : 0x0F0F0F0F; - const float coeff = d_all * sc; - const float ml = coeff * 32.f; - const float dl0 = coeff; - const float dl1 = dl0 / 256.f; - const float dl2 = dl0 / (256.f * 256.f); - const float dl3 = dl0 / (256.f * 256.f * 256.f); - const uint8_t shr_h = il>2 ? 2 : 0; - const uint8_t shl_h = il>1 ? 0 : (il>0 ? 2 : 4); - const uint8_t shr_l = il>1 ? 4 : 0; - for (int i = 0; i < 4; ++i) { - const uint32_t low = (ql[2*i] | (uint32_t)(ql[2*i+1] << 16)) & kmask2; - const uint32_t high = (qh[2*i] | (uint32_t)(qh[2*i+1] << 16)) & kmask1; - const uint32_t q = ((high << shl_h) >> shr_h) | (low >> shr_l); - reg[i][0] = (half)(dl0 * ((half)(q & 0xFF)) - ml); - reg[i][1] = (half)(dl1 * ((float)(q & 0xFF00)) - ml); - reg[i][2] = (half)(dl2 * ((float)(q & 0xFF0000)) - ml); - reg[i][3] = (half)(dl3 * ((float)(q & 0xFF000000)) - ml); - } -} -""" - # Decode mat-vec kernel, ported from llama.cpp kernel_mul_mv_q6_K_f32_impl. # Threadgroup = (32 * NSG, 1, 1): NSG simdgroups, each computing N_R0 output @@ -661,14 +451,13 @@ def _emit_q6k_matmul( ) -@REGISTRY.register(target=[torch.ops.mlx.gguf_linear.default]) -def _gguf_linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: - """Lower ``mlx::gguf_linear`` to fused Q6_K Metal kernels.""" +def emit_linear(P: MLXProgramBuilder, n: Node) -> Slot: + """Lower ``mlx::gguf_linear`` (q6k) to fused Q6_K Metal kernels.""" args = P.args(n) if len(args) == 4: - x_slot, weight_slot, fmt, bias_slot = args + x_slot, weight_slot, _fmt, bias_slot = args elif len(args) == 3: - x_slot, weight_slot, fmt = args + x_slot, weight_slot, _fmt = args bias_slot = None else: raise ValueError( @@ -678,11 +467,6 @@ def _gguf_linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: x_node = n.args[0] weight_node = n.args[1] - if fmt != "q6k": - raise NotImplementedError( - f"mlx::gguf_linear: unsupported format {fmt!r}; only 'q6k' is supported" - ) - weight_meta = weight_node.meta["val"] if weight_meta.dim() != 2: raise NotImplementedError( diff --git a/backends/mlx/custom_kernel_ops/gguf/test/__init__.py b/backends/mlx/custom_kernel_ops/gguf/test/__init__.py new file mode 100644 index 00000000000..2e41cd717f6 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/test/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/backends/mlx/custom_kernel_ops/test/test_gguf_embedding.py b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py similarity index 93% rename from backends/mlx/custom_kernel_ops/test/test_gguf_embedding.py rename to backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py index 6548e9d2785..d595b6e8c11 100644 --- a/backends/mlx/custom_kernel_ops/test/test_gguf_embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py @@ -14,18 +14,18 @@ Usage:: - python -m executorch.backends.mlx.custom_kernel_ops.test.test_gguf_embedding run - python -m executorch.backends.mlx.custom_kernel_ops.test.test_gguf_embedding run --rebuild + python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_embedding run + python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_embedding run --rebuild """ from typing import List, Tuple -import executorch.backends.mlx.custom_kernel_ops.gguf_embedding # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.gguf.embedding # noqa: F401 import torch import torch.nn as nn -from executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear import ( +from executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear import ( make_q6_k_blob, ) from executorch.backends.mlx.test.test_utils import OpTestCase diff --git a/backends/mlx/custom_kernel_ops/test/test_gguf_linear.py b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py similarity index 95% rename from backends/mlx/custom_kernel_ops/test/test_gguf_linear.py rename to backends/mlx/custom_kernel_ops/gguf/test/test_linear.py index d4b5db2b197..09e968c9858 100644 --- a/backends/mlx/custom_kernel_ops/test/test_gguf_linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py @@ -19,20 +19,20 @@ Usage:: - python -m executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear run - python -m executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear run -v - python -m executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear run --rebuild - python -m executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear eager + python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear run + python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear run -v + python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear run --rebuild + python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear eager """ from typing import List, Tuple -import executorch.backends.mlx.custom_kernel_ops.gguf_linear # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.gguf.linear # noqa: F401 import torch import torch.nn as nn -from executorch.backends.mlx.custom_kernel_ops.gguf_linear import ( +from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import ( dequantize_q6_k, Q6K_BLOCK_BYTES, QK_K, diff --git a/backends/mlx/runtime/MLXInterpreter.h b/backends/mlx/runtime/MLXInterpreter.h index 44cb3d8056a..8563ff339a7 100644 --- a/backends/mlx/runtime/MLXInterpreter.h +++ b/backends/mlx/runtime/MLXInterpreter.h @@ -990,8 +990,8 @@ inline void exec_metal_kernel( n.name, n.input_names, n.output_names, - n.source, - n.header, + n.source ? *n.source : std::string{}, + n.header ? *n.header : std::string{}, n.ensure_row_contiguous, n.atomic_outputs); diff --git a/backends/mlx/serialization/MLXLoader.cpp.tmpl b/backends/mlx/serialization/MLXLoader.cpp.tmpl index aa4716d7a4a..7017988d271 100644 --- a/backends/mlx/serialization/MLXLoader.cpp.tmpl +++ b/backends/mlx/serialization/MLXLoader.cpp.tmpl @@ -62,7 +62,8 @@ std::vector to_vector(const flatbuffers::Vector* fb_vec) { // load_instruction - AUTO-GENERATED switch statement // ============================================================================= -Instruction load_instruction(const mlx_delegate::Instruction* fb_instr) { +Instruction load_instruction( + const mlx_delegate::Instruction* fb_instr, StringPool& strpool) { Instruction instr; if (!fb_instr || !fb_instr->op()) { @@ -142,6 +143,10 @@ MLXProgram load_program(const void* data, size_t size) { check_collection_size(program.num_tensors(), "num_tensors()"); check_collection_size(program.num_values, "num_values"); + // Pool shared across all chains so identical kernel source/header blobs are + // interned once for the whole program. + StringPool strpool; + if (fb_graph->instruction_chains()) { check_collection_size(fb_graph->instruction_chains()->size(), "instruction_chains"); program.instruction_chains.reserve(fb_graph->instruction_chains()->size()); @@ -152,7 +157,7 @@ MLXProgram load_program(const void* data, size_t size) { check_collection_size(fb_chain->instructions()->size(), "instructions in chain"); chain.reserve(fb_chain->instructions()->size()); for (size_t i = 0; i < fb_chain->instructions()->size(); ++i) { - chain.push_back(load_instruction(fb_chain->instructions()->Get(static_cast(i)))); + chain.push_back(load_instruction(fb_chain->instructions()->Get(static_cast(i)), strpool)); } } program.instruction_chains.push_back(std::move(chain)); diff --git a/backends/mlx/serialization/MLXLoader.h.tmpl b/backends/mlx/serialization/MLXLoader.h.tmpl index 0930d5e00e1..8bee2c23bc8 100644 --- a/backends/mlx/serialization/MLXLoader.h.tmpl +++ b/backends/mlx/serialization/MLXLoader.h.tmpl @@ -4,9 +4,11 @@ #include #include +#include #include #include #include +#include #include #include @@ -330,8 +332,27 @@ inline SlotVariant convert_slot_variant(const mlx_delegate::SlotVariant* fb) { return SlotVariant{fb->idx(), convert_slot_type(fb->slot_type())}; } +// Interns FlatBuffer strings by pointer so identical kernel source/header +// blobs (deduplicated to a single offset by the serializer) share one +// std::string in memory. Buffers written without string sharing simply get +// one entry per node — correct, just not deduplicated. +struct StringPool { + std::unordered_map> map; + std::shared_ptr intern(const flatbuffers::String* s) { + if (!s) { + return nullptr; + } + auto& slot = map[static_cast(s)]; + if (!slot) { + slot = std::make_shared(s->str()); + } + return slot; + } +}; + // Load an instruction from FlatBuffer -Instruction load_instruction(const mlx_delegate::Instruction* fb_instr); +Instruction load_instruction( + const mlx_delegate::Instruction* fb_instr, StringPool& strpool); // Load the full MLXProgram from FlatBuffer data MLXProgram load_program(const void* data, size_t size); diff --git a/backends/mlx/serialization/generate.py b/backends/mlx/serialization/generate.py index db3d4cd2d49..fd0b5b672b0 100755 --- a/backends/mlx/serialization/generate.py +++ b/backends/mlx/serialization/generate.py @@ -627,6 +627,16 @@ def generate_python_serializers(schema: FBSSchema) -> str: " return builder.EndVector()", "", "", + "def _shared_string(builder: flatbuffers.Builder, s):", + ' """CreateString with per-buffer dedup so identical strings share one offset."""', + " if s is None:", + " return None", + " # flatbuffers' Builder dedups identical strings via its built-in", + " # sharedStrings cache; fall back to CreateString on old flatbuffers.", + ' create = getattr(builder, "CreateSharedString", None) or builder.CreateString', + " return create(s)", + "", + "", "class GeneratedOpBuilders:", ' """Mixin class with auto-generated op builder methods."""', "", @@ -714,7 +724,7 @@ def generate_python_serializers(schema: FBSSchema) -> str: " self, builder: flatbuffers.Builder, vec: List[str]", " ) -> int:", ' """Pre-build a vector of strings (offsets must be created before table Start)."""', - " offsets = [builder.CreateString(s) for s in vec]", + " offsets = [_shared_string(builder, s) for s in vec]", " builder.StartVector(4, len(offsets), 4)", " for off in reversed(offsets):", " builder.PrependUOffsetTRelative(off)", @@ -800,12 +810,12 @@ def _generate_op_builder_method(table: FBSTable) -> str: } _PY_PREBUILD_OFFSET = { - "str": "builder.CreateString(op.{name})", + "str": "_shared_string(builder, op.{name})", "int_or_vid": "self._build_int_or_vid(builder, op.{name})", "float_or_vid": "self._build_float_or_vid(builder, op.{name})", "vid_or_tid": "self._build_vid_or_tid(builder, op.{name})", "int_or_vid_or_tid": "self._build_int_or_vid_or_tid(builder, op.{name})", - "optional_str": "builder.CreateString(op.{name}) if op.{name} is not None else None", + "optional_str": "_shared_string(builder, op.{name})", } @@ -996,6 +1006,19 @@ def generate_cpp_loader_h(schema: FBSSchema) -> str: return header + result +def _is_interned_str(table, field_name) -> bool: + """Whether a string field should be loaded as an interned shared_ptr. + + Only large, frequently-duplicated kernel blobs (MetalKernelNode source/ + header) are interned so identical text shares one std::string at runtime. + """ + return ( + table is not None + and getattr(table, "name", None) == "MetalKernelNode" + and field_name in ("source", "header") + ) + + def _fbs_type_to_cpp( fbs_type: str, required: bool, @@ -1023,6 +1046,10 @@ def _fbs_type_to_cpp( cpp_type = FBS_TO_CPP.get(fbs_type, fbs_type) + # Interned strings (deduped + shared at load time) use a shared_ptr handle. + if _is_interned_str(table, fld.name if fld is not None else None): + return "std::shared_ptr" + # Handle optional types if not required: if fbs_type == "Tid": @@ -1113,7 +1140,7 @@ def _generate_loader_case(table: FBSTable) -> List[str]: fb_field_name = fld.name kind = _get_field_kind(fld, table) - load_lines = _emit_cpp_load(kind, fld.name, fb_field_name) + load_lines = _emit_cpp_load(kind, fld.name, fb_field_name, table) if load_lines is None: raise ValueError( f"Unhandled field kind '{kind}' for field '{fld.name}' in table '{table.name}'. " @@ -1145,8 +1172,13 @@ def _generate_loader_case(table: FBSTable) -> List[str]: } -def _emit_cpp_load(kind: str, name: str, fb_name: str) -> "List[str] | None": +def _emit_cpp_load( + kind: str, name: str, fb_name: str, table=None +) -> "List[str] | None": """Emit C++ load lines for a field kind, or None if kind is unrecognized.""" + # Interned string fields share one std::string via the load-time pool. + if _is_interned_str(table, name) and kind in ("str", "optional_str"): + return [f" node.{name} = strpool.intern(fb->{fb_name}());"] # Required struct / compound via converter if kind in _CPP_CONVERTER: conv = _CPP_CONVERTER[kind] diff --git a/backends/mlx/serialization/mlx_graph_serialize.py b/backends/mlx/serialization/mlx_graph_serialize.py index db5acc9048f..26c562dd7e8 100644 --- a/backends/mlx/serialization/mlx_graph_serialize.py +++ b/backends/mlx/serialization/mlx_graph_serialize.py @@ -31,6 +31,7 @@ # Import auto-generated serializers from executorch.backends.mlx.serialization._generated_serializers import ( + _shared_string, GeneratedOpBuilders, ) from executorch.backends.mlx.serialization.mlx_graph_schema import ( # noqa: F401 @@ -85,7 +86,7 @@ def _build_int_or_vid(builder: flatbuffers.Builder, iov: IntOrVid) -> int: def _build_string(builder: flatbuffers.Builder, s: str) -> int: - return builder.CreateString(s) + return _shared_string(builder, s) def _build_int_vector(builder: flatbuffers.Builder, vec: List[int]) -> int: @@ -188,7 +189,7 @@ def _build_flatbuffer(self) -> bytes: tensor_meta_vec = self._build_offset_vector(builder, tensor_meta_offsets) # 5. Build version string (must be created before the table that uses it) - version_off = builder.CreateString(self.graph.version) + version_off = _shared_string(builder, self.graph.version) # 6. Build the root MLXGraph table from executorch.backends.mlx.serialization._generated.mlx_delegate import ( @@ -280,7 +281,7 @@ def _build_slot_variant( return FBSlotVariantModule.End(builder) def _build_named_slot(self, builder: flatbuffers.Builder, ns: NamedSlot) -> int: - name_off = builder.CreateString(ns.name) + name_off = _shared_string(builder, ns.name) slot_off = self._build_slot_variant(builder, ns.slot) from executorch.backends.mlx.serialization._generated.mlx_delegate import ( diff --git a/backends/mlx/test/test_serialization_dedup.py b/backends/mlx/test/test_serialization_dedup.py new file mode 100644 index 00000000000..e28e4613384 --- /dev/null +++ b/backends/mlx/test/test_serialization_dedup.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Serializer string-dedup regression test. + +MetalKernelNode ``source``/``header`` blobs are large and repeated once per +layer. The serializer routes every string through ``_shared_string`` so +identical text is written into the FlatBuffer exactly once (multiple fields +share a single offset). The loader then interns those shared offsets into one +``std::shared_ptr`` per unique blob, so this dedup also +shrinks runtime memory for newly-produced ``.pte`` files. + +This test pins the serializer half of that behavior. +""" + +import unittest + +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + Instruction, + InstructionChain, + IntOrVid, + MetalKernelNode, + MLXGraph, + Tid, +) +from executorch.backends.mlx.serialization.mlx_graph_serialize import ( + serialize_mlx_graph, +) + + +def _graph(nodes): + chain = InstructionChain(instructions=[Instruction(op=n) for n in nodes]) + return MLXGraph( + instruction_chains=[chain], + version="test", + input_map=[], + output_map=[], + mutable_buffer_map=[], + named_slots=[], + tensor_meta=[], + ) + + +def _kernel(source, header=None): + return MetalKernelNode( + name="gguf_q6k_matmul", + source=source, + inputs=[Tid(0)], + outputs=[Tid(1)], + grid=[IntOrVid(literal=1)], + threadgroup=[IntOrVid(literal=1)], + header=header, + input_names=["x"], + output_names=["out"], + ) + + +class TestSerializationStringDedup(unittest.TestCase): + def test_identical_source_header_written_once(self): + source = "KERNEL_SOURCE_MARKER_" + "x" * 2000 + header = "KERNEL_HEADER_MARKER_" + "y" * 2000 + + nodes = [_kernel(source, header) for _ in range(5)] + buf = serialize_mlx_graph(_graph(nodes)) + + self.assertEqual(buf.count(source.encode()), 1) + self.assertEqual(buf.count(header.encode()), 1) + + def test_distinct_sources_not_merged(self): + base = "KERNEL_SOURCE_MARKER_" + "x" * 2000 + nodes = [_kernel(base + str(i)) for i in range(3)] + buf = serialize_mlx_graph(_graph(nodes)) + + # Each distinct source must still appear (the common prefix appears once + # per distinct string since the suffixes differ). + self.assertEqual(buf.count(base.encode()), 3) + + +if __name__ == "__main__": + unittest.main() diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 2954a7c7644..9b1466af63f 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -205,7 +205,7 @@ def load_gguf_model( continue # Fallback: unpack Q6_K to a quantized tensor. - from executorch.backends.mlx.custom_kernel_ops.gguf_linear import ( + from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import ( Q6K_BLOCK_BYTES, QK_K, ) diff --git a/examples/models/gemma4_31b/mlx_gguf_linear.py b/examples/models/gemma4_31b/mlx_gguf_linear.py index b8cb18ad47e..a18e6cddf60 100644 --- a/examples/models/gemma4_31b/mlx_gguf_linear.py +++ b/examples/models/gemma4_31b/mlx_gguf_linear.py @@ -15,12 +15,12 @@ from __future__ import annotations # Importing the op modules registers the custom ops. -import executorch.backends.mlx.custom_kernel_ops.gguf_embedding # noqa: F401 -import executorch.backends.mlx.custom_kernel_ops.gguf_linear # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.gguf.embedding # noqa: F401 +import executorch.backends.mlx.custom_kernel_ops.gguf.linear # noqa: F401 import torch import torch.nn as nn -from executorch.backends.mlx.custom_kernel_ops.gguf_linear import Q6K_BLOCK_BYTES, QK_K +from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import Q6K_BLOCK_BYTES, QK_K class GGUFLinear(nn.Module): diff --git a/examples/models/gemma4_31b/quant/tests/test_gguf.py b/examples/models/gemma4_31b/quant/tests/test_gguf.py index f42f2e7aeb6..b2324c9eadb 100644 --- a/examples/models/gemma4_31b/quant/tests/test_gguf.py +++ b/examples/models/gemma4_31b/quant/tests/test_gguf.py @@ -302,9 +302,7 @@ def test_raw_blob_preserves_bytes(self): ) def test_raw_blob_dequant_matches_gguf_reference(self): - from executorch.backends.mlx.custom_kernel_ops.gguf_linear import ( - dequantize_q6_k, - ) + from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import dequantize_q6_k from executorch.examples.models.gemma4_31b.quant.gguf import _raw_q6_k d, scales_16, qvals, block = self._block() @@ -322,9 +320,7 @@ def test_raw_blob_dequant_matches_gguf_lib(self): # dequantize Q6_K even though it cannot quantize to it). import gguf - from executorch.backends.mlx.custom_kernel_ops.gguf_linear import ( - dequantize_q6_k, - ) + from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import dequantize_q6_k from executorch.examples.models.gemma4_31b.quant.gguf import _raw_q6_k _, _, _, block = self._block() diff --git a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py index 3f9d3cc96e9..604e48f4923 100644 --- a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py @@ -327,7 +327,7 @@ class TestGgufLinearMlx(unittest.TestCase): """Q6_K weights route to the fused mlx::gguf_linear op (raw-blob path).""" def _make_blob(self, N: int, K: int) -> torch.Tensor: - from executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear import ( + from executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear import ( make_q6_k_blob, ) @@ -398,7 +398,7 @@ class TestGgufEmbeddingMlx(unittest.TestCase): """Q6_K token embedding routes to the fused mlx::gguf_embedding op.""" def _make_blob(self, vocab: int, K: int) -> torch.Tensor: - from executorch.backends.mlx.custom_kernel_ops.test.test_gguf_linear import ( + from executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear import ( make_q6_k_blob, ) From 9ec2e0e94410f16333c31637f76579ddd1257d50 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 5 Jun 2026 12:34:17 -0700 Subject: [PATCH 09/18] up --- backends/mlx/custom_kernel_ops/gguf/q6k/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py b/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py index e82707d8bfe..a8bfed41f5b 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py @@ -293,7 +293,7 @@ def _q6k_matmul_source(has_bias: bool) -> str: // reused as float staging for the output. Without this barrier, a fast // simdgroup could start writing mc[] into sa while a slower one is still // reading the last weight tile via simdgroup_load(ma[]). - // (Matches ggml-metal.metal:9546 in llama.cpp's bounds-checked write path.) + // (Mirrors the barrier in llama.cpp kernel_mul_mm's bounds-checked write path.) threadgroup_barrier(mem_flags::mem_threadgroup); {{ threadgroup float * temp_str = ((threadgroup float *) sa) From 5d197f5b1f3eb870cd318bf482bee3d7a49e678c Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 5 Jun 2026 13:01:53 -0700 Subject: [PATCH 10/18] up --- backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py index d595b6e8c11..f02f3692ffc 100644 --- a/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py @@ -63,7 +63,10 @@ def get_test_configs(cls) -> List["GGUFEmbeddingTest"]: cls(vocab=512, K=1024, idx_shape=(4,)), cls(vocab=300, K=256, idx_shape=(16,)), # vocab not tile-aligned cls(vocab=512, K=256, idx_shape=(2, 3)), # multi-dim indices - cls(vocab=262144, K=5376, idx_shape=(8,)), # real Gemma-4-31B embed + # Real Gemma-4-31B embed width (K=5376, 21 Q6_K blocks/row). Vocab is + # kept small so the packed weight fits CI-runner GPU buffer limits; the + # gather + per-row dequant path is identical regardless of vocab. + cls(vocab=2048, K=5376, idx_shape=(8,)), ] def create_model(self) -> nn.Module: From 1ad2e0f57d7e3921a079f76435256ffff1eed221 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Fri, 5 Jun 2026 16:05:46 -0700 Subject: [PATCH 11/18] up --- backends/mlx/custom_kernel_ops/gguf/q6k/common.py | 7 +++++++ backends/mlx/custom_kernel_ops/gguf/q6k/linear.py | 8 ++++++++ .../mlx/custom_kernel_ops/gguf/test/test_embedding.py | 2 -- backends/mlx/custom_kernel_ops/gguf/test/test_linear.py | 8 ++++---- examples/models/gemma4_31b/mlx_gguf_linear.py | 1 - examples/models/gemma4_31b/quant/tests/test_gguf.py | 1 - examples/models/gemma4_31b/tests/test_mlx_pipeline.py | 3 --- 7 files changed, 19 insertions(+), 11 deletions(-) diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/common.py b/backends/mlx/custom_kernel_ops/gguf/q6k/common.py index 0b870d73a4a..405cc8aba64 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q6k/common.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/common.py @@ -30,6 +30,13 @@ The dequantized value for a 6-bit code ``q`` (0..63) in sub-block ``s`` is ``d * scales[s] * (q - 32)``. + +Attribution +----------- +The Q6_K block layout, the Metal dequant helpers in ``_Q6K_HEADER``, and the +pure-torch ``dequantize_q6_k`` reference follow llama.cpp +(``ggml-common.h`` / ``ggml-metal.metal``: ``block_q6_K``, ``dequantize_q6_K``), +which is MIT-licensed (Copyright (c) 2023-2024 The ggml authors). """ from __future__ import annotations diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py b/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py index a8bfed41f5b..581612588f1 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py @@ -29,6 +29,14 @@ both kernels are emitted into separate instruction chains and selected at runtime via an ``IfNode`` on ``M`` (``M > 1`` -> mat-mat, ``M == 1`` -> mat-vec). + +Attribution +----------- +The Q6_K Metal kernels and dequant routines here are ported from llama.cpp +(``ggml/src/ggml-metal/ggml-metal.metal`` -- ``kernel_mul_mv_q6_K_f32_impl``, +``kernel_mul_mm``, ``dequantize_q6_K``), which is MIT-licensed +(Copyright (c) 2023-2024 The ggml authors). Inline ``ported from ...`` notes +point at the specific upstream function for each kernel. """ from __future__ import annotations diff --git a/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py index f02f3692ffc..1c5c7695bf2 100644 --- a/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py @@ -21,10 +21,8 @@ from typing import List, Tuple import executorch.backends.mlx.custom_kernel_ops.gguf.embedding # noqa: F401 - import torch import torch.nn as nn - from executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear import ( make_q6_k_blob, ) diff --git a/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py index 09e968c9858..f91f133547c 100644 --- a/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py @@ -28,16 +28,13 @@ from typing import List, Tuple import executorch.backends.mlx.custom_kernel_ops.gguf.linear # noqa: F401 - import torch import torch.nn as nn - from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import ( dequantize_q6_k, Q6K_BLOCK_BYTES, QK_K, ) - from executorch.backends.mlx.test.test_utils import OpTestCase @@ -138,7 +135,10 @@ def get_test_configs(cls) -> List["GGUFLinearTest"]: cfgs.append(cls(M=1, N=4096, K=5376, dtype=torch.bfloat16)) # attn_v cfgs.append(cls(M=1, N=5376, K=21504, dtype=torch.bfloat16)) # ffn_down cfgs.append(cls(M=8, N=5376, K=21504, dtype=torch.bfloat16)) # ffn_down prefill - cfgs.append(cls(M=1, N=262144, K=5376, dtype=torch.bfloat16)) # lm_head + # lm_head: real vocab is 262144, but N is capped so the packed weight + # fits CI-runner GPU buffer limits; the mat-vec N-tiling path is the + # same at any N. + cfgs.append(cls(M=1, N=16384, K=5376, dtype=torch.bfloat16)) # lm_head return cfgs def create_model(self) -> nn.Module: diff --git a/examples/models/gemma4_31b/mlx_gguf_linear.py b/examples/models/gemma4_31b/mlx_gguf_linear.py index a18e6cddf60..23146eaf3d3 100644 --- a/examples/models/gemma4_31b/mlx_gguf_linear.py +++ b/examples/models/gemma4_31b/mlx_gguf_linear.py @@ -17,7 +17,6 @@ # Importing the op modules registers the custom ops. import executorch.backends.mlx.custom_kernel_ops.gguf.embedding # noqa: F401 import executorch.backends.mlx.custom_kernel_ops.gguf.linear # noqa: F401 - import torch import torch.nn as nn from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import Q6K_BLOCK_BYTES, QK_K diff --git a/examples/models/gemma4_31b/quant/tests/test_gguf.py b/examples/models/gemma4_31b/quant/tests/test_gguf.py index b2324c9eadb..136fcb9ec38 100644 --- a/examples/models/gemma4_31b/quant/tests/test_gguf.py +++ b/examples/models/gemma4_31b/quant/tests/test_gguf.py @@ -319,7 +319,6 @@ def test_raw_blob_dequant_matches_gguf_lib(self): # Cross-check our dequant against gguf's own Q6_K dequantizer (gguf can # dequantize Q6_K even though it cannot quantize to it). import gguf - from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import dequantize_q6_k from executorch.examples.models.gemma4_31b.quant.gguf import _raw_q6_k diff --git a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py index 604e48f4923..7bec12a76ce 100644 --- a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py @@ -20,7 +20,6 @@ import torch import torch.nn as nn - from executorch.examples.models.gemma4_31b.model import Gemma4_31B from executorch.examples.models.gemma4_31b.quant import ( DEFAULT_MLX_PACKERS, @@ -368,7 +367,6 @@ def test_replace_with_gguf_linear_swaps_module(self): def test_gguf_linear_delegates_to_mlx(self): from executorch.backends.mlx import MLXPartitioner - from executorch.examples.models.gemma4_31b.mlx_gguf_linear import GGUFLinear from executorch.exir import to_edge_transform_and_lower from torch.export import Dim, export @@ -431,7 +429,6 @@ def test_replace_with_gguf_embedding_swaps_module(self): def test_gguf_embedding_delegates_to_mlx(self): from executorch.backends.mlx import MLXPartitioner - from executorch.examples.models.gemma4_31b.mlx_gguf_linear import GGUFEmbedding from executorch.exir import to_edge_transform_and_lower from torch.export import Dim, export From 23058487cef8ae3bc47f6dd8c04424dd9620e146 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Sat, 6 Jun 2026 12:57:19 -0700 Subject: [PATCH 12/18] up --- extension/llm/export/gguf.py | 413 +++++++++++++++++++++++++ extension/llm/export/test/test_gguf.py | 199 ++++++++++++ 2 files changed, 612 insertions(+) create mode 100644 extension/llm/export/gguf.py create mode 100644 extension/llm/export/test/test_gguf.py diff --git a/extension/llm/export/gguf.py b/extension/llm/export/gguf.py new file mode 100644 index 00000000000..2d922472641 --- /dev/null +++ b/extension/llm/export/gguf.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Export-time GGUF quantized weights. + +``ExportableGGUFTensor`` is the canonical **loading representation** for a +GGUF-quantized weight: it wraps the *raw* GGUF block bytes for one tensor and +defers all unpacking. The intended flow: + +1. **Load**: ``load_gguf(path)`` -> ``dict[name -> ExportableGGUFTensor | Tensor]`` + (quantized tensors become ``ExportableGGUFTensor``; F32/F16 become plain + tensors). No unpacking happens at load. +2. **Lower (dequantize)**: used as a weight, the subclass dequantizes via the + ``torchao::gguf_dequantize`` custom op (gguf-package eager body) and runs the + plain torch ``linear`` / ``embedding`` (NVFP4-style). A backend can + pattern-match ``gguf_dequantize`` -> linear/embedding to fuse. +3. **Convert**: ``.to_int4_tensor()`` / ``.to_intx_unpacked_to_int8_tensor()`` + unpack into torchao tensor subclasses (``Int4Tensor`` for Q4_K, + ``IntxUnpackedToInt8Tensor`` for Q4_K or Q6_K) to take the non-fused + (affine-dequant) path instead. + +The GGUF quant type is identified by a **string** (``"q4_k"``, ``"q6_k"``) +everywhere user-facing (subclass attribute + ``gguf_dequantize`` op argument); the +``gguf`` package's integer ``GGMLQuantizationType`` ids are an internal lookup +detail. + +Backend-agnostic; depends on ``torch``, ``torchao``, ``numpy``, and the ``gguf`` +package. The *policy* of which tensors to convert is left to the caller. + +Attribution: the Q4_K / Q6_K block layouts follow llama.cpp / gguf-py +(``ggml-common.h``), MIT-licensed (Copyright (c) 2023-2024 The ggml authors). +""" + +from __future__ import annotations + +from typing import Dict, Iterator, Optional, Tuple, Union + +import numpy as np +import torch +from torch import Tensor +from torchao.utils import TorchAOBaseTensor + +aten = torch.ops.aten + +# --------------------------------------------------------------------------- +# GGUF k-quant constants +# --------------------------------------------------------------------------- + +QK_K = 256 # super-block size for k-quants + +Q4_K_GROUP_SIZE = QK_K // 8 # 32 (8 sub-blocks per super-block) +Q6_K_GROUP_SIZE = QK_K // 16 # 16 (16 sub-blocks per super-block) + +_Q4_K_BLOCK_BYTES = 2 + 2 + 12 + QK_K // 2 # 144 +_Q6_K_BLOCK_BYTES = 2 + QK_K // 2 + QK_K // 4 + QK_K // 16 # 210 + +# ``gguf.GGMLQuantizationType`` integer ids. +GGML_F32 = 0 +GGML_F16 = 1 +GGML_Q4_K = 12 +GGML_Q6_K = 14 + +# String quant-type names are the user-facing identifier (op arg + subclass attr). +# These dicts map names to the internal ids / block sizes. +_GGML_ID_BY_TYPE = {"q4_k": GGML_Q4_K, "q6_k": GGML_Q6_K} +_TYPE_BY_GGML_ID = {v: k for k, v in _GGML_ID_BY_TYPE.items()} +_BLOCK_BYTES_BY_TYPE = {"q4_k": _Q4_K_BLOCK_BYTES, "q6_k": _Q6_K_BLOCK_BYTES} + + +def _read_f16(raw: Tensor, col_start: int, col_end: int) -> Tensor: + """Read an fp16 field from per-block bytes, return float32.""" + return raw[:, col_start:col_end].contiguous().view(torch.float16).float() + + +def _gguf_dequantize(raw: Tensor, ggml_type: str, output_dtype: torch.dtype) -> Tensor: + """Dequantize a raw GGUF block blob to a float tensor via the ``gguf`` package. + + ``raw`` is ``(N, row_bytes)`` uint8; the result is ``(N, K)`` in + ``output_dtype``. + """ + import gguf + + if ggml_type not in _GGML_ID_BY_TYPE: + raise NotImplementedError(f"unsupported GGUF quant type {ggml_type!r}") + qtype = gguf.GGMLQuantizationType(_GGML_ID_BY_TYPE[ggml_type]) + np_raw = raw.detach().cpu().contiguous().numpy() + deq = gguf.dequantize(np_raw, qtype) + return torch.from_numpy(np.ascontiguousarray(deq)).to( + device=raw.device, dtype=output_dtype + ) + + +# --------------------------------------------------------------------------- +# Fused ops (eager = gguf.dequantize + torch op; a backend may lower to kernels) +# --------------------------------------------------------------------------- + + +@torch.library.custom_op("torchao::gguf_dequantize", mutates_args=()) +def gguf_dequantize( + weight: Tensor, + ggml_type: str, + output_dtype: torch.dtype = torch.bfloat16, +) -> Tensor: + """Dequantize a raw GGUF block blob (``(N, row_bytes)`` uint8) to ``(N, K)``.""" + return _gguf_dequantize(weight, ggml_type, output_dtype) + + +@gguf_dequantize.register_fake +def _(weight, ggml_type, output_dtype=torch.bfloat16): + K = (weight.shape[1] // _BLOCK_BYTES_BY_TYPE[ggml_type]) * QK_K + return torch.empty((weight.shape[0], K), dtype=output_dtype, device=weight.device) + + +# --------------------------------------------------------------------------- +# Per-type field extraction (used by the to_*_tensor conversions) +# --------------------------------------------------------------------------- + + +def _q4_k_fields(raw: Tensor, N: int, K: int) -> Tuple[Tensor, Tensor, Tensor]: + """Decode Q4_K blocks for conversion to ``Int4Tensor``. + + Returns ``(q, eff_scale, eff_min)`` where ``q`` is ``(N, K)`` uint8 in + [0, 15], and ``eff_scale`` / ``eff_min`` are ``(N, K // 32)`` float32. + """ + n_blocks = N * (K // QK_K) + blk = raw.reshape(n_blocks, _Q4_K_BLOCK_BYTES) + + d = _read_f16(blk, 0, 2) + dmin = _read_f16(blk, 2, 4) + s = blk[:, 4:16] + qs = blk[:, 16:144] + + sc = torch.empty(n_blocks, 8, dtype=torch.float32) + mn = torch.empty(n_blocks, 8, dtype=torch.float32) + sc[:, :4] = (s[:, :4] & 0x3F).float() + mn[:, :4] = (s[:, 4:8] & 0x3F).float() + sc[:, 4:] = ((s[:, 8:12] & 0xF) | ((s[:, :4] >> 6) << 4)).float() + mn[:, 4:] = ((s[:, 8:12] >> 4) | ((s[:, 4:8] >> 6) << 4)).float() + + eff_scale = (d * sc).reshape(N, -1) + eff_min = (dmin * mn).reshape(N, -1) + + # GGUF Q4_K nibble order: 32 lows then 32 highs per sub-block pair. + low = (qs & 0x0F).to(torch.uint8) + high = ((qs >> 4) & 0x0F).to(torch.uint8) + q = torch.cat( + [ + low[:, :32], + high[:, :32], + low[:, 32:64], + high[:, 32:64], + low[:, 64:96], + high[:, 64:96], + low[:, 96:128], + high[:, 96:128], + ], + dim=-1, + ).reshape(N, K) + return q, eff_scale, eff_min + + +def _q6_k_fields(raw: Tensor, N: int, K: int) -> Tuple[Tensor, Tensor]: + """Decode Q6_K blocks for conversion to ``IntxUnpackedToInt8Tensor``. + + Returns ``(q, eff_scale)`` where ``q`` is ``(N, K)`` int8 in [-32, 31] and + ``eff_scale`` is ``(N, K // 16)`` float32. + """ + n_blocks = N * (K // QK_K) + blk = raw.reshape(n_blocks, _Q6_K_BLOCK_BYTES) + + ql = blk[:, 0:128] + qh = blk[:, 128:192] + sc = blk[:, 192:208] + d = _read_f16(blk, 208, 210) + + qh0 = qh[:, :32] + qh1 = qh[:, 32:64] + q = torch.empty(n_blocks, QK_K, dtype=torch.int16) + q[:, 0:32] = (ql[:, :32] & 0x0F) | ((qh0 & 0x03) << 4) + q[:, 32:64] = (ql[:, 32:64] & 0x0F) | (((qh0 >> 2) & 0x03) << 4) + q[:, 64:96] = ((ql[:, :32] >> 4) & 0x0F) | (((qh0 >> 4) & 0x03) << 4) + q[:, 96:128] = ((ql[:, 32:64] >> 4) & 0x0F) | (((qh0 >> 6) & 0x03) << 4) + q[:, 128:160] = (ql[:, 64:96] & 0x0F) | ((qh1 & 0x03) << 4) + q[:, 160:192] = (ql[:, 96:128] & 0x0F) | (((qh1 >> 2) & 0x03) << 4) + q[:, 192:224] = ((ql[:, 64:96] >> 4) & 0x0F) | (((qh1 >> 4) & 0x03) << 4) + q[:, 224:256] = ((ql[:, 96:128] >> 4) & 0x0F) | (((qh1 >> 6) & 0x03) << 4) + q -= 32 + + # ``sc`` bytes are signed int8 sub-block scales. + eff_scale = (d * sc.to(torch.int8).float()).reshape(N, -1) + return q.reshape(N, K).to(torch.int8), eff_scale + + +# --------------------------------------------------------------------------- +# Tensor subclass +# --------------------------------------------------------------------------- + + +class ExportableGGUFTensor(TorchAOBaseTensor): + """Wraps the raw GGUF block bytes for one quantized weight. + + Stores the exact GGUF ``block_q*_K`` byte layout (no repacking) plus the + quant type string (``"q4_k"`` / ``"q6_k"``). ``aten.linear`` / ``aten.embedding`` + dequantize via the ``torchao::gguf_dequantize`` op (then a plain + linear/embedding); :meth:`to_int4_tensor` / :meth:`to_intx_unpacked_to_int8_tensor` + convert to torchao subclasses instead. + """ + + tensor_data_names = ["raw"] + tensor_attribute_names = ["ggml_type", "orig_dtype"] + + def __new__(cls, raw: Tensor, ggml_type: str, orig_dtype: torch.dtype): + if raw.dim() != 2 or raw.dtype != torch.uint8: + raise ValueError( + f"ExportableGGUFTensor: raw must be 2-D uint8 (N, row_bytes); got " + f"shape {tuple(raw.shape)} dtype {raw.dtype}" + ) + if ggml_type not in _BLOCK_BYTES_BY_TYPE: + raise NotImplementedError( + f"ExportableGGUFTensor: unsupported quant type {ggml_type!r}; " + f"supported: {sorted(_BLOCK_BYTES_BY_TYPE)}" + ) + n, row_bytes = int(raw.shape[0]), int(raw.shape[1]) + block_bytes = _BLOCK_BYTES_BY_TYPE[ggml_type] + if row_bytes % block_bytes != 0: + raise ValueError( + f"ExportableGGUFTensor: row bytes {row_bytes} not a multiple of " + f"block bytes {block_bytes} for quant type {ggml_type!r}" + ) + K = (row_bytes // block_bytes) * QK_K + self = torch.Tensor._make_wrapper_subclass( + cls, (n, K), dtype=orig_dtype, device=raw.device, requires_grad=False + ) + self.raw = raw + self.ggml_type = ggml_type + self.orig_dtype = orig_dtype + return self + + # -- construction -------------------------------------------------------- + + @classmethod + def from_raw( + cls, + raw: Tensor, + ggml_type: str, + orig_dtype: torch.dtype = torch.bfloat16, + ) -> "ExportableGGUFTensor": + """Build from a ``(N, row_bytes)`` uint8 GGUF block blob.""" + return cls(raw.contiguous(), ggml_type, orig_dtype) + + # -- dequant (via gguf package) ------------------------------------------ + + def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> Tensor: + """Dequantize to a plain float tensor using the ``gguf`` package.""" + return torch.ops.torchao.gguf_dequantize( + self.raw, self.ggml_type, output_dtype or self.orig_dtype + ) + + # -- conversions (unpack lives here) ------------------------------------- + + def to_int4_tensor(self) -> Tensor: + """Convert a Q4_K tensor to a torchao ``Int4Tensor``.""" + from torchao.quantization.quantize_.workflows.int4.int4_tensor import ( + Int4Tensor, + ) + + if self.ggml_type != "q4_k": + raise NotImplementedError( + f"to_int4_tensor only supports q4_k; got {self.ggml_type!r}" + ) + N, K = int(self.shape[0]), int(self.shape[1]) + q, eff_scale, eff_min = _q4_k_fields(self.raw, N, K) + + zero = torch.where( + eff_scale != 0, eff_min / eff_scale, torch.zeros_like(eff_min) + ) + # Nibble-pack for Int4Tensor: even index -> low nibble, odd -> high. + packed = q[:, ::2] | (q[:, 1::2] << 4) + return Int4Tensor( + qdata=packed, + # Int4Tensor scale/zero layout is (K // gs, N) -- transposed. + scale=eff_scale.to(torch.bfloat16).t().contiguous(), + zero_point=zero.to(torch.bfloat16).t().contiguous(), + block_size=[1, Q4_K_GROUP_SIZE], + shape=torch.Size([N, K]), + ) + + def to_intx_unpacked_to_int8_tensor(self) -> Tensor: + """Convert to a torchao ``IntxUnpackedToInt8Tensor`` (Q4_K or Q6_K). + + Q6_K maps to a symmetric int8 tensor (values [-32, 31], zero-point 0). + Q4_K maps to a 4-bit tensor: values are centered to [-8, 7] and the + affine min is folded into a (float) zero-point, so the rewrite is exact. + """ + from torchao.quantization import IntxUnpackedToInt8Tensor + + N, K = int(self.shape[0]), int(self.shape[1]) + if self.ggml_type == "q6_k": + q, eff_scale = _q6_k_fields(self.raw, N, K) + return IntxUnpackedToInt8Tensor( + qdata=q, + scale=eff_scale.to(torch.bfloat16), + zero_point=torch.zeros_like(eff_scale, dtype=torch.int8), + target_dtype=torch.int8, + block_size=(1, Q6_K_GROUP_SIZE), + dtype=torch.bfloat16, + activation_quantization=None, + ) + if self.ggml_type == "q4_k": + q, eff_scale, eff_min = _q4_k_fields(self.raw, N, K) + zero = torch.where( + eff_scale != 0, eff_min / eff_scale, torch.zeros_like(eff_min) + ) + # Center quants [0, 15] -> [-8, 7] and shift the zero-point to match + # (dequant = scale * (q - zp) is preserved). + return IntxUnpackedToInt8Tensor( + qdata=q.to(torch.int8) - 8, + scale=eff_scale.to(torch.bfloat16), + zero_point=(zero - 8).to(torch.bfloat16), + target_dtype=torch.int4, + block_size=(1, Q4_K_GROUP_SIZE), + dtype=torch.bfloat16, + activation_quantization=None, + ) + raise NotImplementedError( + f"to_intx_unpacked_to_int8_tensor supports q4_k/q6_k; " + f"got {self.ggml_type!r}" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + +implements = ExportableGGUFTensor.implements + + +@implements([aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight = args[0], args[1] + bias = args[2] if len(args) > 2 else None + return torch.nn.functional.linear( + input_tensor, weight.dequantize(input_tensor.dtype), bias + ) + + +@implements([aten.embedding.default]) +def _(func, types, args, kwargs): + weight, indices = args[0], args[1] + return torch.nn.functional.embedding(indices, weight.dequantize()) + + +@implements([aten.t.default]) +def _(func, types, args, kwargs): + return args[0].dequantize().t() + + +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return args[0] + + +@implements([aten._to_copy.default]) +def _(func, types, args, kwargs): + return args[0].dequantize(output_dtype=kwargs.get("dtype", args[0].orig_dtype)) + + +# --------------------------------------------------------------------------- +# Loader +# --------------------------------------------------------------------------- + + +def iter_gguf( + path: str, +) -> Iterator[Tuple[str, Union[ExportableGGUFTensor, Tensor]]]: + """Stream ``(name, value)`` for every tensor in a GGUF file (low peak mem). + + Quantized tensors (Q4_K, Q6_K) are wrapped as ``ExportableGGUFTensor`` with + the raw block bytes; F32/F16 are returned as plain float tensors (bf16 for + F16). GGUF shapes are reversed to PyTorch ``(N, K)`` convention. + """ + from gguf import GGMLQuantizationType, GGUFReader + + reader = GGUFReader(path) + for tensor in reader.tensors: + shape = list(reversed(tensor.shape.tolist())) + ttype = int(tensor.tensor_type) + flat = torch.frombuffer(memoryview(tensor.data), dtype=torch.uint8) + if ttype in _TYPE_BY_GGML_ID: + N = shape[0] + row_bytes = flat.numel() // N + raw = flat.reshape(N, row_bytes).clone() + yield tensor.name, ExportableGGUFTensor.from_raw( + raw, _TYPE_BY_GGML_ID[ttype] + ) + elif tensor.tensor_type == GGMLQuantizationType.F32: + yield tensor.name, flat.view(torch.float32).reshape(shape).clone() + elif tensor.tensor_type == GGMLQuantizationType.F16: + yield tensor.name, flat.view(torch.float16).reshape(shape).to( + torch.bfloat16 + ) + else: + raise ValueError(f"Unsupported GGUF quant type: {tensor.tensor_type}") + + +def load_gguf(path: str) -> Dict[str, Union[ExportableGGUFTensor, Tensor]]: + """Load a GGUF file into ``{name -> ExportableGGUFTensor | Tensor}``. + + Holds all tensors at once; use :func:`iter_gguf` for low peak memory. + """ + return dict(iter_gguf(path)) diff --git a/extension/llm/export/test/test_gguf.py b/extension/llm/export/test/test_gguf.py new file mode 100644 index 00000000000..e30eb786483 --- /dev/null +++ b/extension/llm/export/test/test_gguf.py @@ -0,0 +1,199 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for ``extension/llm/export/gguf.py``. + +The reference oracle is the ``gguf`` package's own ``gguf.dequantize`` (which can +dequantize Q4_K / Q6_K). We validate that: + +* ``ExportableGGUFTensor.dequantize`` (and the fused ``torchao::gguf_*`` ops, + whose eager bodies use ``gguf``) reproduce ``gguf.dequantize``; +* our hand-written ``to_int4_tensor`` / ``to_intx_unpacked_to_int8_tensor`` + unpack matches ``gguf.dequantize`` (within bf16 storage tolerance); +* using the subclass as a weight dispatches linear/embedding to the fused ops. + +Blocks are crafted with a small fp16 super-block scale and fixed mid-range +sub-scales so dequantized magnitudes are O(1) and bf16 round-trip error is small +and deterministic (random sub-scales can produce near-zero effective scales, +which blow up the bf16 zero-point error for Q4_K). +""" + +import unittest + +import numpy as np +import torch + +try: + import gguf + from gguf import GGMLQuantizationType + + _HAS_GGUF = True +except ImportError: + _HAS_GGUF = False + +from executorch.extension.llm.export.gguf import ( + _Q4_K_BLOCK_BYTES, + _Q6_K_BLOCK_BYTES, + ExportableGGUFTensor, + Q4_K_GROUP_SIZE, +) + + +def _fp16_bytes(x: float) -> torch.Tensor: + return torch.tensor([x], dtype=torch.float16).view(torch.uint8) + + +def _make_q4k_raw(N: int, nb: int, seed: int = 0) -> torch.Tensor: + """A ``(N, nb*144)`` uint8 Q4_K blob with sane, deterministic magnitudes.""" + g = torch.Generator().manual_seed(seed) + blk = torch.randint(0, 256, (N * nb, _Q4_K_BLOCK_BYTES), dtype=torch.uint8, generator=g) + blk[:, 0:2] = _fp16_bytes(0.01) # d + blk[:, 2:4] = _fp16_bytes(0.01) # dmin + blk[:, 4:16] = 0x21 # fixed mid-range 6-bit sub-scales/mins (non-zero) + return blk.reshape(N, nb * _Q4_K_BLOCK_BYTES) + + +def _make_q6k_raw(N: int, nb: int, seed: int = 0) -> torch.Tensor: + """A ``(N, nb*210)`` uint8 Q6_K blob with sane, deterministic magnitudes.""" + g = torch.Generator().manual_seed(seed) + blk = torch.randint(0, 256, (N * nb, _Q6_K_BLOCK_BYTES), dtype=torch.uint8, generator=g) + blk[:, 192:208] = 0x10 # fixed int8 sub-scales (non-zero) + blk[:, 208:210] = _fp16_bytes(0.01) # d + return blk.reshape(N, nb * _Q6_K_BLOCK_BYTES) + + +def _gguf_ref(raw: torch.Tensor, qtype) -> torch.Tensor: + return torch.from_numpy(np.asarray(gguf.dequantize(raw.numpy(), qtype))).float() + + +def _int4_to_float(w) -> torch.Tensor: + """Dequantize an ``Int4Tensor`` from its stored fields (no fbgemm needed).""" + N, K = int(w.shape[0]), int(w.shape[1]) + gs = w.block_size[1] + q = torch.empty(N, K, dtype=torch.float32) + q[:, ::2] = (w.qdata & 0x0F).float() + q[:, 1::2] = (w.qdata >> 4).float() + scale = w.scale.t().float().repeat_interleave(gs, dim=1) + zero = w.zero_point.t().float().repeat_interleave(gs, dim=1) + return scale * (q - zero) + + +def _rel_max(a: torch.Tensor, b: torch.Tensor) -> float: + return (a - b).abs().max().item() / (b.abs().max().item() + 1e-9) + + +@unittest.skipUnless(_HAS_GGUF, "gguf package not installed") +class TestExportableGGUFTensor(unittest.TestCase): + def test_dequantize_matches_gguf(self): + for ggml_type, qtype, make in ( + ("q4_k", GGMLQuantizationType.Q4_K, _make_q4k_raw), + ("q6_k", GGMLQuantizationType.Q6_K, _make_q6k_raw), + ): + raw = make(N=3, nb=2) + t = ExportableGGUFTensor.from_raw(raw, ggml_type) + self.assertEqual(tuple(t.shape), (3, 2 * 256)) + mine = t.dequantize(torch.float32) + ref = _gguf_ref(raw, qtype) + # .dequantize() routes through gguf, so it should match exactly. + self.assertTrue(torch.equal(mine, ref), f"{qtype}") + + def test_to_intx_unpacked_matches_reference(self): + # Reference is the gguf-package dequant (ExportableGGUFTensor.dequantize); + # the Intx tensor's dequantize exercises our unpacking. Covers Q4_K & Q6_K. + for ggml_type, make in (("q4_k", _make_q4k_raw), ("q6_k", _make_q6k_raw)): + raw = make(N=3, nb=2) + t = ExportableGGUFTensor.from_raw(raw, ggml_type) + ix = t.to_intx_unpacked_to_int8_tensor() + self.assertEqual(tuple(ix.shape), (3, 512)) + rel = _rel_max(ix.dequantize().float(), t.dequantize(torch.float32)) + self.assertLess(rel, 1e-2, f"{ggml_type} to_intx rel err {rel}") + + def test_to_int4_tensor_matches_reference(self): + raw = _make_q4k_raw(N=3, nb=2) + t = ExportableGGUFTensor.from_raw(raw, "q4_k") + w = t.to_int4_tensor() + self.assertEqual(tuple(w.shape), (3, 512)) + self.assertEqual(list(w.block_size), [1, Q4_K_GROUP_SIZE]) + # Int4Tensor has no CPU dequantize(); reconstruct from its packed fields + # (this still exercises our nibble-packing) against the gguf reference. + rel = _rel_max(_int4_to_float(w), t.dequantize(torch.float32)) + self.assertLess(rel, 1e-2, f"Q4_K to_int4 rel err {rel}") + + def test_gguf_dequantize_op_matches_reference(self): + for ggml_type, make in (("q4_k", _make_q4k_raw), ("q6_k", _make_q6k_raw)): + raw = make(N=3, nb=2) + t = ExportableGGUFTensor.from_raw(raw, ggml_type) + out = torch.ops.torchao.gguf_dequantize(raw, ggml_type, torch.float32) + self.assertTrue(torch.equal(out, t.dequantize(torch.float32))) + + def test_subclass_linear_dispatches_to_dequant(self): + raw = _make_q6k_raw(N=4, nb=1) + t = ExportableGGUFTensor.from_raw(raw, "q6_k") + x = torch.randn(2, 256, dtype=torch.bfloat16) + out = torch.nn.functional.linear(x, t) + ref = torch.nn.functional.linear(x, t.dequantize(torch.bfloat16)) + self.assertTrue(torch.equal(out, ref)) + + def test_subclass_embedding_dispatches_to_dequant(self): + raw = _make_q6k_raw(N=8, nb=1) + t = ExportableGGUFTensor.from_raw(raw, "q6_k") + idx = torch.tensor([0, 3, 7, 1]) + out = torch.nn.functional.embedding(idx, t) + ref = torch.nn.functional.embedding(idx, t.dequantize(torch.bfloat16)) + self.assertTrue(torch.equal(out, ref)) + + def test_unsupported_type_raises(self): + raw = torch.zeros(1, _Q6_K_BLOCK_BYTES, dtype=torch.uint8) + with self.assertRaises(NotImplementedError): + ExportableGGUFTensor.from_raw(raw, "q5_k") + + +@unittest.skipUnless(_HAS_GGUF, "gguf package not installed") +class TestExportableGGUFTensorExport(unittest.TestCase): + """Exporting a module whose weight is an ``ExportableGGUFTensor`` should + lower linear/embedding through the ``torchao::gguf_dequantize`` op after + ``run_decompositions`` (the subclass dispatch fires during decomposition).""" + + @staticmethod + def _targets(ep): + return {str(n.target) for n in ep.graph.nodes if n.op == "call_function"} + + def test_linear_exports_with_gguf_dequantize(self): + t = ExportableGGUFTensor.from_raw(_make_q6k_raw(N=4, nb=1), "q6_k") + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.nn.Parameter(t, requires_grad=False) + + def forward(self, x): + return torch.nn.functional.linear(x, self.w) + + ep = torch.export.export( + M(), (torch.randn(2, 256, dtype=torch.bfloat16),) + ).run_decompositions({}) + self.assertIn("torchao.gguf_dequantize.default", self._targets(ep)) + + def test_embedding_exports_with_gguf_dequantize(self): + t = ExportableGGUFTensor.from_raw(_make_q6k_raw(N=8, nb=1), "q6_k") + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.nn.Parameter(t, requires_grad=False) + + def forward(self, idx): + return torch.nn.functional.embedding(idx, self.w) + + ep = torch.export.export( + M(), (torch.tensor([0, 1, 2, 3]),) + ).run_decompositions({}) + self.assertIn("torchao.gguf_dequantize.default", self._targets(ep)) + + +if __name__ == "__main__": + unittest.main() From 9dac8a370ac30d8762a4f6759c40ee9d1dc6f83d Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Sun, 7 Jun 2026 09:59:55 -0700 Subject: [PATCH 13/18] up --- .../mlx/custom_kernel_ops/gguf/__init__.py | 17 +- .../mlx/custom_kernel_ops/gguf/embedding.py | 97 ----- backends/mlx/custom_kernel_ops/gguf/linear.py | 113 ------ .../mlx/custom_kernel_ops/gguf/patterns.py | 165 +++++++++ backends/mlx/custom_kernel_ops/gguf/q4k.py | 0 .../custom_kernel_ops/gguf/q4k/__init__.py | 20 + .../mlx/custom_kernel_ops/gguf/q4k/common.py | 44 +++ .../custom_kernel_ops/gguf/q4k/embedding.py | 92 +++++ .../mlx/custom_kernel_ops/gguf/q4k/linear.py | 85 +++++ .../custom_kernel_ops/gguf/q6k/__init__.py | 20 +- .../mlx/custom_kernel_ops/gguf/q6k/common.py | 87 +---- .../custom_kernel_ops/gguf/q6k/embedding.py | 71 ++-- .../mlx/custom_kernel_ops/gguf/q6k/linear.py | 87 ++--- .../gguf/test/test_embedding.py | 49 ++- .../gguf/test/test_linear.py | 191 ++++++---- examples/models/gemma4_31b/export.py | 5 + examples/models/gemma4_31b/gguf_loader.py | 167 +++------ examples/models/gemma4_31b/mlx_gguf_linear.py | 146 -------- examples/models/gemma4_31b/quant/gguf.py | 240 ------------ .../gemma4_31b/quant/tests/test_gguf.py | 347 ------------------ .../gemma4_31b/tests/test_mlx_pipeline.py | 154 ++++---- extension/llm/export/gguf.py | 4 +- extension/llm/export/test/test_gguf.py | 47 ++- 23 files changed, 793 insertions(+), 1455 deletions(-) delete mode 100644 backends/mlx/custom_kernel_ops/gguf/embedding.py delete mode 100644 backends/mlx/custom_kernel_ops/gguf/linear.py create mode 100644 backends/mlx/custom_kernel_ops/gguf/patterns.py create mode 100644 backends/mlx/custom_kernel_ops/gguf/q4k.py create mode 100644 backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py create mode 100644 backends/mlx/custom_kernel_ops/gguf/q4k/common.py create mode 100644 backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py create mode 100644 backends/mlx/custom_kernel_ops/gguf/q4k/linear.py delete mode 100644 examples/models/gemma4_31b/mlx_gguf_linear.py delete mode 100644 examples/models/gemma4_31b/quant/gguf.py delete mode 100644 examples/models/gemma4_31b/quant/tests/test_gguf.py diff --git a/backends/mlx/custom_kernel_ops/gguf/__init__.py b/backends/mlx/custom_kernel_ops/gguf/__init__.py index 2baaf73217a..4f8268ae88c 100644 --- a/backends/mlx/custom_kernel_ops/gguf/__init__.py +++ b/backends/mlx/custom_kernel_ops/gguf/__init__.py @@ -6,17 +6,20 @@ # LICENSE file in the root directory of this source tree. # -"""GGUF-quantized custom ops for the MLX backend. +"""GGUF-quantized weight lowering for the MLX backend. Submodules: -* :mod:`.q6k` -- shared Q6_K primitives (constants, pure-torch dequant, - Metal header). Import this for symbols; it does not register any op. -* :mod:`.linear` -- registers ``mlx::gguf_linear``. -* :mod:`.embedding` -- registers ``mlx::gguf_embedding``. +* :mod:`.q6k` -- shared Q6_K primitives (constants, pure-torch dequant, + Metal header) and the fused mat-vec / mat-mat / gather kernels. Importing + ``.q6k`` (or ``.q6k.common``) is lightweight and does not touch the registry. +* :mod:`.patterns` -- registers MLX pattern handlers that match + ``torchao::gguf_dequantize -> linear/embedding`` (what ``ExportableGGUFTensor`` + exports) and lower them to the Q6_K kernels. -To register an op, import the corresponding submodule for its side effect, e.g. -``import executorch.backends.mlx.custom_kernel_ops.gguf.linear # noqa: F401``. +To enable GGUF lowering, import :mod:`.patterns` for its side effect:: + + import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 This package ``__init__`` is intentionally side-effect free so importing ``.q6k`` for the pure-torch dequant does not pull in the MLX builder/registry. diff --git a/backends/mlx/custom_kernel_ops/gguf/embedding.py b/backends/mlx/custom_kernel_ops/gguf/embedding.py deleted file mode 100644 index 0b63792c63f..00000000000 --- a/backends/mlx/custom_kernel_ops/gguf/embedding.py +++ /dev/null @@ -1,97 +0,0 @@ -# -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -# - -"""``mlx::gguf_embedding``: embedding gather against a GGUF-quantized table. - - out = dequant(weight[indices]) - -This module is a thin **format router**: it owns the ``mlx::gguf_embedding`` op -identity (custom op, fake, and lowering registration) and dispatches on -``format`` to a per-format implementation. Only ``"q6k"`` is supported today -(see :mod:`.q6k.embedding`); other formats raise ``NotImplementedError``. - -Usage:: - - import executorch.backends.mlx.custom_kernel_ops.gguf.embedding # noqa: F401 - - out = torch.ops.mlx.gguf_embedding(weight, indices, "q6k") - # weight: (vocab, (K/256)*210) uint8 GGUF q6_K blob - # indices: (...) int - # out: (..., K) bfloat16 -""" - -from __future__ import annotations - -import torch -from torch import Tensor -from torch.fx.node import Node - - -# --------------------------------------------------------------------------- -# Custom op + eager fallback (format-agnostic shell; dispatches by format) -# --------------------------------------------------------------------------- - - -@torch.library.custom_op("mlx::gguf_embedding", mutates_args=()) -def gguf_embedding(weight: Tensor, indices: Tensor, format: str) -> Tensor: - """Gather + dequantize rows of a GGUF-quantized embedding table. - - Args: - weight: ``(vocab, row_bytes)`` uint8 GGUF quant blob. - indices: integer token ids of any shape. - format: GGUF quant type; only ``"q6k"`` supported. - - Returns: - ``(*indices.shape, K)`` bfloat16. - """ - if format == "q6k": - from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.embedding import ( - eager_embedding, - ) - - return eager_embedding(weight, indices) - raise NotImplementedError( - f"mlx::gguf_embedding: unsupported format {format!r}; only 'q6k' is supported" - ) - - -@torch.library.register_fake("mlx::gguf_embedding") -def gguf_embedding_fake(weight: Tensor, indices: Tensor, format: str) -> Tensor: - from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.common import ( - Q6K_BLOCK_BYTES, - QK_K, - ) - - row_bytes = weight.shape[1] - K = (row_bytes // Q6K_BLOCK_BYTES) * QK_K - return indices.new_empty((*indices.shape, K), dtype=torch.bfloat16) - - -# --------------------------------------------------------------------------- -# MLX handler (format router) -# --------------------------------------------------------------------------- - -from executorch.backends.mlx.builder.op_registry import REGISTRY -from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder -from executorch.backends.mlx.builder.slot_manager import Slot - - -@REGISTRY.register(target=[torch.ops.mlx.gguf_embedding.default]) -def _gguf_embedding_handler(P: MLXProgramBuilder, n: Node) -> Slot: - """Route ``mlx::gguf_embedding`` lowering to the per-format implementation.""" - args = P.args(n) - fmt = args[2] if len(args) >= 3 else None - if fmt == "q6k": - from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.embedding import ( - emit_embedding, - ) - - return emit_embedding(P, n) - raise NotImplementedError( - f"mlx::gguf_embedding: unsupported format {fmt!r}; only 'q6k' is supported" - ) diff --git a/backends/mlx/custom_kernel_ops/gguf/linear.py b/backends/mlx/custom_kernel_ops/gguf/linear.py deleted file mode 100644 index 42999c589e4..00000000000 --- a/backends/mlx/custom_kernel_ops/gguf/linear.py +++ /dev/null @@ -1,113 +0,0 @@ -# -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -# - -"""``mlx::gguf_linear``: linear layer against a GGUF-quantized weight. - - out = x @ dequant(weight)^T (+ bias) - -The weight is stored in the **exact GGUF packed block layout** (no repacking), -so weights converted by llama.cpp / gguf-py can be consumed directly. The -``format`` argument selects the GGUF quantization type. - -This module is a thin **format router**: it owns the ``mlx::gguf_linear`` op -identity (custom op, fake, and lowering registration) and dispatches on -``format`` to a per-format implementation. Only ``"q6k"`` is supported today -(see :mod:`.q6k.linear`); other formats raise ``NotImplementedError``. To add a -format, implement ``eager_linear`` / ``emit_linear`` in a sibling package (e.g. -``q4k``) and add a branch below. - -Usage:: - - import executorch.backends.mlx.custom_kernel_ops.gguf.linear # noqa: F401 - - out = torch.ops.mlx.gguf_linear(x, weight, "q6k", bias) - # x: (..., K) bf16 / fp16 / fp32 - # weight: (N, (K/256)*210) uint8 GGUF q6_K blob - # bias: (N,) or None activation dtype - # out: (..., N) activation dtype -""" - -from __future__ import annotations - -from typing import Optional - -import torch -from torch import Tensor -from torch.fx.node import Node - - -# --------------------------------------------------------------------------- -# Custom op + eager fallback (format-agnostic shell; dispatches by format) -# --------------------------------------------------------------------------- - - -@torch.library.custom_op("mlx::gguf_linear", mutates_args=()) -def gguf_linear( - x: Tensor, - weight: Tensor, - format: str, - bias: Optional[Tensor] = None, -) -> Tensor: - """Linear against a GGUF-quantized weight. - - Args: - x: ``(..., K)`` activations (bf16 / fp16 / fp32). - weight: ``(N, row_bytes)`` uint8 GGUF quant blob. - format: GGUF quant type; only ``"q6k"`` supported. - bias: optional ``(N,)`` of activation dtype. - - Returns: - ``(..., N)`` of activation dtype. - """ - if format == "q6k": - from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.linear import ( - eager_linear, - ) - - return eager_linear(x, weight, bias) - raise NotImplementedError( - f"mlx::gguf_linear: unsupported format {format!r}; only 'q6k' is supported" - ) - - -@torch.library.register_fake("mlx::gguf_linear") -def gguf_linear_fake( - x: Tensor, - weight: Tensor, - format: str, - bias: Optional[Tensor] = None, -) -> Tensor: - N = weight.shape[0] - out_shape = list(x.shape) - out_shape[-1] = N - return x.new_empty(out_shape, dtype=x.dtype) - - -# --------------------------------------------------------------------------- -# MLX handler (format router) -# --------------------------------------------------------------------------- - -from executorch.backends.mlx.builder.op_registry import REGISTRY -from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder -from executorch.backends.mlx.builder.slot_manager import Slot - - -@REGISTRY.register(target=[torch.ops.mlx.gguf_linear.default]) -def _gguf_linear_handler(P: MLXProgramBuilder, n: Node) -> Slot: - """Route ``mlx::gguf_linear`` lowering to the per-format implementation.""" - args = P.args(n) - fmt = args[2] if len(args) >= 3 else None - if fmt == "q6k": - from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.linear import ( - emit_linear, - ) - - return emit_linear(P, n) - raise NotImplementedError( - f"mlx::gguf_linear: unsupported format {fmt!r}; only 'q6k' is supported" - ) diff --git a/backends/mlx/custom_kernel_ops/gguf/patterns.py b/backends/mlx/custom_kernel_ops/gguf/patterns.py new file mode 100644 index 00000000000..10afef99e11 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/patterns.py @@ -0,0 +1,165 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""MLX pattern handlers for GGUF-quantized weights. + +``ExportableGGUFTensor`` (extension/llm/export/gguf.py) lowers a quantized +linear/embedding to:: + + linear(x, torchao::gguf_dequantize(weight, ggml_type, out_dtype), bias) + embedding(torchao::gguf_dequantize(weight, ggml_type, out_dtype), indices) + +These handlers match that ``gguf_dequantize -> linear/embedding`` subgraph and +lower it without materializing the dequantized weight: + +* **Q6_K** -> fused custom Metal kernels in :mod:`.q6k` (linear + embedding). +* **Q4_K** -> MLX's native 4-bit ``quantized_matmul`` via :mod:`.q4k` (linear); + the GGUF blocks are repacked into MLX qparams at export time. + +Other quant types are left unmatched (the caller is expected to convert them to a +torchao ``Int4Tensor`` / ``IntxUnpackedToInt8Tensor`` first). + +Importing this module registers the patterns as a side effect. +""" + +from __future__ import annotations + +from typing import Optional, Tuple + +import torch +from executorch.backends.mlx.builder.op_helpers import get_aten_target +from executorch.backends.mlx.builder.op_registry import PatternHandler, REGISTRY +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.pattern_utils import has_single_user, match_target +from torch.export.exported_program import ExportedProgram +from torch.fx.node import Node + +# Quant types each pattern can lower (linear has both a custom Q6_K kernel and an +# MLX-native Q4_K path; embedding only has the Q6_K gather kernel). +_LINEAR_TYPES = {"q4_k", "q6_k"} +_EMBEDDING_TYPES = {"q4_k", "q6_k"} + + +def parse_gguf_dequantize_node( + node: Node, +) -> Optional[Tuple[Node, str, torch.dtype]]: + """Parse a ``torchao::gguf_dequantize`` node. + + Returns ``(weight_node, ggml_type, output_dtype)`` or ``None`` if ``node`` is + not a ``gguf_dequantize`` node (or the op isn't registered). + """ + try: + import executorch.extension.llm.export.gguf # noqa: F401 registers the op + except ImportError: + return None + + if get_aten_target(node.target) is not torch.ops.torchao.gguf_dequantize.default: + return None + + weight = node.args[0] + ggml_type = node.args[1] + output_dtype = torch.bfloat16 + if len(node.args) > 2: + output_dtype = node.args[2] + elif "output_dtype" in node.kwargs: + output_dtype = node.kwargs["output_dtype"] + return weight, ggml_type, output_dtype + + +@REGISTRY.register_pattern(name="GGUF_QUANTIZED_LINEAR") +class GGUFQuantizedLinearHandler(PatternHandler): + """Lower ``gguf_dequantize + linear`` to a fused quantized matmul. + + Matches ``linear(x, gguf_dequantize(weight, ggml_type, out_dtype), bias)`` + and dispatches on ``ggml_type``: Q6_K -> custom Metal kernels, Q4_K -> MLX + 4-bit ``quantized_matmul``. + """ + + def __init__(self, head, body, weight, ggml_type, output_dtype): + super().__init__(head, body) + self.weight = weight + self.ggml_type = ggml_type + self.output_dtype = output_dtype + + @classmethod + def maybe_create(cls, ep: ExportedProgram, head: Node): + if not match_target(head, torch.ops.aten.linear.default): + return None + if len(head.args) < 2 or not isinstance(head.args[1], Node): + return None + dequant = head.args[1] + if not has_single_user(dequant): + return None + parsed = parse_gguf_dequantize_node(dequant) + if parsed is None: + return None + weight, ggml_type, output_dtype = parsed + if ggml_type not in _LINEAR_TYPES: + return None + return cls(head, [dequant], weight, ggml_type, output_dtype) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + x_node = n.args[0] + bias_node = n.args[2] if len(n.args) > 2 else None + if self.ggml_type == "q6_k": + from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.linear import ( + emit_linear, + ) + else: # q4_k + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.linear import ( + emit_linear, + ) + return emit_linear(P, n, x_node, self.weight, bias_node) + + +@REGISTRY.register_pattern(name="GGUF_QUANTIZED_EMBEDDING") +class GGUFQuantizedEmbeddingHandler(PatternHandler): + """Fuse ``gguf_dequantize + embedding`` into the Q6_K gather kernel. + + Matches:: + + embedding(gguf_dequantize(weight, "q6_k", out_dtype), indices) + """ + + def __init__(self, head, body, weight, ggml_type, output_dtype): + super().__init__(head, body) + self.weight = weight + self.ggml_type = ggml_type + self.output_dtype = output_dtype + + @classmethod + def maybe_create(cls, ep: ExportedProgram, head: Node): + if not match_target(head, torch.ops.aten.embedding.default): + return None + if len(head.args) < 2 or not isinstance(head.args[0], Node): + return None + dequant = head.args[0] + if not has_single_user(dequant): + return None + parsed = parse_gguf_dequantize_node(dequant) + if parsed is None: + return None + weight, ggml_type, output_dtype = parsed + if ggml_type not in _EMBEDDING_TYPES: + return None + return cls(head, [dequant], weight, ggml_type, output_dtype) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + indices_node = n.args[1] + if self.ggml_type == "q6_k": + from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.embedding import ( + emit_embedding, + ) + else: # q4_k + from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.embedding import ( + emit_embedding, + ) + return emit_embedding(P, n, self.weight, indices_node, self.output_dtype) diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k.py b/backends/mlx/custom_kernel_ops/gguf/q4k.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py b/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py new file mode 100644 index 00000000000..661d49e39d0 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py @@ -0,0 +1,20 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""GGUF **Q4_K** format lowering for the MLX backend. + +Q4_K maps onto MLX's native affine 4-bit kernels (no custom Metal): + +* :mod:`.common` -- repack a raw Q4_K blob into MLX qparams. +* :mod:`.linear` -- ``emit_linear`` (``QuantizedMatmulNode``). +* :mod:`.embedding` -- ``emit_embedding`` (gather + ``DequantizeNode``). + +The pattern handlers in ``custom_kernel_ops.gguf.patterns`` call these ``emit_*`` +functions. ``.linear`` / ``.embedding`` are intentionally NOT imported here so +the package import stays light. +""" diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/common.py b/backends/mlx/custom_kernel_ops/gguf/q4k/common.py new file mode 100644 index 00000000000..c59df0e6183 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/common.py @@ -0,0 +1,44 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""Shared Q4_K -> MLX qparam repack for the Q4_K lowering. + +Q4_K maps cleanly onto MLX's affine 4-bit kernels (group_size 32): the GGUF +blocks are unpacked to the torchao ``IntxUnpackedToInt8Tensor`` layout and +repacked into MLX qparams (``S * Q + B``) at export time, so the weight is +stored MLX-ready and decoded by MLX itself. +""" + +from __future__ import annotations + +from typing import Tuple + +from executorch.backends.mlx.builder.op_helpers import to_mlx_qparams +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from torch.fx.node import Node + +_BITS = 4 + + +def _repack_mlx(P: MLXProgramBuilder, weight_node: Node) -> Tuple[Slot, Slot, Slot, int]: + """Unpack a raw Q4_K blob and repack into MLX qparam constants. + + Returns ``(packed_slot, scales_slot, biases_slot, group_size)``. + """ + from executorch.extension.llm.export.gguf import ExportableGGUFTensor + + weight_target, raw = P.get_placeholder_target_and_tensor(weight_node) + intx = ExportableGGUFTensor.from_raw(raw, "q4_k").to_intx_unpacked_to_int8_tensor() + group_size = int(intx.block_size[-1]) + packed, biases = to_mlx_qparams(intx.qdata, intx.scale, intx.zero_point, _BITS) + + packed_slot = P.make_or_get_constant(f"{weight_target}_q4k_packed", packed) + scales_slot = P.make_or_get_constant(f"{weight_target}_q4k_scales", intx.scale) + biases_slot = P.make_or_get_constant(f"{weight_target}_q4k_biases", biases) + return packed_slot, scales_slot, biases_slot, group_size diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py b/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py new file mode 100644 index 00000000000..0f3b1feddfa --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py @@ -0,0 +1,92 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""GGUF **Q4_K** embedding lowering via MLX's native 4-bit quantized gather. + +Lowers a ``gguf_dequantize -> embedding`` pattern to a quantized gather: gather +the packed quants / scales / biases by index (``TakeNode``), then dequantize the +gathered rows (``DequantizeNode``, mode "affine"). The GGUF blob is repacked into +MLX qparams at export time (see :mod:`.common`). +""" + +from __future__ import annotations + +from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.common import ( + _BITS, + _repack_mlx, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + DequantizeNode, + IntOrVidOrTid, + TakeNode, +) +from torch.fx.node import Node + + +def emit_embedding( + P: MLXProgramBuilder, + head: Node, + weight_node: Node, + indices_node: Node, + output_dtype, +) -> Slot: + """Lower a Q4_K ``gguf_dequantize -> embedding`` pattern to a quantized gather. + + Gathers the packed quants / scales / biases by index, then dequantizes the + gathered rows (MLX affine 4-bit) -- the same shape as MLX's generic quantized + embedding. + """ + w_slot, scales_slot, biases_slot, group_size = _repack_mlx(P, weight_node) + (indices_slot,) = P.slot_map([indices_node]) + ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(indices_slot)) + + _, wq_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(w_slot), + index=ids_index, + out=P.slot_to_tid(wq_sel), + axis=0, + ) + ) + _, sc_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(scales_slot), + index=ids_index, + out=P.slot_to_tid(sc_sel), + axis=0, + ) + ) + _, b_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(biases_slot), + index=ids_index, + out=P.slot_to_tid(b_sel), + axis=0, + ) + ) + + out = P.make_or_get_slot(head) + P.emit( + DequantizeNode( + w=P.slot_to_tid(wq_sel), + scales=P.slot_to_tid(sc_sel), + out=P.slot_to_tid(out), + biases=P.slot_to_tid(b_sel), + group_size=group_size, + bits=_BITS, + mode="affine", + dtype=torch_dtype_to_scalar_type(output_dtype), + ) + ) + return out diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py b/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py new file mode 100644 index 00000000000..3db41cfc3d1 --- /dev/null +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py @@ -0,0 +1,85 @@ +# +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# + +"""GGUF **Q4_K** linear lowering via MLX's native 4-bit quantized matmul. + +Lowers a ``gguf_dequantize -> linear`` pattern to a ``QuantizedMatmulNode`` +(mode "affine", group_size 32); the GGUF blob is repacked into MLX qparams at +export time (see :mod:`.common`). +""" + +from __future__ import annotations + +from typing import Optional + +from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder +from executorch.backends.mlx.builder.slot_manager import Slot +from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.common import ( + _BITS, + _repack_mlx, +) +from executorch.backends.mlx.serialization.mlx_graph_schema import ( + AddNode, + AsTypeNode, + QuantizedMatmulNode, +) +from torch.fx.node import Node + + +def emit_linear( + P: MLXProgramBuilder, + head: Node, + x_node: Node, + weight_node: Node, + bias_node: Optional[Node], +) -> Slot: + """Lower a Q4_K ``gguf_dequantize -> linear`` pattern to MLX 4-bit matmul. + + ``weight_node`` is the raw GGUF blob constant; ``head`` is the ``aten.linear`` + node. The blob is repacked into MLX qparams at export time, so only the + MLX-format constants are serialized. + """ + w_slot, scales_slot, biases_slot, group_size = _repack_mlx(P, weight_node) + x_slot, bias_slot = P.slot_map([x_node, bias_node]) + + out = P.make_or_get_slot(head) + P.emit( + QuantizedMatmulNode( + x=P.slot_to_tid(x_slot), + w=P.slot_to_tid(w_slot), + scales=P.slot_to_tid(scales_slot), + biases=P.slot_to_tid(biases_slot), + out=P.slot_to_tid(out), + group_size=group_size, + bits=_BITS, + mode="affine", + transpose=True, + ) + ) + + if bias_node is not None: + P.emit( + AddNode( + a=P.slot_to_tid(out), + b=P.slot_to_tid(bias_slot), + out=P.slot_to_tid(out), + ) + ) + + out_dtype = head.meta["val"].dtype + if out_dtype != x_node.meta["val"].dtype: + P.emit( + AsTypeNode( + x=P.slot_to_tid(out), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(out_dtype), + ) + ) + + return out diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py b/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py index 0da3e6d3303..0362a946cc7 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py @@ -8,21 +8,21 @@ """GGUF **Q6_K** format implementation. -* :mod:`.common` -- shared primitives (constants, pure-torch dequant, Metal - header). Re-exported here so ``from ...gguf.q6k import dequantize_q6_k`` stays - lightweight (no MLX builder import). -* :mod:`.linear` -- Q6_K mat-vec/mat-mat kernels + eager compute + lowering. -* :mod:`.embedding` -- Q6_K gather kernel + eager compute + lowering. +* :mod:`.common` -- shared primitives (constants + Metal header). Re-exported + here so ``from ...gguf.q6k import Q6K_BLOCK_BYTES`` stays lightweight (no MLX + builder import). +* :mod:`.linear` -- Q6_K mat-vec/mat-mat kernels + ``emit_linear`` lowering. +* :mod:`.embedding` -- Q6_K gather kernel + ``emit_embedding`` lowering. -The format-agnostic op routers live one level up in -``custom_kernel_ops.gguf.linear`` / ``.embedding`` and dispatch here on -``format``. ``.linear`` / ``.embedding`` are intentionally NOT imported here so -importing :mod:`.common` for the pure-torch dequant does not pull in the builder. +The pattern handlers that match ``torchao::gguf_dequantize -> linear/embedding`` +and call these ``emit_*`` functions live one level up in +``custom_kernel_ops.gguf.patterns``. ``.linear`` / ``.embedding`` are +intentionally NOT imported here so importing :mod:`.common` for the pure-torch +dequant does not pull in the builder. """ from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.common import ( # noqa: F401 _Q6K_HEADER, - dequantize_q6_k, Q6K_BLOCK_BYTES, QK_K, ) diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/common.py b/backends/mlx/custom_kernel_ops/gguf/q6k/common.py index 405cc8aba64..9445fbc5b36 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q6k/common.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/common.py @@ -6,19 +6,18 @@ # LICENSE file in the root directory of this source tree. # -"""Shared GGUF **Q6_K** primitives for the MLX custom ops. +"""Shared GGUF **Q6_K** primitives for the MLX backend. This module holds the pieces common to every Q6_K kernel (linear matmul/matvec and the embedding gather), so format-specific op modules import from here rather than from each other: * ``QK_K`` / ``Q6K_BLOCK_BYTES`` and the per-super-block byte layout constants. -* ``dequantize_q6_k`` -- the pure-torch dequant oracle (eager fallback + tests). * ``_Q6K_HEADER`` -- the Metal header (the ``block_q6_K`` struct plus the per-element and vectorized dequant helpers) shared by all Q6_K Metal kernels. -Adding another GGUF format (e.g. Q4_K) should mirror this module (``q4k.py``) -and the op handlers in :mod:`.linear` / :mod:`.embedding` dispatch on ``format``. +Adding another GGUF format (e.g. Q4_K) should mirror this module and the pattern +handlers in :mod:`..patterns` dispatch on the GGUF quant type. Q6_K layout (per 256-element super-block, 210 bytes, see llama.cpp ``block_q6_K`` in ``ggml-common.h``):: @@ -33,17 +32,14 @@ Attribution ----------- -The Q6_K block layout, the Metal dequant helpers in ``_Q6K_HEADER``, and the -pure-torch ``dequantize_q6_k`` reference follow llama.cpp +The Q6_K block layout and the Metal dequant helpers in ``_Q6K_HEADER`` follow +llama.cpp (``ggml-common.h`` / ``ggml-metal.metal``: ``block_q6_K``, ``dequantize_q6_K``), which is MIT-licensed (Copyright (c) 2023-2024 The ggml authors). """ from __future__ import annotations -import torch -from torch import Tensor - # --------------------------------------------------------------------------- # Q6_K constants @@ -58,79 +54,6 @@ Q6K_BLOCK_BYTES = _Q6K_QL_BYTES + _Q6K_QH_BYTES + _Q6K_SCALES + _Q6K_D_BYTES # 210 -# --------------------------------------------------------------------------- -# Pure-torch dequant reference -# --------------------------------------------------------------------------- - - -def dequantize_q6_k(weight: Tensor, K: int) -> Tensor: - """Dequantize a GGUF Q6_K blob to float32. - - Args: - weight: ``(N, n_blocks * 210)`` uint8, GGUF ``block_q6_K`` layout. - K: number of logical input features (``n_blocks * 256``). - - Returns: - ``(N, K)`` float32 dequantized weight. - """ - if weight.dtype != torch.uint8: - raise ValueError(f"gguf_linear: weight must be uint8; got {weight.dtype}") - N = weight.shape[0] - nb = K // QK_K - if weight.shape[-1] != nb * Q6K_BLOCK_BYTES: - raise ValueError( - f"gguf_linear: weight row bytes {weight.shape[-1]} != " - f"{nb} blocks * {Q6K_BLOCK_BYTES}" - ) - - blocks = weight.view(N, nb, Q6K_BLOCK_BYTES) - ql = blocks[..., 0:_Q6K_QL_BYTES].to(torch.int32) - qh = blocks[..., _Q6K_QL_BYTES : _Q6K_QL_BYTES + _Q6K_QH_BYTES].to(torch.int32) - sc_off = _Q6K_QL_BYTES + _Q6K_QH_BYTES - scales = ( - blocks[..., sc_off : sc_off + _Q6K_SCALES] - .contiguous() - .view(torch.int8) - .to(torch.float32) - ) - d = ( - blocks[..., sc_off + _Q6K_SCALES : sc_off + _Q6K_SCALES + _Q6K_D_BYTES] - .contiguous() - .view(torch.float16) - .to(torch.float32) - ) # (N, nb, 1) - - y = torch.empty(N, nb, QK_K, dtype=torch.float32, device=weight.device) - # is = l // 16 over l in 0..31 -> selects which of the 8 half-scales. - is_idx = (torch.arange(32, device=weight.device) // 16).long() # (32,) - - for h in range(2): # two 128-element halves - ql_h = ql[..., h * 64 : h * 64 + 64] # (N, nb, 64) - qh_h = qh[..., h * 32 : h * 32 + 32] # (N, nb, 32) - sc_h = scales[..., h * 8 : h * 8 + 8] # (N, nb, 8) - - ql_lo = ql_h[..., 0:32] - ql_hi = ql_h[..., 32:64] - - q1 = (ql_lo & 0xF) | ((qh_h & 0x3) << 4) - q2 = (ql_hi & 0xF) | (((qh_h >> 2) & 0x3) << 4) - q3 = (ql_lo >> 4) | (((qh_h >> 4) & 0x3) << 4) - q4 = (ql_hi >> 4) | (((qh_h >> 6) & 0x3) << 4) - - sc0 = sc_h[..., is_idx + 0] - sc2 = sc_h[..., is_idx + 2] - sc4 = sc_h[..., is_idx + 4] - sc6 = sc_h[..., is_idx + 6] - - base = h * 128 - y[..., base + 0 : base + 32] = d * sc0 * (q1 - 32).to(torch.float32) - y[..., base + 32 : base + 64] = d * sc2 * (q2 - 32).to(torch.float32) - y[..., base + 64 : base + 96] = d * sc4 * (q3 - 32).to(torch.float32) - y[..., base + 96 : base + 128] = d * sc6 * (q4 - 32).to(torch.float32) - - return y.reshape(N, K) - - # --------------------------------------------------------------------------- # Shared Metal header # --------------------------------------------------------------------------- diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py b/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py index d70cf821176..52e68aea427 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py @@ -8,15 +8,15 @@ """GGUF **Q6_K** embedding implementation. -Provides the two pieces the ``mlx::gguf_embedding`` router dispatches to for the -``"q6k"`` format: +Provides the Q6_K embedding lowering used by the MLX GGUF pattern handler +(:mod:`..patterns`): -* :func:`eager_embedding` -- pure-torch reference (the custom-op eager body). -* :func:`emit_embedding` -- lowers the op to a fused Q6_K gather Metal kernel. +* :func:`emit_embedding` -- lowers a ``gguf_dequantize -> embedding`` pattern to + a fused Q6_K gather Metal kernel. This is the gather counterpart to :mod:`.linear` and exists because MLX's affine dequantize has no group_size=16 Metal kernel, so a Q6_K embedding (group_size 16) -cannot use the generic quantized-embedding path. Output is bfloat16. +cannot use the generic quantized-embedding path. """ from __future__ import annotations @@ -31,7 +31,6 @@ from executorch.backends.mlx.builder.slot_manager import Slot from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.common import ( _Q6K_HEADER, - dequantize_q6_k, Q6K_BLOCK_BYTES, QK_K, ) @@ -39,35 +38,9 @@ IntOrVid, MetalKernelNode, ) -from torch import Tensor from torch.fx.node import Node -# --------------------------------------------------------------------------- -# Eager reference -# --------------------------------------------------------------------------- - - -def eager_embedding(weight: Tensor, indices: Tensor) -> Tensor: - """Eager gather + dequantize of Q6_K embedding rows; returns bfloat16.""" - if weight.dim() != 2: - raise ValueError( - f"mlx::gguf_embedding: weight must be 2-D (vocab, row_bytes); got " - f"shape {tuple(weight.shape)}" - ) - row_bytes = weight.shape[1] - if row_bytes % Q6K_BLOCK_BYTES != 0: - raise ValueError( - f"mlx::gguf_embedding: weight row bytes {row_bytes} must be a " - f"multiple of {Q6K_BLOCK_BYTES}" - ) - K = (row_bytes // Q6K_BLOCK_BYTES) * QK_K - - rows = weight[indices.reshape(-1).long()] # (num, row_bytes) - deq = dequantize_q6_k(rows, K) # (num, K) float32 - return deq.reshape(*indices.shape, K).to(torch.bfloat16) - - # --------------------------------------------------------------------------- # Metal kernel source # --------------------------------------------------------------------------- @@ -86,39 +59,41 @@ def eager_embedding(weight: Tensor, indices: Tensor) -> Tensor: """ -def emit_embedding(P: MLXProgramBuilder, n: Node) -> Slot: - """Lower ``mlx::gguf_embedding`` (q6k) to a fused Q6_K gather Metal kernel.""" - args = P.args(n) - if len(args) != 3: - raise ValueError( - f"mlx::gguf_embedding: expected 3 args (weight, indices, format); " - f"got {len(args)}" - ) - weight_slot, indices_slot, _fmt = args - weight_node = n.args[0] - indices_node = n.args[1] +def emit_embedding( + P: MLXProgramBuilder, + head: Node, + weight_node: Node, + indices_node: Node, + output_dtype: torch.dtype, +) -> Slot: + """Lower a Q6_K ``gguf_dequantize`` -> ``embedding`` pattern to a fused gather. + + ``weight_node`` is the raw GGUF blob (the dequantize op's weight input) and + ``head`` is the ``aten.embedding`` node that owns the output slot. + """ + weight_slot, indices_slot = P.slot_map([weight_node, indices_node]) weight_meta = weight_node.meta["val"] if weight_meta.dim() != 2: raise NotImplementedError( - f"mlx::gguf_embedding: weight must be 2-D (vocab, row_bytes); got " + f"gguf q6k embedding: weight must be 2-D (vocab, row_bytes); got " f"shape {tuple(weight_meta.shape)}" ) row_bytes = weight_meta.shape[1] if not isinstance(row_bytes, int): raise NotImplementedError( - "mlx::gguf_embedding: weight shape must be statically known" + "gguf q6k embedding: weight shape must be statically known" ) if row_bytes % Q6K_BLOCK_BYTES != 0: raise ValueError( - f"mlx::gguf_embedding: weight row bytes {row_bytes} must be a " + f"gguf q6k embedding: weight row bytes {row_bytes} must be a " f"multiple of {Q6K_BLOCK_BYTES}" ) K = (row_bytes // Q6K_BLOCK_BYTES) * QK_K - out_dtype_int = torch_dtype_to_scalar_type(torch.bfloat16) + out_dtype_int = torch_dtype_to_scalar_type(output_dtype) - out = P.make_or_get_slot(n) + out = P.make_or_get_slot(head) leading = emit_shape(P, indices_node, indices_slot, end_dim=None) num_idx_iov = emit_product(P, leading) out_shape_flat = leading + [IntOrVid.from_literal(K)] diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py b/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py index 581612588f1..c8cf817638f 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py @@ -8,11 +8,12 @@ """GGUF **Q6_K** linear implementation. -Provides the two pieces the ``mlx::gguf_linear`` router dispatches to for the -``"q6k"`` format: +Provides the Q6_K linear pieces used by the MLX GGUF pattern handler +(:mod:`..patterns`): -* :func:`eager_linear` -- pure-torch reference (the custom-op eager body). -* :func:`emit_linear` -- lowers the op to fused Q6_K Metal kernels. +* :func:`eager_linear` -- pure-torch reference (``x @ dequant(weight)^T``). +* :func:`emit_linear` -- lowers a ``gguf_dequantize -> linear`` pattern to fused + Q6_K Metal kernels. Compute is keyed on the activation dtype (matching GGUF/llama.cpp): the Metal kernels are templated on ``InT``, accumulate in ``float32``, read ``d`` as @@ -43,7 +44,6 @@ from typing import Optional -import torch from executorch.backends.mlx.builder.op_helpers import ( emit_product, emit_shape, @@ -53,7 +53,6 @@ from executorch.backends.mlx.builder.slot_manager import Slot from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.common import ( _Q6K_HEADER, - dequantize_q6_k, Q6K_BLOCK_BYTES, QK_K, ) @@ -66,41 +65,9 @@ MultiplyIntNode, SubtractIntNode, ) -from torch import Tensor from torch.fx.node import Node -# --------------------------------------------------------------------------- -# Eager reference -# --------------------------------------------------------------------------- - - -def eager_linear(x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: - """Eager ``x @ dequant(weight)^T (+ bias)`` for a Q6_K weight blob.""" - if weight.dim() != 2: - raise ValueError( - f"mlx::gguf_linear: weight must be 2-D (N, row_bytes); got " - f"shape {tuple(weight.shape)}" - ) - _N, row_bytes = weight.shape - if row_bytes % Q6K_BLOCK_BYTES != 0: - raise ValueError( - f"mlx::gguf_linear: weight row bytes {row_bytes} must be a multiple of " - f"{Q6K_BLOCK_BYTES} (one q6_K block per 256 features)" - ) - K = (row_bytes // Q6K_BLOCK_BYTES) * QK_K - if x.shape[-1] != K: - raise ValueError( - f"mlx::gguf_linear: x last dim {x.shape[-1]} != K {K} implied by weight" - ) - - w_deq = dequantize_q6_k(weight, K) # (N, K) float32 - out = torch.matmul(x.to(torch.float32), w_deq.t()) # (..., N) float32 - if bias is not None: - out = out + bias.to(torch.float32) - return out.to(x.dtype) - - # --------------------------------------------------------------------------- # Metal kernel sources # --------------------------------------------------------------------------- @@ -335,7 +302,6 @@ def _q6k_matmul_source(has_bias: bool) -> str: def _emit_q6k_matvec( P: MLXProgramBuilder, - n: Node, x_node: Node, x_slot: Slot, weight_slot: Slot, @@ -393,7 +359,6 @@ def _emit_q6k_matvec( def _emit_q6k_matmul( P: MLXProgramBuilder, - n: Node, x_node: Node, x_slot: Slot, weight_slot: Slot, @@ -459,37 +424,35 @@ def _emit_q6k_matmul( ) -def emit_linear(P: MLXProgramBuilder, n: Node) -> Slot: - """Lower ``mlx::gguf_linear`` (q6k) to fused Q6_K Metal kernels.""" - args = P.args(n) - if len(args) == 4: - x_slot, weight_slot, _fmt, bias_slot = args - elif len(args) == 3: - x_slot, weight_slot, _fmt = args - bias_slot = None - else: - raise ValueError( - f"mlx::gguf_linear: expected 3 or 4 args (x, weight, format[, bias]); " - f"got {len(args)}" - ) - x_node = n.args[0] - weight_node = n.args[1] +def emit_linear( + P: MLXProgramBuilder, + head: Node, + x_node: Node, + weight_node: Node, + bias_node: Optional[Node], +) -> Slot: + """Lower a Q6_K ``gguf_dequantize`` -> ``linear`` pattern to fused kernels. + + ``weight_node`` is the raw GGUF blob (the dequantize op's weight input) and + ``head`` is the ``aten.linear`` node that owns the output slot. + """ + x_slot, weight_slot, bias_slot = P.slot_map([x_node, weight_node, bias_node]) weight_meta = weight_node.meta["val"] if weight_meta.dim() != 2: raise NotImplementedError( - f"mlx::gguf_linear: weight must be 2-D (N, row_bytes); got " + f"gguf q6k linear: weight must be 2-D (N, row_bytes); got " f"shape {tuple(weight_meta.shape)}" ) N = weight_meta.shape[0] row_bytes = weight_meta.shape[1] if not isinstance(N, int) or not isinstance(row_bytes, int): raise NotImplementedError( - "mlx::gguf_linear: weight shape must be statically known" + "gguf q6k linear: weight shape must be statically known" ) if row_bytes % Q6K_BLOCK_BYTES != 0: raise ValueError( - f"mlx::gguf_linear: weight row bytes {row_bytes} must be a multiple of " + f"gguf q6k linear: weight row bytes {row_bytes} must be a multiple of " f"{Q6K_BLOCK_BYTES}" ) K = (row_bytes // Q6K_BLOCK_BYTES) * QK_K @@ -506,17 +469,16 @@ def emit_linear(P: MLXProgramBuilder, n: Node) -> Slot: M = None # dynamic / symbolic break - out = P.make_or_get_slot(n) + out = P.make_or_get_slot(head) tile = _Q6K_MM_NR1 # M-dimension tile (activation rows per threadgroup) if M == 1: # Static decode -> mat-vec. - _emit_q6k_matvec(P, n, x_node, x_slot, weight_slot, bias_slot, N, K, out) + _emit_q6k_matvec(P, x_node, x_slot, weight_slot, bias_slot, N, K, out) elif M is not None: # Static prefill -> tiled simdgroup mat-mat (literal grid). blocks_m = (M + tile - 1) // tile _emit_q6k_matmul( P, - n, x_node, x_slot, weight_slot, @@ -565,7 +527,6 @@ def emit_linear(P: MLXProgramBuilder, n: Node) -> Slot: with P.new_chain() as then_idx: # prefill / mat-mat _emit_q6k_matmul( P, - n, x_node, x_slot, weight_slot, @@ -576,7 +537,7 @@ def emit_linear(P: MLXProgramBuilder, n: Node) -> Slot: out, ) with P.new_chain() as else_idx: # decode / mat-vec - _emit_q6k_matvec(P, n, x_node, x_slot, weight_slot, bias_slot, N, K, out) + _emit_q6k_matvec(P, x_node, x_slot, weight_slot, bias_slot, N, K, out) P.emit( IfNode( diff --git a/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py index 1c5c7695bf2..a6586022a87 100644 --- a/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py @@ -6,38 +6,50 @@ # LICENSE file in the root directory of this source tree. """ -Tests for ``mlx::gguf_embedding`` (GGUF Q6_K embedding gather). +Tests for the GGUF Q6_K embedding lowering. -Compares the fused gather Metal kernel against the eager reference on the same -packed Q6_K table. The kernel and reference run identical per-element float -dequant, so the bf16 outputs match exactly. +An ``nn.Embedding`` whose weight is an ``ExportableGGUFTensor`` exports to +``embedding(torchao::gguf_dequantize(weight, "q6_k", ...), indices)``. The MLX +``GGUF_QUANTIZED_EMBEDDING`` pattern matches that subgraph and lowers it to the +fused Q6_K gather Metal kernel. These tests compare the kernel against the eager +reference (``gguf``-package dequant + ``F.embedding``) on the same packed table. Usage:: python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_embedding run - python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_embedding run --rebuild + python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_embedding list """ from typing import List, Tuple -import executorch.backends.mlx.custom_kernel_ops.gguf.embedding # noqa: F401 +# Importing the patterns module registers GGUF_QUANTIZED_LINEAR / _EMBEDDING. +import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 import torch import torch.nn as nn from executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear import ( make_q6_k_blob, ) from executorch.backends.mlx.test.test_utils import OpTestCase +from executorch.extension.llm.export.gguf import ExportableGGUFTensor -class GGUFEmbeddingModel(nn.Module): - def forward(self, weight: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: - return torch.ops.mlx.gguf_embedding(weight, indices, "q6k") +def _make_gguf_embedding_model(vocab: int, K: int, seed: int = 0) -> nn.Module: + """An ``nn.Embedding`` whose weight is a Q6_K ``ExportableGGUFTensor``.""" + emb = nn.Embedding(vocab, K) + blob = make_q6_k_blob(vocab, K, seed=seed) + emb.weight = nn.Parameter( + ExportableGGUFTensor.from_raw(blob, "q6_k", torch.bfloat16), + requires_grad=False, + ) + return emb class GGUFEmbeddingTest(OpTestCase): name = "gguf_embedding" - rtol = 0.0 - atol = 0.0 + # Reference dequant runs in fp32 (gguf) then casts to bf16; the kernel + # dequantizes per element to bf16, so allow bf16 tolerance. + rtol = 2e-2 + atol = 2e-2 def __init__( self, @@ -67,14 +79,19 @@ def get_test_configs(cls) -> List["GGUFEmbeddingTest"]: cls(vocab=2048, K=5376, idx_shape=(8,)), ] + def get_edge_compile_config(self): + from executorch.exir import EdgeCompileConfig + + # The gguf_dequantize custom op isn't a core ATen op; skip IR validity. + return EdgeCompileConfig(_check_ir_validity=False) + def create_model(self) -> nn.Module: - return GGUFEmbeddingModel() + return _make_gguf_embedding_model(self.vocab, self.K) def create_inputs(self) -> Tuple[torch.Tensor, ...]: torch.manual_seed(0) - weight = make_q6_k_blob(self.vocab, self.K) - indices = torch.randint(0, self.vocab, self.idx_shape, dtype=torch.int32) - return (weight, indices) + indices = torch.randint(0, self.vocab, self.idx_shape, dtype=torch.int64) + return (indices,) def _main() -> None: # noqa: C901 @@ -83,7 +100,7 @@ def _main() -> None: # noqa: C901 from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner - parser = argparse.ArgumentParser(description="Test mlx::gguf_embedding op") + parser = argparse.ArgumentParser(description="Test GGUF Q6_K embedding lowering") parser.add_argument("action", choices=["generate", "compare", "run", "list"]) parser.add_argument("--verbose", "-v", action="store_true") parser.add_argument("--rebuild", action="store_true") diff --git a/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py index f91f133547c..a1663958a09 100644 --- a/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py @@ -6,36 +6,37 @@ # LICENSE file in the root directory of this source tree. """ -Tests for ``mlx::gguf_linear`` (GGUF Q6_K linear). +Tests for the GGUF Q6_K linear lowering. -Compares the fused Metal kernels (mat-vec for decode, mat-mat for prefill) -against the eager pure-torch reference on the *same* packed Q6_K weight, so -quantization quality is irrelevant -- only the kernel-vs-reference numerics -are checked. Tolerances follow the activation dtype presets. +A linear whose weight is an ``ExportableGGUFTensor`` (extension/llm/export/gguf) +exports to ``linear(x, torchao::gguf_dequantize(weight, "q6_k", ...), bias)``. +The MLX ``GGUF_QUANTIZED_LINEAR`` pattern (custom_kernel_ops/gguf/patterns.py) +matches that subgraph and lowers it to the fused Q6_K Metal kernels (mat-vec for +decode, mat-mat for prefill). These tests compare the fused kernels against the +eager reference (``gguf``-package dequant + ``F.linear``) on the same packed +weight, so quantization quality is irrelevant -- only kernel-vs-reference +numerics are checked. -``GGUFLinearDynamicTest`` additionally exports once with a symbolic seqlen and -runs the same .pte with M=1 and M>1 to exercise both branches of the runtime -``IfNode`` (decode mat-vec vs prefill mat-mat). +``GGUFLinearDynamicTest`` exports once with a symbolic seqlen and runs the same +.pte with M=1 and M>1 to exercise both branches of the runtime ``IfNode`` +(decode mat-vec vs prefill mat-mat). Usage:: python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear run python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear run -v - python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear run --rebuild - python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear eager + python -m executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear list """ from typing import List, Tuple -import executorch.backends.mlx.custom_kernel_ops.gguf.linear # noqa: F401 +# Importing the patterns module registers GGUF_QUANTIZED_LINEAR / _EMBEDDING. +import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 import torch import torch.nn as nn -from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import ( - dequantize_q6_k, - Q6K_BLOCK_BYTES, - QK_K, -) +from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import Q6K_BLOCK_BYTES, QK_K from executorch.backends.mlx.test.test_utils import OpTestCase +from executorch.extension.llm.export.gguf import ExportableGGUFTensor # --------------------------------------------------------------------------- @@ -43,9 +44,9 @@ # # The Python ``gguf`` package can dequantize Q6_K but does NOT implement Q6_K # quantization, so we build the packed weight here. Quantization quality is -# irrelevant: the tests only compare the kernel against the eager op on the -# *same* bytes, so we just emit valid random blocks (random ql/qh/scales plus a -# small finite fp16 ``d`` -- the one field that must be finite). +# irrelevant: the tests only compare the kernel against the eager reference on +# the *same* bytes, so we just emit valid random blocks (random ql/qh/scales +# plus a small finite fp16 ``d`` -- the one field that must be finite). # --------------------------------------------------------------------------- @@ -72,19 +73,54 @@ def make_q6_k_blob(N: int, K: int, seed: int = 0) -> torch.Tensor: return out -# --------------------------------------------------------------------------- -# Test cases -# --------------------------------------------------------------------------- +def make_q4_k_blob(N: int, K: int, seed: int = 0) -> torch.Tensor: + """Build a ``(N, (K/256)*144)`` uint8 tensor of valid GGUF Q4_K blocks.""" + assert K % QK_K == 0, f"K={K} must be a multiple of {QK_K}" + nb = K // QK_K + block_bytes = 144 # Q4_K: d(2) + dmin(2) + scales(12) + qs(128) + g = torch.Generator().manual_seed(seed) + out = torch.empty(N, nb * block_bytes, dtype=torch.uint8) + blocks = out.view(N, nb, block_bytes) + # d (0:2) / dmin (2:4): small finite fp16 super-block scale + min, so + # dequantized magnitudes stay O(0.1) like real Q4_K weights. + blocks[..., 0:2] = torch.tensor([7e-4], dtype=torch.float16).view(torch.uint8) + blocks[..., 2:4] = torch.tensor([7e-4], dtype=torch.float16).view(torch.uint8) + # scales+mins (4:16, 6-bit packed) and qs (16:144, 4-bit): any bytes valid. + blocks[..., 4:144] = torch.randint( + 0, 256, (N, nb, 140), dtype=torch.uint8, generator=g + ) + return out + + +_BLOB_MAKERS = {"q6_k": make_q6_k_blob, "q4_k": make_q4_k_blob} + + +def _make_gguf_linear_model( + N: int, + K: int, + dtype: torch.dtype, + bias: bool, + ggml_type: str = "q6_k", + seed: int = 0, +) -> nn.Module: + """An ``nn.Linear`` whose weight is a GGUF ``ExportableGGUFTensor``.""" + linear = nn.Linear(K, N, bias=bias).to(dtype) + blob = _BLOB_MAKERS[ggml_type](N, K, seed=seed) + linear.weight = nn.Parameter( + ExportableGGUFTensor.from_raw(blob, ggml_type, dtype), requires_grad=False + ) + return linear class GGUFLinearModel(nn.Module): - def forward( - self, - x: torch.Tensor, - weight: torch.Tensor, - bias: torch.Tensor, - ) -> torch.Tensor: - return torch.ops.mlx.gguf_linear(x, weight, "q6k", bias) + """Wrapper so the forward arg is named ``x`` (for dynamic-shape specs).""" + + def __init__(self, linear: nn.Module): + super().__init__() + self.linear = linear + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) _DTYPE_TOL = { @@ -97,6 +133,13 @@ def forward( _DTYPE_TAG = {torch.bfloat16: "bf16", torch.float16: "fp16", torch.float32: "fp32"} +def _edge_compile_config(): + from executorch.exir import EdgeCompileConfig + + # The gguf_dequantize custom op isn't a core ATen op; skip IR validity. + return EdgeCompileConfig(_check_ir_validity=False) + + class GGUFLinearTest(OpTestCase): name = "gguf_linear" @@ -106,13 +149,18 @@ def __init__( N: int = 256, K: int = 256, dtype: torch.dtype = torch.bfloat16, + bias: bool = True, + ggml_type: str = "q6_k", ): self.M = M self.N = N self.K = K self.dtype = dtype + self.bias = bias + self.ggml_type = ggml_type self.rtol, self.atol = _DTYPE_TOL[dtype] - self.name = f"gguf_linear_m{M}_n{N}_k{K}_{_DTYPE_TAG[dtype]}" + tag = f"gguf_linear_{ggml_type}_m{M}_n{N}_k{K}_{_DTYPE_TAG[dtype]}" + self.name = tag if bias else tag + "_nobias" @classmethod def get_test_configs(cls) -> List["GGUFLinearTest"]: @@ -123,6 +171,7 @@ def get_test_configs(cls) -> List["GGUFLinearTest"]: cfgs.append(cls(M=1, N=N, K=K, dtype=torch.bfloat16)) cfgs.append(cls(M=1, N=256, K=256, dtype=torch.float16)) cfgs.append(cls(M=1, N=256, K=256, dtype=torch.float32)) + cfgs.append(cls(M=1, N=256, K=256, dtype=torch.bfloat16, bias=False)) # Prefill (mat-mat). for M in (8, 64, 128): cfgs.append(cls(M=M, N=512, K=512, dtype=torch.bfloat16)) @@ -130,8 +179,7 @@ def get_test_configs(cls) -> List["GGUFLinearTest"]: # Ragged shapes (M and N not multiples of the 32-wide tile / row group). cfgs.append(cls(M=40, N=300, K=256, dtype=torch.bfloat16)) cfgs.append(cls(M=1, N=300, K=256, dtype=torch.bfloat16)) - # Real Gemma-4-31B shapes (hidden=5376, ffn=21504, vocab=262144) to - # exercise the kernels at production N/K (decode + prefill). + # Real Gemma-4-31B shapes (hidden=5376, ffn=21504) at production N/K. cfgs.append(cls(M=1, N=4096, K=5376, dtype=torch.bfloat16)) # attn_v cfgs.append(cls(M=1, N=5376, K=21504, dtype=torch.bfloat16)) # ffn_down cfgs.append(cls(M=8, N=5376, K=21504, dtype=torch.bfloat16)) # ffn_down prefill @@ -139,17 +187,28 @@ def get_test_configs(cls) -> List["GGUFLinearTest"]: # fits CI-runner GPU buffer limits; the mat-vec N-tiling path is the # same at any N. cfgs.append(cls(M=1, N=16384, K=5376, dtype=torch.bfloat16)) # lm_head + # Q4_K -> MLX native 4-bit quantized_matmul (group_size 32). + cfgs.append(cls(M=1, N=512, K=512, dtype=torch.bfloat16, ggml_type="q4_k")) + cfgs.append(cls(M=8, N=512, K=512, dtype=torch.bfloat16, ggml_type="q4_k")) + cfgs.append(cls(M=1, N=5376, K=5376, dtype=torch.bfloat16, ggml_type="q4_k")) + cfgs.append( + cls(M=1, N=512, K=512, dtype=torch.bfloat16, bias=False, ggml_type="q4_k") + ) return cfgs + def get_edge_compile_config(self): + return _edge_compile_config() + def create_model(self) -> nn.Module: - return GGUFLinearModel() + return GGUFLinearModel( + _make_gguf_linear_model( + self.N, self.K, self.dtype, self.bias, self.ggml_type + ) + ) def create_inputs(self) -> Tuple[torch.Tensor, ...]: torch.manual_seed(0) - x = torch.randn(self.M, self.K, dtype=self.dtype) - weight = make_q6_k_blob(self.N, self.K) - bias = torch.randn(self.N, dtype=self.dtype) - return (x, weight, bias) + return (torch.randn(self.M, self.K, dtype=self.dtype),) class GGUFLinearDynamicTest(OpTestCase): @@ -169,7 +228,6 @@ def __init__( ): self.export_M = export_M self.test_M = test_M - self.seq_len = export_M # used by create_inputs (export tracing) self.N = N self.K = K self.dtype = dtype @@ -191,47 +249,38 @@ def get_test_configs(cls) -> List["GGUFLinearDynamicTest"]: def get_dynamic_shapes(self): seq_dim = torch.export.Dim("seq_len", min=1, max=64) - return {"x": {0: seq_dim}, "weight": None, "bias": None} + return {"x": {0: seq_dim}} - def create_model(self) -> nn.Module: - return GGUFLinearModel() + def get_edge_compile_config(self): + return _edge_compile_config() - def _make_inputs(self, M: int) -> Tuple[torch.Tensor, ...]: - # Deterministic weight/bias so export-time and run-time (and the eager - # reference) all use the same quantized weight (it is a runtime input). - torch.manual_seed(0) - weight = make_q6_k_blob(self.N, self.K) - bias = torch.randn(self.N, dtype=self.dtype) - x = torch.randn(M, self.K, dtype=self.dtype) - return (x, weight, bias) + def create_model(self) -> nn.Module: + # Deterministic weight so export-time and run-time use the same model. + return GGUFLinearModel( + _make_gguf_linear_model(self.N, self.K, self.dtype, bias=True) + ) def create_inputs(self) -> Tuple[torch.Tensor, ...]: - return self._make_inputs(self.export_M) + torch.manual_seed(0) + return (torch.randn(self.export_M, self.K, dtype=self.dtype),) def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: - return self._make_inputs(self.test_M) + torch.manual_seed(0) + return (torch.randn(self.test_M, self.K, dtype=self.dtype),) def _eager_sanity() -> None: - """Quick CPU check: dequant + matmul matches the eager op on the same bytes.""" - torch.manual_seed(0) - N, K = 4, 512 - packed = make_q6_k_blob(N, K) - w_deq = dequantize_q6_k(packed, K) - print(f"dequant finite: {torch.isfinite(w_deq).all().item()}") - x = torch.randn(3, K) - ref = x @ w_deq.t() - out = torch.ops.mlx.gguf_linear(x, packed, "q6k", None) - err = (out - ref).abs().max() - print(f"eager op vs reference max abs err: {err.item():.6e}") - assert err < 1e-3, err - # Unsupported format raises. - try: - torch.ops.mlx.gguf_linear(x, packed, "q4k", None) - raise AssertionError("expected NotImplementedError for q4k") - except RuntimeError as e: - print(f"q4k correctly rejected: {type(e).__name__}") - print("eager sanity OK") + """Quick CPU check: the subclass linear exports to gguf_dequantize.""" + model = GGUFLinearModel(_make_gguf_linear_model(4, 512, torch.bfloat16, bias=True)) + x = torch.randn(3, 512, dtype=torch.bfloat16) + out = model(x) + print( + f"eager forward finite: {torch.isfinite(out).all().item()}, shape {tuple(out.shape)}" + ) + ep = torch.export.export(model, (x,)).run_decompositions({}) + targets = {str(n.target) for n in ep.graph.nodes if n.op == "call_function"} + assert "torchao.gguf_dequantize.default" in targets, targets + print("export contains torchao.gguf_dequantize: OK") if __name__ == "__main__": # noqa: C901 @@ -240,7 +289,7 @@ def _eager_sanity() -> None: from executorch.backends.mlx.test.test_utils import rebuild_op_test_runner - parser = argparse.ArgumentParser(description="Test mlx::gguf_linear op") + parser = argparse.ArgumentParser(description="Test GGUF Q6_K linear lowering") parser.add_argument( "action", choices=["generate", "compare", "run", "list", "eager"] ) diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index e46b51a8411..bcdd49a9a34 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -306,6 +306,11 @@ def _export_mlx( """ import gc + # Register the GGUF dequant op + MLX GGUF pattern handlers so quantized GGUF + # weights lower to the fused Q6_K kernels / Q4_K quantized matmul. + import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 + import executorch.extension.llm.export.gguf # noqa: F401 + from executorch.backends.mlx import MLXPartitioner from executorch.backends.mlx.passes import get_default_passes diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 9b1466af63f..a5452000124 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -6,9 +6,19 @@ """Load a GGUF file into a Gemma 4 31B model. -Streams tensors one at a time via ``iter_gguf_tensors`` for low peak -memory, remaps GGUF names to model FQNs, handles tied embed/lm_head, -and packs for the target backend. +Streams tensors one at a time via the shared loader in +``extension/llm/export/gguf.py`` (each quantized weight arrives as an +``ExportableGGUFTensor`` wrapping the raw GGUF blob), remaps GGUF names to model +FQNs, handles the tied embed/lm_head, and converts each weight for the target +backend: + +* **MLX**: linears keep the ``ExportableGGUFTensor`` (lowered by the MLX GGUF + pattern -- Q6_K custom kernels, Q4_K native 4-bit matmul); a Q6_K token + embedding keeps it too (fused gather), while a Q4_K embedding is converted to + ``IntxUnpackedToInt8Tensor`` (MLX quantized gather -- there is no Q4_K gather + kernel). +* **CUDA**: Q4_K -> ``Int4Tensor``, Q6_K -> ``IntxUnpackedToInt8Tensor``; the + token embedding is dequantized to bf16 (``Int4Tensor`` can't gather). Usage: model, config = load_gguf_model("model.gguf", backend="cuda") @@ -65,32 +75,6 @@ def gguf_to_model_key(gguf_key: str) -> Optional[str]: return None -def _resolve_tied_lm_head(model, embed_quant, embed_q6k_raw, packers): - """Handle tied embed/lm_head after streaming all tensors.""" - from executorch.examples.models.gemma4_31b.quant import pack_one - - lm_head = getattr(model.lm_head, "weight", None) - if lm_head is None or lm_head.device.type != "meta": - return - if embed_q6k_raw is not None: - # Tied Q6_K weights: lm_head is a matmul, so use the fused gguf_linear - # op (the embedding itself stays a quantized gather). - from executorch.examples.models.gemma4_31b.mlx_gguf_linear import ( - replace_with_gguf_linear, - ) - - replace_with_gguf_linear(model, "lm_head.weight", embed_q6k_raw, format="q6k") - elif embed_quant is not None: - pack_one(model, "lm_head.weight", embed_quant, packers) - else: - pack_one( - model, - "lm_head.weight", - model.embed_tokens.weight.data.clone(), - packers, - ) - - def _validate_no_meta(model): """Ensure all parameters have been loaded.""" for fqn, p in model.named_parameters(): @@ -103,32 +87,34 @@ def _validate_no_meta(model): p.requires_grad_(False) -def _handle_mlx_q6k(model, model_key, result): - """Handle Q6_K tensors for the MLX backend. +def _is_embedding(model, model_key: str) -> bool: + parent = model.get_submodule(model_key.rsplit(".", 1)[0]) + return isinstance(parent, torch.nn.Embedding) - Returns ``(processed, embed_q6k_raw)`` where ``processed`` is True - if the tensor was consumed by the fused gguf_linear/gguf_embedding - path (caller should ``continue``), and ``embed_q6k_raw`` is the raw - blob when the tensor is the embedding (for tied lm_head reuse). - """ - from executorch.examples.models.gemma4_31b.mlx_gguf_linear import ( - replace_with_gguf_embedding, - replace_with_gguf_linear, - ) - embed_q6k_raw = None - parent = model.get_submodule(model_key.rsplit(".", 1)[0]) - if isinstance(parent, torch.nn.Linear): - replace_with_gguf_linear(model, model_key, result, format="q6k") - return True, None - if isinstance(parent, torch.nn.Embedding): - if model_key == "embed_tokens.weight": - embed_q6k_raw = result - replace_with_gguf_embedding(model, model_key, result, format="q6k") - return True, embed_q6k_raw +def _convert_weight(model, model_key: str, gtensor, backend: str): + """Convert an ``ExportableGGUFTensor`` to the per-backend module weight.""" + if backend == "mlx": + return gtensor + # CUDA: native torchao quantized tensors. + if gtensor.ggml_type == "q4_k": + return gtensor.to_int4_tensor() + return gtensor.to_intx_unpacked_to_int8_tensor() - # Any other Q6_K module: fall back to a quantized tensor. - return False, None + +def _resolve_tied_lm_head(model, lm_head_weight, packers): + """Assign a tied lm_head (GGUF ties it to the token embedding).""" + from executorch.examples.models.gemma4_31b.quant import pack_one + + lm_head = getattr(model.lm_head, "weight", None) + if lm_head is None or lm_head.device.type != "meta": + return + if lm_head_weight is not None: + pack_one(model, "lm_head.weight", lm_head_weight, packers) + else: + pack_one( + model, "lm_head.weight", model.embed_tokens.weight.data.clone(), packers + ) def load_gguf_model( @@ -136,27 +122,20 @@ def load_gguf_model( max_seq_len: int = 4096, backend: str = "cuda", ) -> tuple: - """Load a GGUF file, remap keys, and pack for the target backend. + """Load a GGUF file, remap keys, and convert weights for the target backend. Streams tensors one at a time for low peak memory. - GGUF ties ``embed_tokens`` and ``lm_head`` into a single Q4_K tensor. - We untie them so ``lm_head`` keeps the original Q4_K quantization. - On CUDA, the embedding is dequantized to bf16 because ``Int4Tensor`` - does not support the gather op that ``nn.Embedding`` requires. On - MLX, the embedding stays quantized — ``QuantizedEmbeddingHandler`` - handles quantized gather natively. + GGUF ties ``embed_tokens`` and ``lm_head`` into a single tensor. We untie + them so ``lm_head`` keeps its quantization: on MLX it lowers through the GGUF + linear pattern; on CUDA it stays a quantized ``Int4Tensor`` / + ``IntxUnpackedToInt8Tensor``, while the embedding is dequantized to bf16. Returns ``(model, config)``. """ from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig from executorch.examples.models.gemma4_31b.quant import dequantize_weight, pack_one - from executorch.examples.models.gemma4_31b.quant.gguf import ( - _unpack_q6_k, - iter_gguf_tensors, - ) - from torchao.quantization import IntxUnpackedToInt8Tensor - from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + from executorch.extension.llm.export.gguf import ExportableGGUFTensor, iter_gguf if backend == "cuda": from executorch.examples.models.gemma4_31b.quant import DEFAULT_CUDA_PACKERS @@ -175,62 +154,34 @@ def load_gguf_model( with torch.device("meta"): model = Gemma4_31B(config) - embed_quant = None - embed_q6k_raw = None # raw Q6_K embedding blob, reused for a tied lm_head + lm_head_weight = None # weight reused for a tied lm_head n_processed = 0 - from gguf import GGMLQuantizationType - print(f"Streaming GGUF from {gguf_path}...") - for gguf_name, result, gguf_type in iter_gguf_tensors( - gguf_path, q6k_raw=(backend == "mlx") - ): + for gguf_name, value in iter_gguf(gguf_path): model_key = gguf_to_model_key(gguf_name) if model_key is None: continue - if type(result) is torch.Tensor and result.dtype == torch.float32: - result = result.to(torch.bfloat16) - - # MLX Q6_K: Linear weights use the fused gguf_linear op, the token - # embedding uses the fused gguf_embedding op -- both consume the raw - # GGUF blob (group_size=16 has no MLX affine kernel). The embedding's - # raw blob is also kept so a tied lm_head can use gguf_linear. - if backend == "mlx" and gguf_type == GGMLQuantizationType.Q6_K: - processed, raw = _handle_mlx_q6k(model, model_key, result) - if raw is not None: - embed_q6k_raw = raw - if processed: - n_processed += 1 - continue - - # Fallback: unpack Q6_K to a quantized tensor. - from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import ( - Q6K_BLOCK_BYTES, - QK_K, - ) - - n_rows, row_bytes = result.shape - result = _unpack_q6_k( - result.reshape(-1), - [n_rows, (row_bytes // Q6K_BLOCK_BYTES) * QK_K], - ) - - if model_key == "embed_tokens.weight" and isinstance( - result, (Int4Tensor, IntxUnpackedToInt8Tensor) - ): - embed_quant = result - if backend == "cuda": - result = dequantize_weight(result, torch.bfloat16) + if isinstance(value, ExportableGGUFTensor): + weight = _convert_weight(model, model_key, value, backend) + if model_key == "embed_tokens.weight": + # Tied lm_head reuses the embedding weight: MLX wants the raw + # ExportableGGUFTensor (linear pattern), CUDA the quant tensor. + lm_head_weight = value if backend == "mlx" else weight + if backend == "cuda": + weight = dequantize_weight(weight, torch.bfloat16) + value = weight + elif value.dtype == torch.float32: + value = value.to(torch.bfloat16) - pack_one(model, model_key, result, packers) + pack_one(model, model_key, value, packers) n_processed += 1 if n_processed % 100 == 0: print(f" Processed {n_processed} tensors...") - _resolve_tied_lm_head(model, embed_quant, embed_q6k_raw, packers) - del embed_quant + _resolve_tied_lm_head(model, lm_head_weight, packers) _validate_no_meta(model) model.eval() diff --git a/examples/models/gemma4_31b/mlx_gguf_linear.py b/examples/models/gemma4_31b/mlx_gguf_linear.py deleted file mode 100644 index 23146eaf3d3..00000000000 --- a/examples/models/gemma4_31b/mlx_gguf_linear.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""MLX carrier modules for GGUF Q6_K weights. - -Wrap raw GGUF ``block_q6_K`` blobs and dispatch to the fused ``mlx::gguf_linear`` -(matmul) and ``mlx::gguf_embedding`` (gather) Metal kernels, instead of the slow -non-fused dequantize paths that group_size=16 affine quant takes through the MLX -``QUANTIZED_LINEAR`` / quantized-embedding patterns. -""" - -from __future__ import annotations - -# Importing the op modules registers the custom ops. -import executorch.backends.mlx.custom_kernel_ops.gguf.embedding # noqa: F401 -import executorch.backends.mlx.custom_kernel_ops.gguf.linear # noqa: F401 -import torch -import torch.nn as nn -from executorch.backends.mlx.custom_kernel_ops.gguf.q6k import Q6K_BLOCK_BYTES, QK_K - - -class GGUFLinear(nn.Module): - """``y = gguf_linear(x, weight_blob, format)`` for a GGUF-quantized linear. - - The weight is the **raw** GGUF block blob, stored as a uint8 buffer of shape - ``(out_features, n_blocks * block_bytes)``. Gemma linears are bias-free, so - bias is always ``None``. - """ - - def __init__(self, weight_blob: torch.Tensor, format: str = "q6k"): - super().__init__() - if weight_blob.dim() != 2 or weight_blob.dtype != torch.uint8: - raise ValueError( - f"GGUFLinear: weight_blob must be 2-D uint8; got " - f"shape {tuple(weight_blob.shape)} dtype {weight_blob.dtype}" - ) - if format != "q6k": - raise NotImplementedError( - f"GGUFLinear: unsupported format {format!r}; only 'q6k' supported" - ) - row_bytes = int(weight_blob.shape[1]) - if row_bytes % Q6K_BLOCK_BYTES != 0: - raise ValueError( - f"GGUFLinear: weight row bytes {row_bytes} must be a multiple of " - f"{Q6K_BLOCK_BYTES}" - ) - self.format = format - self.out_features = int(weight_blob.shape[0]) - self.in_features = (row_bytes // Q6K_BLOCK_BYTES) * QK_K - # uint8 cannot be a grad-requiring Parameter; store as a buffer so it is - # serialized as a constant in the exported program. - self.register_buffer("weight", weight_blob) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.ops.mlx.gguf_linear(x, self.weight, self.format, None) - - def extra_repr(self) -> str: - return ( - f"in_features={self.in_features}, out_features={self.out_features}, " - f"format={self.format!r}" - ) - - -def replace_with_gguf_linear( - model: nn.Module, - weight_fqn: str, - weight_blob: torch.Tensor, - format: str = "q6k", -) -> None: - """Replace the ``nn.Linear`` owning ``weight_fqn`` with a ``GGUFLinear``. - - ``weight_fqn`` is the fully-qualified name of the ``.weight`` tensor - (e.g. ``model.layers.0.mlp.down_proj.weight``). The parent linear module is - swapped in place on its grandparent module. - """ - parts = weight_fqn.rsplit(".", 1) - if len(parts) != 2 or parts[1] != "weight": - raise ValueError( - f"replace_with_gguf_linear: expected a '*.weight' fqn; got {weight_fqn!r}" - ) - linear_fqn = parts[0] - grandparent_fqn, _, child_name = linear_fqn.rpartition(".") - grandparent = model.get_submodule(grandparent_fqn) if grandparent_fqn else model - setattr(grandparent, child_name, GGUFLinear(weight_blob, format=format)) - - -class GGUFEmbedding(nn.Module): - """``y = gguf_embedding(weight_blob, indices, format)`` for a GGUF table. - - The weight is the **raw** GGUF block blob, stored as a uint8 buffer of shape - ``(num_embeddings, n_blocks * block_bytes)``. ``forward`` returns bfloat16, - matching the model's embedding dtype. - """ - - def __init__(self, weight_blob: torch.Tensor, format: str = "q6k"): - super().__init__() - if weight_blob.dim() != 2 or weight_blob.dtype != torch.uint8: - raise ValueError( - f"GGUFEmbedding: weight_blob must be 2-D uint8; got " - f"shape {tuple(weight_blob.shape)} dtype {weight_blob.dtype}" - ) - if format != "q6k": - raise NotImplementedError( - f"GGUFEmbedding: unsupported format {format!r}; only 'q6k' supported" - ) - row_bytes = int(weight_blob.shape[1]) - if row_bytes % Q6K_BLOCK_BYTES != 0: - raise ValueError( - f"GGUFEmbedding: weight row bytes {row_bytes} must be a multiple of " - f"{Q6K_BLOCK_BYTES}" - ) - self.format = format - self.num_embeddings = int(weight_blob.shape[0]) - self.embedding_dim = (row_bytes // Q6K_BLOCK_BYTES) * QK_K - self.register_buffer("weight", weight_blob) - - def forward(self, indices: torch.Tensor) -> torch.Tensor: - return torch.ops.mlx.gguf_embedding(self.weight, indices, self.format) - - def extra_repr(self) -> str: - return ( - f"num_embeddings={self.num_embeddings}, " - f"embedding_dim={self.embedding_dim}, format={self.format!r}" - ) - - -def replace_with_gguf_embedding( - model: nn.Module, - weight_fqn: str, - weight_blob: torch.Tensor, - format: str = "q6k", -) -> None: - """Replace the ``nn.Embedding`` owning ``weight_fqn`` with a ``GGUFEmbedding``.""" - parts = weight_fqn.rsplit(".", 1) - if len(parts) != 2 or parts[1] != "weight": - raise ValueError( - f"replace_with_gguf_embedding: expected a '*.weight' fqn; " - f"got {weight_fqn!r}" - ) - module_fqn = parts[0] - grandparent_fqn, _, child_name = module_fqn.rpartition(".") - grandparent = model.get_submodule(grandparent_fqn) if grandparent_fqn else model - setattr(grandparent, child_name, GGUFEmbedding(weight_blob, format=format)) diff --git a/examples/models/gemma4_31b/quant/gguf.py b/examples/models/gemma4_31b/quant/gguf.py deleted file mode 100644 index 9eeec1e74d1..00000000000 --- a/examples/models/gemma4_31b/quant/gguf.py +++ /dev/null @@ -1,240 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Unpack GGUF quantized tensors to torchao tensor subclasses. - -Supports Q4_K, Q6_K, F32, and F16 tensor types. Two public APIs: - - - ``unpack_gguf_tensor`` — convert a single tensor - - ``iter_gguf_tensors`` — stream all tensors from a file (low peak memory) - -Model-agnostic. For Gemma 4 31B key mapping and model loading, see -``gguf_loader.py``. -""" - -from collections.abc import Iterator - -import torch - -QK_K = 256 # super-block size for k-quants -Q4_K_GROUPS = 8 # sub-blocks per Q4_K super-block -Q4_K_GROUP_SIZE = QK_K // Q4_K_GROUPS # 32 -Q6_K_GROUPS = 16 # sub-blocks per Q6_K super-block -Q6_K_GROUP_SIZE = QK_K // Q6_K_GROUPS # 16 - - -def _raw_tensor(data: bytes) -> torch.Tensor: - """Wrap a numpy mmap view as a uint8 torch tensor (zero-copy).""" - return torch.frombuffer(memoryview(data), dtype=torch.uint8) - - -def _read_f16(raw: torch.Tensor, col_start: int, col_end: int) -> torch.Tensor: - """Read fp16 field from block bytes, return float32.""" - return raw[:, col_start:col_end].contiguous().view(torch.float16).float() - - -def _unpack_q4_k(data, shape: list[int]) -> torch.Tensor: - """Unpack Q4_K super-blocks into an ``Int4Tensor``. - - Q4_K block layout (144 bytes per 256 values): - - d (2B, fp16): super-block scale - - dmin (2B, fp16): super-block min - - scales (12B): 8 sub-block scales + 8 sub-block mins, 6-bit packed - - qs (128B): 256 4-bit values, two per byte - - Dequant: weight = d * sub_scale * q - dmin * sub_min - """ - from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor - - N, K = shape - assert K % QK_K == 0, f"Q4_K requires K divisible by {QK_K}, got {K}" - n_blocks = N * (K // QK_K) - block_bytes = 2 + 2 + 12 + QK_K // 2 # 144 - raw = _raw_tensor(data).reshape(n_blocks, block_bytes) - - d = _read_f16(raw, 0, 2) - dmin = _read_f16(raw, 2, 4) - s = raw[:, 4:16] - qs = raw[:, 16:144] - - sc = torch.empty(n_blocks, 8, dtype=torch.float32) - mn = torch.empty(n_blocks, 8, dtype=torch.float32) - sc[:, :4] = (s[:, :4] & 0x3F).float() - mn[:, :4] = (s[:, 4:8] & 0x3F).float() - sc[:, 4:] = ((s[:, 8:12] & 0xF) | ((s[:, :4] >> 6) << 4)).float() - mn[:, 4:] = ((s[:, 8:12] >> 4) | ((s[:, 4:8] >> 6) << 4)).float() - del s - - eff_scale = (d * sc).reshape(N, -1) - eff_min = (dmin * mn).reshape(N, -1) - del d, dmin, sc, mn - - zero_std = torch.where( - eff_scale != 0, eff_min / eff_scale, torch.zeros_like(eff_min) - ) - del eff_min - - # GGUF Q4_K nibble order: 32 lows then 32 highs per sub-block pair - low = (qs & 0x0F).to(torch.uint8) - high = ((qs >> 4) & 0x0F).to(torch.uint8) - qdata_unpacked = torch.cat( - [ - low[:, :32], - high[:, :32], - low[:, 32:64], - high[:, 32:64], - low[:, 64:96], - high[:, 64:96], - low[:, 96:128], - high[:, 96:128], - ], - dim=-1, - ).reshape(N, K) - del qs, low, high - - # Nibble-pack for Int4Tensor: even=LOW, odd=HIGH - packed = qdata_unpacked[:, ::2] | (qdata_unpacked[:, 1::2] << 4) - - # Int4Tensor scale/zero layout: (K//gs, N) — transposed - return Int4Tensor( - qdata=packed, - scale=eff_scale.to(torch.bfloat16).t().contiguous(), - zero_point=zero_std.to(torch.bfloat16).t().contiguous(), - block_size=[1, Q4_K_GROUP_SIZE], - shape=torch.Size([N, K]), - ) - - -def _unpack_q6_k(data, shape: list[int]) -> torch.Tensor: - """Unpack Q6_K super-blocks into an ``IntxUnpackedToInt8Tensor``. - - ``data`` may be a raw byte buffer or an already-materialized uint8 tensor of - the block bytes. - - Q6_K block layout (210 bytes per 256 values): - - ql (128B): lower 4 bits of 256 6-bit values - - qh (64B): upper 2 bits of 256 6-bit values - - scales (16B): 16 int8 sub-block scales (groups of 16) - - d (2B, fp16): super-block scale - - Dequant: weight = d * scale_j * (q - 32) - Values are 6-bit [-32, 31], widened to INT8. - """ - from torchao.quantization import IntxUnpackedToInt8Tensor - - N, K = shape - assert K % QK_K == 0, f"Q6_K requires K divisible by {QK_K}, got {K}" - n_blocks = N * (K // QK_K) - block_bytes = 2 + QK_K // 2 + QK_K // 4 + QK_K // 16 # 210 - raw_u8 = data if isinstance(data, torch.Tensor) else _raw_tensor(data) - raw = raw_u8.reshape(n_blocks, block_bytes) - - ql = raw[:, 0:128] - qh = raw[:, 128:192] - sc = raw[:, 192:208] - d = _read_f16(raw, 208, 210) - - qh0 = qh[:, :32] - qh1 = qh[:, 32:64] - qdata = torch.empty(n_blocks, QK_K, dtype=torch.int16) - qdata[:, 0:32] = (ql[:, :32] & 0x0F) | ((qh0 & 0x03) << 4) - qdata[:, 32:64] = (ql[:, 32:64] & 0x0F) | (((qh0 >> 2) & 0x03) << 4) - qdata[:, 64:96] = ((ql[:, :32] >> 4) & 0x0F) | (((qh0 >> 4) & 0x03) << 4) - qdata[:, 96:128] = ((ql[:, 32:64] >> 4) & 0x0F) | (((qh0 >> 6) & 0x03) << 4) - qdata[:, 128:160] = (ql[:, 64:96] & 0x0F) | ((qh1 & 0x03) << 4) - qdata[:, 160:192] = (ql[:, 96:128] & 0x0F) | (((qh1 >> 2) & 0x03) << 4) - qdata[:, 192:224] = ((ql[:, 64:96] >> 4) & 0x0F) | (((qh1 >> 4) & 0x03) << 4) - qdata[:, 224:256] = ((ql[:, 96:128] >> 4) & 0x0F) | (((qh1 >> 6) & 0x03) << 4) - qdata -= 32 - del ql, qh, qh0, qh1 - - # sc bytes are signed int8 scales; reinterpret from uint8 - eff_scale = (d * sc.to(torch.int8).float()).reshape(N, -1) - del d, sc - - return IntxUnpackedToInt8Tensor( - qdata=qdata.reshape(N, K).to(torch.int8), - scale=eff_scale.to(torch.bfloat16), - zero_point=torch.zeros_like(eff_scale, dtype=torch.int8), - target_dtype=torch.int8, - block_size=(1, Q6_K_GROUP_SIZE), - dtype=torch.bfloat16, - activation_quantization=None, - ) - - -def _raw_q6_k(data, shape: list[int]) -> torch.Tensor: - """Return the raw GGUF Q6_K bytes as ``(N, n_blocks*210)`` uint8. - - Unlike ``_unpack_q6_k`` (which dequantizes and deinterleaves into an - ``IntxUnpackedToInt8Tensor``), this preserves the exact GGUF ``block_q6_K`` - byte layout so it can be consumed directly by the fused - ``mlx::gguf_linear`` Metal kernel. - """ - N, K = shape - assert K % QK_K == 0, f"Q6_K requires K divisible by {QK_K}, got {K}" - block_bytes = 2 + QK_K // 2 + QK_K // 4 + QK_K // 16 # 210 - n_blocks = N * (K // QK_K) - raw = _raw_tensor(data).reshape(n_blocks, block_bytes) - return raw.reshape(N, -1).clone() - - -def unpack_gguf_tensor( - tensor_data, - tensor_type, - shape: list[int], -) -> torch.Tensor: - """Unpack a single GGUF tensor. - - Returns an ``Int4Tensor`` for Q4_K, ``IntxUnpackedToInt8Tensor`` for Q6_K, - or a plain ``torch.Tensor`` for F32/F16. - """ - from gguf import GGMLQuantizationType - - if tensor_type == GGMLQuantizationType.Q4_K: - return _unpack_q4_k(tensor_data, shape) - elif tensor_type == GGMLQuantizationType.Q6_K: - return _unpack_q6_k(tensor_data, shape) - elif tensor_type == GGMLQuantizationType.F32: - return _raw_tensor(tensor_data).view(torch.float32).reshape(shape).clone() - elif tensor_type == GGMLQuantizationType.F16: - return ( - _raw_tensor(tensor_data) - .view(torch.float16) - .reshape(shape) - .to(torch.bfloat16) - ) - else: - raise ValueError(f"Unsupported GGUF quant type: {tensor_type}") - - -def iter_gguf_tensors( - path: str, - q6k_raw: bool = False, -) -> Iterator[tuple[str, torch.Tensor, object]]: - """Yield ``(name, result, tensor_type)`` for each tensor in a GGUF file. - - Processes one tensor at a time for low peak memory. Tensor names are - GGUF names (e.g., ``blk.0.attn_q.weight``); the caller handles key - remapping. GGUF shapes are reversed to PyTorch convention automatically. - ``tensor_type`` is the ``gguf.GGMLQuantizationType`` so callers can branch - on the quant format. - - If ``q6k_raw`` is set, Q6_K tensors are yielded as the **raw** - ``(N, n_blocks*210)`` uint8 GGUF block blob (for the fused - ``mlx::gguf_linear`` path) instead of being dequantized/deinterleaved. - """ - from gguf import GGMLQuantizationType, GGUFReader - - reader = GGUFReader(path) - for tensor in reader.tensors: - shape = list(reversed(tensor.shape.tolist())) - if q6k_raw and tensor.tensor_type == GGMLQuantizationType.Q6_K: - # Keep the raw GGUF block bytes for the fused mlx::gguf_linear path. - result = _raw_q6_k(tensor.data, shape) - else: - result = unpack_gguf_tensor(tensor.data, tensor.tensor_type, shape) - yield tensor.name, result, tensor.tensor_type diff --git a/examples/models/gemma4_31b/quant/tests/test_gguf.py b/examples/models/gemma4_31b/quant/tests/test_gguf.py deleted file mode 100644 index 136fcb9ec38..00000000000 --- a/examples/models/gemma4_31b/quant/tests/test_gguf.py +++ /dev/null @@ -1,347 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -"""Unit tests for quant/gguf.py — Q4_K and Q6_K unpacking. - -Tests verify the API contract: dequantized weights match the original -GGUF dequantization formula. Uses synthetic blocks — no GGUF file required. -""" - -import os -import struct -import tempfile -import unittest - -import numpy as np -import torch - -try: - from gguf import GGMLQuantizationType - - _HAS_GGUF = True -except ImportError: - _HAS_GGUF = False - -if _HAS_GGUF: - from executorch.examples.models.gemma4_31b.quant.gguf import unpack_gguf_tensor - -from executorch.examples.models.gemma4_31b.quant.quantize import dequantize_weight -from safetensors import safe_open -from safetensors.torch import save_file -from torchao.prototype.safetensors.safetensors_support import ( - flatten_tensor_state_dict, - unflatten_tensor_state_dict, -) - - -def _make_q4_k_block(d, dmin, sub_scales, sub_mins, qvals): - """Build one Q4_K block (144 bytes) from components.""" - buf = bytearray(144) - struct.pack_into("> 4) << 6 - scales_bytes[j] |= (sub_mins[j] >> 4) << 6 - buf[4:16] = scales_bytes - # GGUF Q4_K nibble order: 32 lows then 32 highs per sub-block pair - for g in range(4): - for i in range(32): - lo_val = qvals[g * 64 + i] - hi_val = qvals[g * 64 + 32 + i] - buf[16 + g * 32 + i] = (lo_val & 0xF) | ((hi_val & 0xF) << 4) - return buf - - -def _make_q6_k_block(d, scales_16, qvals_256): - """Build one Q6_K block (210 bytes) from components. - - ggml processes 128 values at a time. For each 128-value half: - ql: 64 bytes (two groups of 32, low/high nibbles) - qh: 32 bytes (2 bits each for 4 sub-positions) - The qvals_256 array is in output order (position 0..255). - """ - buf = bytearray(210) - # First half (positions 0..127): ql bytes 0..63, qh bytes 0..31 - for i in range(32): - buf[i] = (qvals_256[i] & 0x0F) | ((qvals_256[i + 64] & 0x0F) << 4) - for i in range(32): - buf[32 + i] = (qvals_256[i + 32] & 0x0F) | ((qvals_256[i + 96] & 0x0F) << 4) - for i in range(32): - h0 = (qvals_256[i] >> 4) & 0x03 - h1 = (qvals_256[i + 32] >> 4) & 0x03 - h2 = (qvals_256[i + 64] >> 4) & 0x03 - h3 = (qvals_256[i + 96] >> 4) & 0x03 - buf[128 + i] = h0 | (h1 << 2) | (h2 << 4) | (h3 << 6) - # Second half (positions 128..255): ql bytes 64..127, qh bytes 32..63 - for i in range(32): - buf[64 + i] = (qvals_256[i + 128] & 0x0F) | ((qvals_256[i + 192] & 0x0F) << 4) - for i in range(32): - buf[96 + i] = (qvals_256[i + 160] & 0x0F) | ((qvals_256[i + 224] & 0x0F) << 4) - for i in range(32): - h0 = (qvals_256[i + 128] >> 4) & 0x03 - h1 = (qvals_256[i + 160] >> 4) & 0x03 - h2 = (qvals_256[i + 192] >> 4) & 0x03 - h3 = (qvals_256[i + 224] >> 4) & 0x03 - buf[160 + i] = h0 | (h1 << 2) | (h2 << 4) | (h3 << 6) - # Scales and d - for i in range(16): - buf[192 + i] = scales_16[i] & 0xFF - struct.pack_into(" torch.Tensor: + def _linear(self, N: int, K: int, ggml_type: str) -> nn.Module: from executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear import ( + make_q4_k_blob, make_q6_k_blob, ) + from executorch.extension.llm.export.gguf import ExportableGGUFTensor - return make_q6_k_blob(N, K) - - def test_replace_with_gguf_linear_swaps_module(self): - from executorch.examples.models.gemma4_31b.mlx_gguf_linear import ( - GGUFLinear, - replace_with_gguf_linear, + blob = (make_q6_k_blob if ggml_type == "q6_k" else make_q4_k_blob)(N, K) + lin = nn.Linear(K, N, bias=False).to(torch.bfloat16) + lin.weight = nn.Parameter( + ExportableGGUFTensor.from_raw(blob, ggml_type, torch.bfloat16), + requires_grad=False, ) + return lin.eval() - model = build_random_tiny_model() - # Pick a Linear whose in_features is a multiple of 256 (Q6_K requires - # K % 256 == 0) -- e.g. down_proj (in = intermediate_size). - target_fqn = None - for name, mod in model.named_modules(): - if isinstance(mod, nn.Linear) and mod.in_features % 256 == 0: - target_fqn = name - N, K = mod.out_features, mod.in_features - break - self.assertIsNotNone(target_fqn, "no Linear with in_features % 256 == 0") - - blob = self._make_blob(N, K) - replace_with_gguf_linear(model, target_fqn + ".weight", blob, format="q6k") - - swapped = model.get_submodule(target_fqn) - self.assertIsInstance(swapped, GGUFLinear) - self.assertEqual(swapped.in_features, K) - self.assertEqual(swapped.out_features, N) - self.assertEqual(swapped.weight.dtype, torch.uint8) - - # Forward dispatches to the op (eager fallback) and matches it exactly. - x = torch.randn(2, K, dtype=torch.bfloat16) - y = swapped(x) - ref = torch.ops.mlx.gguf_linear(x, blob, "q6k", None) - self.assertEqual(y.shape, torch.Size([2, N])) - self.assertTrue(torch.equal(y, ref)) - - def test_gguf_linear_delegates_to_mlx(self): + def _assert_delegated(self, model, example, leftovers): + import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 from executorch.backends.mlx import MLXPartitioner - from executorch.examples.models.gemma4_31b.mlx_gguf_linear import GGUFLinear - from executorch.exir import to_edge_transform_and_lower + from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower from torch.export import Dim, export - N, K = 256, 512 - blob = self._make_blob(N, K) - # GGUFLinear passes bias=None, so the lowered node has 3 args (no bias); - # dynamic seq_len exercises the runtime-routed (IfNode) lowering path. - m = GGUFLinear(blob, format="q6k").eval() seq = Dim("seq", min=1, max=8) - ep = export( - m, - (torch.randn(4, K, dtype=torch.bfloat16),), - dynamic_shapes=({0: seq},), - strict=True, + ep = export(model, example, dynamic_shapes=({0: seq},), strict=True) + et = to_edge_transform_and_lower( + ep, + partitioner=[MLXPartitioner()], + compile_config=EdgeCompileConfig(_check_ir_validity=False), ) - et = to_edge_transform_and_lower(ep, partitioner=[MLXPartitioner()]) remaining = [ str(n.target) for n in et.exported_program().graph.nodes - if n.op == "call_function" and "gguf_linear" in str(n.target) + if n.op == "call_function" and any(t in str(n.target) for t in leftovers) ] - self.assertEqual(remaining, [], "gguf_linear was not delegated to MLX") - + self.assertEqual(remaining, [], f"not delegated to MLX: {remaining}") -class TestGgufEmbeddingMlx(unittest.TestCase): - """Q6_K token embedding routes to the fused mlx::gguf_embedding op.""" - - def _make_blob(self, vocab: int, K: int) -> torch.Tensor: - from executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear import ( - make_q6_k_blob, + def test_q6k_linear_delegates(self): + self._assert_delegated( + self._linear(256, 512, "q6_k"), + (torch.randn(4, 512, dtype=torch.bfloat16),), + ("gguf_dequantize", "linear"), ) - return make_q6_k_blob(vocab, K) - - def test_replace_with_gguf_embedding_swaps_module(self): - from executorch.examples.models.gemma4_31b.mlx_gguf_linear import ( - GGUFEmbedding, - replace_with_gguf_embedding, + def test_q4k_linear_delegates(self): + self._assert_delegated( + self._linear(512, 512, "q4_k"), + (torch.randn(4, 512, dtype=torch.bfloat16),), + ("gguf_dequantize", "linear"), ) - model = build_random_tiny_model() - # Q6_K needs embedding_dim % 256 == 0, so use a fixed valid dim (the - # tiny model's hidden_size is smaller); this test exercises the swapped - # module directly, not the full model forward. - vocab, K = 512, 256 - blob = self._make_blob(vocab, K) - replace_with_gguf_embedding(model, "embed_tokens.weight", blob, format="q6k") - - swapped = model.get_submodule("embed_tokens") - self.assertIsInstance(swapped, GGUFEmbedding) - self.assertEqual(swapped.num_embeddings, vocab) - self.assertEqual(swapped.embedding_dim, K) - idx = torch.randint(0, vocab, (2, 4), dtype=torch.int32) - y = swapped(idx) - ref = torch.ops.mlx.gguf_embedding(blob, idx, "q6k") - self.assertEqual(y.shape, torch.Size([2, 4, K])) - self.assertTrue(torch.equal(y, ref)) +class TestGgufEmbeddingMlx(unittest.TestCase): + """GGUF token embeddings (Q6_K + Q4_K) lower through the MLX GGUF pattern.""" - def test_gguf_embedding_delegates_to_mlx(self): + def _assert_delegated(self, ggml_type: str): + import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 from executorch.backends.mlx import MLXPartitioner - from executorch.examples.models.gemma4_31b.mlx_gguf_linear import GGUFEmbedding - from executorch.exir import to_edge_transform_and_lower + from executorch.backends.mlx.custom_kernel_ops.gguf.test.test_linear import ( + make_q4_k_blob, + make_q6_k_blob, + ) + from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower + from executorch.extension.llm.export.gguf import ExportableGGUFTensor from torch.export import Dim, export - m = GGUFEmbedding(self._make_blob(512, 256), format="q6k").eval() + vocab, K = 512, 256 + blob = (make_q6_k_blob if ggml_type == "q6_k" else make_q4_k_blob)(vocab, K) + emb = nn.Embedding(vocab, K) + emb.weight = nn.Parameter( + ExportableGGUFTensor.from_raw(blob, ggml_type, torch.bfloat16), + requires_grad=False, + ) + emb = emb.eval() seq = Dim("seq", min=1, max=8) ep = export( - m, - (torch.randint(0, 512, (4,), dtype=torch.int32),), + emb, + (torch.randint(0, vocab, (4,), dtype=torch.int64),), dynamic_shapes=({0: seq},), strict=True, ) - et = to_edge_transform_and_lower(ep, partitioner=[MLXPartitioner()]) + et = to_edge_transform_and_lower( + ep, + partitioner=[MLXPartitioner()], + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) remaining = [ str(n.target) for n in et.exported_program().graph.nodes - if n.op == "call_function" and "gguf_embedding" in str(n.target) + if n.op == "call_function" + and any(t in str(n.target) for t in ("gguf_dequantize", "embedding")) ] - self.assertEqual(remaining, [], "gguf_embedding was not delegated to MLX") + self.assertEqual(remaining, [], f"not delegated to MLX: {remaining}") + + def test_q6k_embedding_delegates(self): + self._assert_delegated("q6_k") + + def test_q4k_embedding_delegates(self): + self._assert_delegated("q4_k") if __name__ == "__main__": diff --git a/extension/llm/export/gguf.py b/extension/llm/export/gguf.py index 2d922472641..9a0b34cdd6c 100644 --- a/extension/llm/export/gguf.py +++ b/extension/llm/export/gguf.py @@ -264,9 +264,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> Tensor: def to_int4_tensor(self) -> Tensor: """Convert a Q4_K tensor to a torchao ``Int4Tensor``.""" - from torchao.quantization.quantize_.workflows.int4.int4_tensor import ( - Int4Tensor, - ) + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor if self.ggml_type != "q4_k": raise NotImplementedError( diff --git a/extension/llm/export/test/test_gguf.py b/extension/llm/export/test/test_gguf.py index e30eb786483..dff987f0749 100644 --- a/extension/llm/export/test/test_gguf.py +++ b/extension/llm/export/test/test_gguf.py @@ -50,7 +50,9 @@ def _fp16_bytes(x: float) -> torch.Tensor: def _make_q4k_raw(N: int, nb: int, seed: int = 0) -> torch.Tensor: """A ``(N, nb*144)`` uint8 Q4_K blob with sane, deterministic magnitudes.""" g = torch.Generator().manual_seed(seed) - blk = torch.randint(0, 256, (N * nb, _Q4_K_BLOCK_BYTES), dtype=torch.uint8, generator=g) + blk = torch.randint( + 0, 256, (N * nb, _Q4_K_BLOCK_BYTES), dtype=torch.uint8, generator=g + ) blk[:, 0:2] = _fp16_bytes(0.01) # d blk[:, 2:4] = _fp16_bytes(0.01) # dmin blk[:, 4:16] = 0x21 # fixed mid-range 6-bit sub-scales/mins (non-zero) @@ -60,7 +62,9 @@ def _make_q4k_raw(N: int, nb: int, seed: int = 0) -> torch.Tensor: def _make_q6k_raw(N: int, nb: int, seed: int = 0) -> torch.Tensor: """A ``(N, nb*210)`` uint8 Q6_K blob with sane, deterministic magnitudes.""" g = torch.Generator().manual_seed(seed) - blk = torch.randint(0, 256, (N * nb, _Q6_K_BLOCK_BYTES), dtype=torch.uint8, generator=g) + blk = torch.randint( + 0, 256, (N * nb, _Q6_K_BLOCK_BYTES), dtype=torch.uint8, generator=g + ) blk[:, 192:208] = 0x10 # fixed int8 sub-scales (non-zero) blk[:, 208:210] = _fp16_bytes(0.01) # d return blk.reshape(N, nb * _Q6_K_BLOCK_BYTES) @@ -71,7 +75,12 @@ def _gguf_ref(raw: torch.Tensor, qtype) -> torch.Tensor: def _int4_to_float(w) -> torch.Tensor: - """Dequantize an ``Int4Tensor`` from its stored fields (no fbgemm needed).""" + """Dequantize an ``Int4Tensor`` from its stored fields. + + ``Int4Tensor`` has no working ``dequantize()`` on CPU (``aten.dequantize`` is + unimplemented and the linear path needs fbgemm), so reconstruct directly + from its public fields (this still exercises our nibble-packing). + """ N, K = int(w.shape[0]), int(w.shape[1]) gs = w.block_size[1] q = torch.empty(N, K, dtype=torch.float32) @@ -82,10 +91,6 @@ def _int4_to_float(w) -> torch.Tensor: return scale * (q - zero) -def _rel_max(a: torch.Tensor, b: torch.Tensor) -> float: - return (a - b).abs().max().item() / (b.abs().max().item() + 1e-9) - - @unittest.skipUnless(_HAS_GGUF, "gguf package not installed") class TestExportableGGUFTensor(unittest.TestCase): def test_dequantize_matches_gguf(self): @@ -109,8 +114,16 @@ def test_to_intx_unpacked_matches_reference(self): t = ExportableGGUFTensor.from_raw(raw, ggml_type) ix = t.to_intx_unpacked_to_int8_tensor() self.assertEqual(tuple(ix.shape), (3, 512)) - rel = _rel_max(ix.dequantize().float(), t.dequantize(torch.float32)) - self.assertLess(rel, 1e-2, f"{ggml_type} to_intx rel err {rel}") + # bf16 storage tolerance. + self.assertTrue( + torch.allclose( + ix.dequantize().float(), + t.dequantize(torch.float32), + rtol=1e-2, + atol=5e-2, + ), + ggml_type, + ) def test_to_int4_tensor_matches_reference(self): raw = _make_q4k_raw(N=3, nb=2) @@ -120,8 +133,14 @@ def test_to_int4_tensor_matches_reference(self): self.assertEqual(list(w.block_size), [1, Q4_K_GROUP_SIZE]) # Int4Tensor has no CPU dequantize(); reconstruct from its packed fields # (this still exercises our nibble-packing) against the gguf reference. - rel = _rel_max(_int4_to_float(w), t.dequantize(torch.float32)) - self.assertLess(rel, 1e-2, f"Q4_K to_int4 rel err {rel}") + self.assertTrue( + torch.allclose( + _int4_to_float(w), + t.dequantize(torch.float32), + rtol=1e-2, + atol=5e-2, + ) + ) def test_gguf_dequantize_op_matches_reference(self): for ggml_type, make in (("q4_k", _make_q4k_raw), ("q6_k", _make_q6k_raw)): @@ -189,9 +208,9 @@ def __init__(self): def forward(self, idx): return torch.nn.functional.embedding(idx, self.w) - ep = torch.export.export( - M(), (torch.tensor([0, 1, 2, 3]),) - ).run_decompositions({}) + ep = torch.export.export(M(), (torch.tensor([0, 1, 2, 3]),)).run_decompositions( + {} + ) self.assertIn("torchao.gguf_dequantize.default", self._targets(ep)) From 15c2b726274c4b12fde743ba65a2855c43dc25cb Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Sun, 7 Jun 2026 11:59:09 -0700 Subject: [PATCH 14/18] up --- backends/mlx/builder/op_helpers.py | 101 ++++++ .../mlx/custom_kernel_ops/gguf/__init__.py | 2 +- .../mlx/custom_kernel_ops/gguf/patterns.py | 26 +- .../mlx/custom_kernel_ops/gguf/q4k/common.py | 4 +- .../custom_kernel_ops/gguf/q4k/embedding.py | 73 +--- .../mlx/custom_kernel_ops/gguf/q4k/linear.py | 9 +- .../custom_kernel_ops/gguf/q6k/__init__.py | 2 +- .../custom_kernel_ops/gguf/q6k/embedding.py | 4 +- .../mlx/custom_kernel_ops/gguf/q6k/linear.py | 4 +- .../gguf/test/test_embedding.py | 4 +- .../gguf/test/test_linear.py | 10 +- backends/mlx/patterns.py | 313 +++++++++++++----- backends/mlx/test/test_ops.py | 155 +++++++++ examples/models/gemma4_31b/export.py | 1 + examples/models/gemma4_31b/quant/pack_mlx.py | 56 +--- .../gemma4_31b/tests/test_mlx_pipeline.py | 67 +++- extension/llm/export/gguf.py | 20 +- extension/llm/export/int4.py | 142 ++++++++ extension/llm/export/test/test_gguf.py | 14 +- extension/llm/export/test/test_int4.py | 125 +++++++ 20 files changed, 890 insertions(+), 242 deletions(-) create mode 100644 extension/llm/export/int4.py create mode 100644 extension/llm/export/test/test_int4.py diff --git a/backends/mlx/builder/op_helpers.py b/backends/mlx/builder/op_helpers.py index be199f75340..2f94a808adc 100644 --- a/backends/mlx/builder/op_helpers.py +++ b/backends/mlx/builder/op_helpers.py @@ -329,6 +329,79 @@ def emit_quantized_biases( return biases +def emit_quantized_gather( + P: MLXProgramBuilder, + out: Slot, + indices_slot: Slot, + qdata_slot: Slot, + scales_slot: Slot, + biases_slot: Optional[Slot], + *, + group_size: int, + bits: int, + mode: str, + out_dtype: torch.dtype, +) -> None: + """Gather quantized rows by index and dequantize them into ``out``. + + Emits ``TakeNode`` for qdata and scales (and biases when present), then a + ``DequantizeNode``. + """ + from executorch.backends.mlx.serialization.mlx_graph_schema import ( + DequantizeNode, + IntOrVidOrTid, + TakeNode, + ) + + ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(indices_slot)) + + _, wq_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(qdata_slot), + index=ids_index, + out=P.slot_to_tid(wq_sel), + axis=0, + ) + ) + + _, sc_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(scales_slot), + index=ids_index, + out=P.slot_to_tid(sc_sel), + axis=0, + ) + ) + + biases_tid = None + if biases_slot is not None: + _, b_sel = P.make_tmp_slot() + P.emit( + TakeNode( + x=P.slot_to_tid(biases_slot), + index=ids_index, + out=P.slot_to_tid(b_sel), + axis=0, + ) + ) + biases_tid = P.slot_to_tid(b_sel) + + P.emit( + DequantizeNode( + w=P.slot_to_tid(wq_sel), + scales=P.slot_to_tid(sc_sel), + out=P.slot_to_tid(out), + biases=biases_tid, + group_size=group_size, + bits=bits, + mode=mode, + dtype=torch_dtype_to_scalar_type(out_dtype), + ) + ) + + def to_mlx_qparams( qdata: torch.Tensor, scale: torch.Tensor, @@ -421,6 +494,34 @@ def parse_dequant_nvfp4_node( return qdata, scale, per_tensor_scale, output_dtype +def parse_dequant_int4_node( + node: Node, +) -> Optional[Tuple[Node, Node, Node, int, Optional[torch.dtype]]]: + """Parse a torchao.dequantize_int4_tensor node. + + Returns (qdata, scale, zero_point, group_size, output_dtype) or None if not a + dequantize_int4_tensor node or the custom op is not registered. + """ + target = get_aten_target(node.target) + try: + import executorch.extension.llm.export.int4 # noqa: F401 + except ImportError: + return None + + if target is not torch.ops.torchao.dequantize_int4_tensor.default: + return None + + qdata, scale, zero_point, group_size = node.args[0:4] + + output_dtype = None + if len(node.args) > 4: + output_dtype = node.args[4] + elif "output_dtype" in node.kwargs: + output_dtype = node.kwargs["output_dtype"] + + return qdata, scale, zero_point, group_size, output_dtype + + def parse_dequant_node( node: Node, ) -> Optional[Tuple[Node, Node, Node, int, int, Optional[torch.dtype], int]]: diff --git a/backends/mlx/custom_kernel_ops/gguf/__init__.py b/backends/mlx/custom_kernel_ops/gguf/__init__.py index 4f8268ae88c..0e6ec0f01a4 100644 --- a/backends/mlx/custom_kernel_ops/gguf/__init__.py +++ b/backends/mlx/custom_kernel_ops/gguf/__init__.py @@ -14,7 +14,7 @@ Metal header) and the fused mat-vec / mat-mat / gather kernels. Importing ``.q6k`` (or ``.q6k.common``) is lightweight and does not touch the registry. * :mod:`.patterns` -- registers MLX pattern handlers that match - ``torchao::gguf_dequantize -> linear/embedding`` (what ``ExportableGGUFTensor`` + ``torchao::dequantize_gguf -> linear/embedding`` (what ``ExportableGGUFTensor`` exports) and lower them to the Q6_K kernels. To enable GGUF lowering, import :mod:`.patterns` for its side effect:: diff --git a/backends/mlx/custom_kernel_ops/gguf/patterns.py b/backends/mlx/custom_kernel_ops/gguf/patterns.py index 10afef99e11..58f6fa41a7e 100644 --- a/backends/mlx/custom_kernel_ops/gguf/patterns.py +++ b/backends/mlx/custom_kernel_ops/gguf/patterns.py @@ -11,10 +11,10 @@ ``ExportableGGUFTensor`` (extension/llm/export/gguf.py) lowers a quantized linear/embedding to:: - linear(x, torchao::gguf_dequantize(weight, ggml_type, out_dtype), bias) - embedding(torchao::gguf_dequantize(weight, ggml_type, out_dtype), indices) + linear(x, torchao::dequantize_gguf(weight, ggml_type, out_dtype), bias) + embedding(torchao::dequantize_gguf(weight, ggml_type, out_dtype), indices) -These handlers match that ``gguf_dequantize -> linear/embedding`` subgraph and +These handlers match that ``dequantize_gguf -> linear/embedding`` subgraph and lower it without materializing the dequantized weight: * **Q6_K** -> fused custom Metal kernels in :mod:`.q6k` (linear + embedding). @@ -46,20 +46,20 @@ _EMBEDDING_TYPES = {"q4_k", "q6_k"} -def parse_gguf_dequantize_node( +def parse_dequantize_gguf_node( node: Node, ) -> Optional[Tuple[Node, str, torch.dtype]]: - """Parse a ``torchao::gguf_dequantize`` node. + """Parse a ``torchao::dequantize_gguf`` node. Returns ``(weight_node, ggml_type, output_dtype)`` or ``None`` if ``node`` is - not a ``gguf_dequantize`` node (or the op isn't registered). + not a ``dequantize_gguf`` node (or the op isn't registered). """ try: import executorch.extension.llm.export.gguf # noqa: F401 registers the op except ImportError: return None - if get_aten_target(node.target) is not torch.ops.torchao.gguf_dequantize.default: + if get_aten_target(node.target) is not torch.ops.torchao.dequantize_gguf.default: return None weight = node.args[0] @@ -74,9 +74,9 @@ def parse_gguf_dequantize_node( @REGISTRY.register_pattern(name="GGUF_QUANTIZED_LINEAR") class GGUFQuantizedLinearHandler(PatternHandler): - """Lower ``gguf_dequantize + linear`` to a fused quantized matmul. + """Lower ``dequantize_gguf + linear`` to a fused quantized matmul. - Matches ``linear(x, gguf_dequantize(weight, ggml_type, out_dtype), bias)`` + Matches ``linear(x, dequantize_gguf(weight, ggml_type, out_dtype), bias)`` and dispatches on ``ggml_type``: Q6_K -> custom Metal kernels, Q4_K -> MLX 4-bit ``quantized_matmul``. """ @@ -96,7 +96,7 @@ def maybe_create(cls, ep: ExportedProgram, head: Node): dequant = head.args[1] if not has_single_user(dequant): return None - parsed = parse_gguf_dequantize_node(dequant) + parsed = parse_dequantize_gguf_node(dequant) if parsed is None: return None weight, ggml_type, output_dtype = parsed @@ -121,11 +121,11 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: @REGISTRY.register_pattern(name="GGUF_QUANTIZED_EMBEDDING") class GGUFQuantizedEmbeddingHandler(PatternHandler): - """Fuse ``gguf_dequantize + embedding`` into the Q6_K gather kernel. + """Fuse ``dequantize_gguf + embedding`` into the Q6_K gather kernel. Matches:: - embedding(gguf_dequantize(weight, "q6_k", out_dtype), indices) + embedding(dequantize_gguf(weight, "q6_k", out_dtype), indices) """ def __init__(self, head, body, weight, ggml_type, output_dtype): @@ -143,7 +143,7 @@ def maybe_create(cls, ep: ExportedProgram, head: Node): dequant = head.args[0] if not has_single_user(dequant): return None - parsed = parse_gguf_dequantize_node(dequant) + parsed = parse_dequantize_gguf_node(dequant) if parsed is None: return None weight, ggml_type, output_dtype = parsed diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/common.py b/backends/mlx/custom_kernel_ops/gguf/q4k/common.py index c59df0e6183..d58a8b71afd 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q4k/common.py +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/common.py @@ -26,7 +26,9 @@ _BITS = 4 -def _repack_mlx(P: MLXProgramBuilder, weight_node: Node) -> Tuple[Slot, Slot, Slot, int]: +def _repack_mlx( + P: MLXProgramBuilder, weight_node: Node +) -> Tuple[Slot, Slot, Slot, int]: """Unpack a raw Q4_K blob and repack into MLX qparam constants. Returns ``(packed_slot, scales_slot, biases_slot, group_size)``. diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py b/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py index 0f3b1feddfa..7b5bbcff0e1 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/embedding.py @@ -8,26 +8,18 @@ """GGUF **Q4_K** embedding lowering via MLX's native 4-bit quantized gather. -Lowers a ``gguf_dequantize -> embedding`` pattern to a quantized gather: gather -the packed quants / scales / biases by index (``TakeNode``), then dequantize the -gathered rows (``DequantizeNode``, mode "affine"). The GGUF blob is repacked into -MLX qparams at export time (see :mod:`.common`). +Lowers a ``dequantize_gguf -> embedding`` pattern to a quantized gather: gather +the packed quants / scales / biases by index, then dequantize the gathered rows +(``DequantizeNode``, mode "affine"). The GGUF blob is repacked into MLX qparams +at export time (see :mod:`.common`). """ from __future__ import annotations -from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type +from executorch.backends.mlx.builder.op_helpers import emit_quantized_gather from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder from executorch.backends.mlx.builder.slot_manager import Slot -from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.common import ( - _BITS, - _repack_mlx, -) -from executorch.backends.mlx.serialization.mlx_graph_schema import ( - DequantizeNode, - IntOrVidOrTid, - TakeNode, -) +from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.common import _BITS, _repack_mlx from torch.fx.node import Node @@ -38,7 +30,7 @@ def emit_embedding( indices_node: Node, output_dtype, ) -> Slot: - """Lower a Q4_K ``gguf_dequantize -> embedding`` pattern to a quantized gather. + """Lower a Q4_K ``dequantize_gguf -> embedding`` pattern to a quantized gather. Gathers the packed quants / scales / biases by index, then dequantizes the gathered rows (MLX affine 4-bit) -- the same shape as MLX's generic quantized @@ -46,47 +38,18 @@ def emit_embedding( """ w_slot, scales_slot, biases_slot, group_size = _repack_mlx(P, weight_node) (indices_slot,) = P.slot_map([indices_node]) - ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(indices_slot)) - - _, wq_sel = P.make_tmp_slot() - P.emit( - TakeNode( - x=P.slot_to_tid(w_slot), - index=ids_index, - out=P.slot_to_tid(wq_sel), - axis=0, - ) - ) - _, sc_sel = P.make_tmp_slot() - P.emit( - TakeNode( - x=P.slot_to_tid(scales_slot), - index=ids_index, - out=P.slot_to_tid(sc_sel), - axis=0, - ) - ) - _, b_sel = P.make_tmp_slot() - P.emit( - TakeNode( - x=P.slot_to_tid(biases_slot), - index=ids_index, - out=P.slot_to_tid(b_sel), - axis=0, - ) - ) out = P.make_or_get_slot(head) - P.emit( - DequantizeNode( - w=P.slot_to_tid(wq_sel), - scales=P.slot_to_tid(sc_sel), - out=P.slot_to_tid(out), - biases=P.slot_to_tid(b_sel), - group_size=group_size, - bits=_BITS, - mode="affine", - dtype=torch_dtype_to_scalar_type(output_dtype), - ) + emit_quantized_gather( + P, + out, + indices_slot, + w_slot, + scales_slot, + biases_slot, + group_size=group_size, + bits=_BITS, + mode="affine", + out_dtype=output_dtype, ) return out diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py b/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py index 3db41cfc3d1..41d032a2d4a 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/linear.py @@ -8,7 +8,7 @@ """GGUF **Q4_K** linear lowering via MLX's native 4-bit quantized matmul. -Lowers a ``gguf_dequantize -> linear`` pattern to a ``QuantizedMatmulNode`` +Lowers a ``dequantize_gguf -> linear`` pattern to a ``QuantizedMatmulNode`` (mode "affine", group_size 32); the GGUF blob is repacked into MLX qparams at export time (see :mod:`.common`). """ @@ -20,10 +20,7 @@ from executorch.backends.mlx.builder.op_helpers import torch_dtype_to_scalar_type from executorch.backends.mlx.builder.program_builder import MLXProgramBuilder from executorch.backends.mlx.builder.slot_manager import Slot -from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.common import ( - _BITS, - _repack_mlx, -) +from executorch.backends.mlx.custom_kernel_ops.gguf.q4k.common import _BITS, _repack_mlx from executorch.backends.mlx.serialization.mlx_graph_schema import ( AddNode, AsTypeNode, @@ -39,7 +36,7 @@ def emit_linear( weight_node: Node, bias_node: Optional[Node], ) -> Slot: - """Lower a Q4_K ``gguf_dequantize -> linear`` pattern to MLX 4-bit matmul. + """Lower a Q4_K ``dequantize_gguf -> linear`` pattern to MLX 4-bit matmul. ``weight_node`` is the raw GGUF blob constant; ``head`` is the ``aten.linear`` node. The blob is repacked into MLX qparams at export time, so only the diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py b/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py index 0362a946cc7..2b809e3b16a 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py @@ -14,7 +14,7 @@ * :mod:`.linear` -- Q6_K mat-vec/mat-mat kernels + ``emit_linear`` lowering. * :mod:`.embedding` -- Q6_K gather kernel + ``emit_embedding`` lowering. -The pattern handlers that match ``torchao::gguf_dequantize -> linear/embedding`` +The pattern handlers that match ``torchao::dequantize_gguf -> linear/embedding`` and call these ``emit_*`` functions live one level up in ``custom_kernel_ops.gguf.patterns``. ``.linear`` / ``.embedding`` are intentionally NOT imported here so importing :mod:`.common` for the pure-torch diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py b/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py index 52e68aea427..64177392eb0 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py @@ -11,7 +11,7 @@ Provides the Q6_K embedding lowering used by the MLX GGUF pattern handler (:mod:`..patterns`): -* :func:`emit_embedding` -- lowers a ``gguf_dequantize -> embedding`` pattern to +* :func:`emit_embedding` -- lowers a ``dequantize_gguf -> embedding`` pattern to a fused Q6_K gather Metal kernel. This is the gather counterpart to :mod:`.linear` and exists because MLX's affine @@ -66,7 +66,7 @@ def emit_embedding( indices_node: Node, output_dtype: torch.dtype, ) -> Slot: - """Lower a Q6_K ``gguf_dequantize`` -> ``embedding`` pattern to a fused gather. + """Lower a Q6_K ``dequantize_gguf`` -> ``embedding`` pattern to a fused gather. ``weight_node`` is the raw GGUF blob (the dequantize op's weight input) and ``head`` is the ``aten.embedding`` node that owns the output slot. diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py b/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py index c8cf817638f..99a82053e90 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/linear.py @@ -12,7 +12,7 @@ (:mod:`..patterns`): * :func:`eager_linear` -- pure-torch reference (``x @ dequant(weight)^T``). -* :func:`emit_linear` -- lowers a ``gguf_dequantize -> linear`` pattern to fused +* :func:`emit_linear` -- lowers a ``dequantize_gguf -> linear`` pattern to fused Q6_K Metal kernels. Compute is keyed on the activation dtype (matching GGUF/llama.cpp): the Metal @@ -431,7 +431,7 @@ def emit_linear( weight_node: Node, bias_node: Optional[Node], ) -> Slot: - """Lower a Q6_K ``gguf_dequantize`` -> ``linear`` pattern to fused kernels. + """Lower a Q6_K ``dequantize_gguf`` -> ``linear`` pattern to fused kernels. ``weight_node`` is the raw GGUF blob (the dequantize op's weight input) and ``head`` is the ``aten.linear`` node that owns the output slot. diff --git a/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py index a6586022a87..3f8e60b7aa8 100644 --- a/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_embedding.py @@ -9,7 +9,7 @@ Tests for the GGUF Q6_K embedding lowering. An ``nn.Embedding`` whose weight is an ``ExportableGGUFTensor`` exports to -``embedding(torchao::gguf_dequantize(weight, "q6_k", ...), indices)``. The MLX +``embedding(torchao::dequantize_gguf(weight, "q6_k", ...), indices)``. The MLX ``GGUF_QUANTIZED_EMBEDDING`` pattern matches that subgraph and lowers it to the fused Q6_K gather Metal kernel. These tests compare the kernel against the eager reference (``gguf``-package dequant + ``F.embedding``) on the same packed table. @@ -82,7 +82,7 @@ def get_test_configs(cls) -> List["GGUFEmbeddingTest"]: def get_edge_compile_config(self): from executorch.exir import EdgeCompileConfig - # The gguf_dequantize custom op isn't a core ATen op; skip IR validity. + # The dequantize_gguf custom op isn't a core ATen op; skip IR validity. return EdgeCompileConfig(_check_ir_validity=False) def create_model(self) -> nn.Module: diff --git a/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py index a1663958a09..daaaef1491c 100644 --- a/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py @@ -9,7 +9,7 @@ Tests for the GGUF Q6_K linear lowering. A linear whose weight is an ``ExportableGGUFTensor`` (extension/llm/export/gguf) -exports to ``linear(x, torchao::gguf_dequantize(weight, "q6_k", ...), bias)``. +exports to ``linear(x, torchao::dequantize_gguf(weight, "q6_k", ...), bias)``. The MLX ``GGUF_QUANTIZED_LINEAR`` pattern (custom_kernel_ops/gguf/patterns.py) matches that subgraph and lowers it to the fused Q6_K Metal kernels (mat-vec for decode, mat-mat for prefill). These tests compare the fused kernels against the @@ -136,7 +136,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def _edge_compile_config(): from executorch.exir import EdgeCompileConfig - # The gguf_dequantize custom op isn't a core ATen op; skip IR validity. + # The dequantize_gguf custom op isn't a core ATen op; skip IR validity. return EdgeCompileConfig(_check_ir_validity=False) @@ -270,7 +270,7 @@ def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: def _eager_sanity() -> None: - """Quick CPU check: the subclass linear exports to gguf_dequantize.""" + """Quick CPU check: the subclass linear exports to dequantize_gguf.""" model = GGUFLinearModel(_make_gguf_linear_model(4, 512, torch.bfloat16, bias=True)) x = torch.randn(3, 512, dtype=torch.bfloat16) out = model(x) @@ -279,8 +279,8 @@ def _eager_sanity() -> None: ) ep = torch.export.export(model, (x,)).run_decompositions({}) targets = {str(n.target) for n in ep.graph.nodes if n.op == "call_function"} - assert "torchao.gguf_dequantize.default" in targets, targets - print("export contains torchao.gguf_dequantize: OK") + assert "torchao.dequantize_gguf.default" in targets, targets + print("export contains torchao.dequantize_gguf: OK") if __name__ == "__main__": # noqa: C901 diff --git a/backends/mlx/patterns.py b/backends/mlx/patterns.py index 5f74cbea643..dcc4f4d7d30 100644 --- a/backends/mlx/patterns.py +++ b/backends/mlx/patterns.py @@ -21,7 +21,9 @@ import torch from executorch.backends.mlx.builder.op_helpers import ( emit_quantized_biases, + emit_quantized_gather, emit_stop_position, + parse_dequant_int4_node, parse_dequant_node, parse_dequant_nvfp4_node, to_mlx_qparams, @@ -44,7 +46,6 @@ DequantizeNode, IndexCopyNode, IntOrVid, - IntOrVidOrTid, ModIntNode, MultiplyNode, QuantizedMatmulNode, @@ -53,13 +54,40 @@ SliceUpdateNode, SubtractIntNode, SymSizeNode, - TakeNode, TransposeNode, ) from torch.export.exported_program import ExportedProgram from torch.fx.node import Node +def _unpack_int4_to_intx_fields( + qdata_packed: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Convert ``Int4Tensor`` packed fields to the IntxUnpacked layout for + :func:`to_mlx_qparams`. + + Input is the torchao ``Int4Tensor`` layout: ``qdata_packed`` ``(N, K//2)`` uint8 + (two nibbles/byte, even index -> low nibble, unsigned [0, 15]) and ``scale`` / + ``zero_point`` ``(K // gs, N)`` (zero_point unsigned [0, 15]). + + Returns ``(qdata, scale, zero_point)`` where ``qdata`` is ``(N, K)`` int8 in + [-8, 7], and ``scale`` / ``zero_point`` are ``(N, K // gs)`` (zero_point + centered by -8). ``zero_point`` keeps its original (possibly fractional, e.g. + HQQ) dtype -- it is only used in :func:`to_mlx_qparams`'s float bias math, so + it must not be truncated to int. The affine identity ``scale * (q - z)`` is + preserved. + """ + p = qdata_packed.view(torch.uint8) + low = (p & 0x0F).to(torch.int8) + high = ((p >> 4) & 0x0F).to(torch.int8) + q = torch.stack([low, high], dim=-1).reshape(p.shape[0], -1) - 8 + scale_nk = scale.t().contiguous() + zero_point_nk = zero_point.t().contiguous() - 8 + return q, scale_nk, zero_point_nk + + @REGISTRY.register_pattern(name="INDEX_COPY") class IndexCopyHandler(PatternHandler): """ @@ -600,43 +628,18 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: [x_node, self.scale, self.per_tensor_scale, self.qdata] ) - ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(x)) - - # Gather quantized weights by indices - _, wq_sel = P.make_tmp_slot() - P.emit( - TakeNode( - x=P.slot_to_tid(qdata_slot), - index=ids_index, - out=P.slot_to_tid(wq_sel), - axis=0, - ) - ) - - # Gather scales by indices - _, sc_sel = P.make_tmp_slot() - P.emit( - TakeNode( - x=P.slot_to_tid(scales_slot), - index=ids_index, - out=P.slot_to_tid(sc_sel), - axis=0, - ) - ) - - # Dequantize the gathered slices out = P.make_or_get_slot(n) - P.emit( - DequantizeNode( - w=P.slot_to_tid(wq_sel), - scales=P.slot_to_tid(sc_sel), - out=P.slot_to_tid(out), - biases=None, - group_size=16, - bits=4, - mode="nvfp4", - dtype=torch_dtype_to_scalar_type(self.output_dtype), - ) + emit_quantized_gather( + P, + out, + x, + qdata_slot, + scales_slot, + None, + group_size=16, + bits=4, + mode="nvfp4", + out_dtype=self.output_dtype, ) if has_per_tensor_scale: @@ -1060,7 +1063,7 @@ def maybe_create( def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: assert n == self.head - w, x = n.args[0:2] + indices_node = n.args[1] qdata_target, qdata = P.get_placeholder_target_and_tensor(self.qdata) zero_point_target, zero_point = P.get_placeholder_target_and_tensor( @@ -1069,62 +1072,25 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: _, scale = P.get_placeholder_target_and_tensor(self.scale) Q, B = to_mlx_qparams(qdata, scale, zero_point, self.bits) - out_scalar_type = torch_dtype_to_scalar_type(self.out_dtype) - w = P.make_or_get_constant(f"{qdata_target}_to_packed", Q) - x, scale_slot = P.slot_map([x, self.scale]) + indices_slot, scale_slot = P.slot_map([indices_node, self.scale]) biases = emit_quantized_biases( P, zero_point_target, scale, zero_point, self.bits, B, scale_slot ) - ids_index = IntOrVidOrTid.from_tid(P.slot_to_tid(x)) - - # Gather quantized weights by ids - _, wq_sel = P.make_tmp_slot() - P.emit( - TakeNode( - x=P.slot_to_tid(w), - index=ids_index, - out=P.slot_to_tid(wq_sel), - axis=0, - ) - ) - - # Gather scales by ids - _, sc_sel = P.make_tmp_slot() - P.emit( - TakeNode( - x=P.slot_to_tid(scale_slot), - index=ids_index, - out=P.slot_to_tid(sc_sel), - axis=0, - ) - ) - - # Gather biases by ids - _, b_sel = P.make_tmp_slot() - P.emit( - TakeNode( - x=P.slot_to_tid(biases), - index=ids_index, - out=P.slot_to_tid(b_sel), - axis=0, - ) - ) - # Dequantize the gathered slices out = P.make_or_get_slot(n) - P.emit( - DequantizeNode( - w=P.slot_to_tid(wq_sel), - scales=P.slot_to_tid(sc_sel), - out=P.slot_to_tid(out), - biases=P.slot_to_tid(b_sel), - group_size=self.group_size, - bits=self.bits, - mode="affine", - dtype=out_scalar_type, - ) + emit_quantized_gather( + P, + out, + indices_slot, + w, + scale_slot, + biases, + group_size=self.group_size, + bits=self.bits, + mode="affine", + out_dtype=self.out_dtype, ) return out @@ -1228,3 +1194,174 @@ def __call__(self, P, n): ) return out + + +@REGISTRY.register_pattern(name="INT4_QUANTIZED_LINEAR") +class Int4QuantizedLinearHandler(PatternHandler): + """Fuse dequantize_int4_tensor + linear into QuantizedMatmulNode(mode="affine"). + + Matches:: + + linear(x, dequantize_int4_tensor(qdata, scale, zero_point, group_size), bias) + + The nibble-packed Int4 weight is unpacked and repacked into MLX 4-bit qparams + at export time. + """ + + def __init__(self, head, body, qdata, scale, zero_point, group_size, out_dtype): + super().__init__(head, body) + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.group_size = group_size + self.out_dtype = out_dtype + + _MIN_FUSED_GROUP_SIZE = 32 + + @staticmethod + def _allow_non_fused() -> bool: + return os.environ.get("ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS", "0") == "1" + + @classmethod + def maybe_create(cls, ep, head): + if not match_target(head, torch.ops.aten.linear.default): + return None + if len(head.args) < 2 or not isinstance(head.args[1], Node): + return None + dequant = head.args[1] + if not has_single_user(dequant): + return None + parsed = parse_dequant_int4_node(dequant) + if parsed is None: + return None + qdata, scale, zero_point, group_size, out_dtype = parsed + return cls(head, [dequant], qdata, scale, zero_point, group_size, out_dtype) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + x_node = n.args[0] + b_node = n.args[2] if len(n.args) > 2 else None + + qdata_target, qdata_packed = P.get_placeholder_target_and_tensor(self.qdata) + zp_target, zero_point = P.get_placeholder_target_and_tensor(self.zero_point) + _, scale = P.get_placeholder_target_and_tensor(self.scale) + + q, scale_nk, zp = _unpack_int4_to_intx_fields(qdata_packed, scale, zero_point) + Q, B = to_mlx_qparams(q, scale_nk, zp, 4) + + w = P.make_or_get_constant(f"{qdata_target}_int4_to_packed", Q) + scale_slot = P.make_or_get_constant(f"{qdata_target}_int4_scales", scale_nk) + biases = emit_quantized_biases(P, zp_target, scale_nk, zp, 4, B, scale_slot) + + x_slot, b_slot = P.slot_map([x_node, b_node]) + out_dtype = ( + x_node.meta["val"].dtype if self.out_dtype is None else self.out_dtype + ) + needs_cast = out_dtype != x_node.meta["val"].dtype + + if self.group_size < self._MIN_FUSED_GROUP_SIZE and not self._allow_non_fused(): + raise ValueError( + f"Int4 quantized linear with group_size={self.group_size} requires " + f"the non-fused path; set ET_MLX_ALLOW_NON_FUSED_QUANTIZED_OPS=1." + ) + + out = P.make_or_get_slot(n) + P.emit( + QuantizedMatmulNode( + x=P.slot_to_tid(x_slot), + w=P.slot_to_tid(w), + scales=P.slot_to_tid(scale_slot), + biases=P.slot_to_tid(biases), + out=P.slot_to_tid(out), + group_size=self.group_size, + bits=4, + mode="affine", + transpose=True, + ) + ) + + if b_node is not None: + P.emit( + AddNode( + a=P.slot_to_tid(out), + b=P.slot_to_tid(b_slot), + out=P.slot_to_tid(out), + ) + ) + + if needs_cast: + P.emit( + AsTypeNode( + x=P.slot_to_tid(out), + out=P.slot_to_tid(out), + scalar_type=torch_dtype_to_scalar_type(out_dtype), + ) + ) + + return out + + +@REGISTRY.register_pattern(name="INT4_QUANTIZED_EMBEDDING") +class Int4QuantizedEmbeddingHandler(PatternHandler): + """Fuse dequantize_int4_tensor + embedding into gather + DequantizeNode(affine). + + Matches:: + + embedding(dequantize_int4_tensor(qdata, scale, zero_point, group_size), ids) + """ + + def __init__(self, head, body, qdata, scale, zero_point, group_size, out_dtype): + super().__init__(head, body) + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.group_size = group_size + self.out_dtype = out_dtype + + @classmethod + def maybe_create(cls, ep, head): + if not match_target(head, torch.ops.aten.embedding.default): + return None + if len(head.args) < 2 or not isinstance(head.args[0], Node): + return None + dequant = head.args[0] + if not has_single_user(dequant): + return None + parsed = parse_dequant_int4_node(dequant) + if parsed is None: + return None + qdata, scale, zero_point, group_size, out_dtype = parsed + return cls(head, [dequant], qdata, scale, zero_point, group_size, out_dtype) + + def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: + assert n == self.head + indices_node = n.args[1] + + qdata_target, qdata_packed = P.get_placeholder_target_and_tensor(self.qdata) + zp_target, zero_point = P.get_placeholder_target_and_tensor(self.zero_point) + _, scale = P.get_placeholder_target_and_tensor(self.scale) + + q, scale_nk, zp = _unpack_int4_to_intx_fields(qdata_packed, scale, zero_point) + Q, B = to_mlx_qparams(q, scale_nk, zp, 4) + + w = P.make_or_get_constant(f"{qdata_target}_int4_to_packed", Q) + scale_slot = P.make_or_get_constant(f"{qdata_target}_int4_scales", scale_nk) + biases = emit_quantized_biases(P, zp_target, scale_nk, zp, 4, B, scale_slot) + + (indices_slot,) = P.slot_map([indices_node]) + out_dtype = scale.dtype if self.out_dtype is None else self.out_dtype + + out = P.make_or_get_slot(n) + emit_quantized_gather( + P, + out, + indices_slot, + w, + scale_slot, + biases, + group_size=self.group_size, + bits=4, + mode="affine", + out_dtype=out_dtype, + ) + return out diff --git a/backends/mlx/test/test_ops.py b/backends/mlx/test/test_ops.py index 9d07af84268..6ba17cccda7 100644 --- a/backends/mlx/test/test_ops.py +++ b/backends/mlx/test/test_ops.py @@ -7402,3 +7402,158 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: self.batch_size, self.seq_len, self.in_features, dtype=self.dtype ) return (x,) + + +def _make_int4_quantized_weight(weight: torch.Tensor, group_size: int) -> torch.Tensor: + """Groupwise affine 4-bit quantize a ``(N, K)`` weight into an + ``ExportableInt4Tensor`` (torchao ``Int4Tensor`` packed layout).""" + from executorch.extension.llm.export.int4 import ExportableInt4Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + N, K = weight.shape + dtype = weight.dtype + w = weight.float().reshape(N, K // group_size, group_size) + wmin = w.amin(dim=-1) + wmax = w.amax(dim=-1) + scale = ((wmax - wmin) / 15.0).clamp(min=1e-8) + # Fractional zero-point (HQQ-style), exercises the float zero_point repack path. + zero = (-wmin / scale).clamp(0, 15) + q = torch.round(w / scale.unsqueeze(-1) + zero.unsqueeze(-1)).clamp(0, 15) + q = q.reshape(N, K).to(torch.uint8) + # Two nibbles/byte: even index -> low nibble. + packed = (q[:, 0::2] | (q[:, 1::2] << 4)).to(torch.uint8) + it = Int4Tensor( + qdata=packed, + scale=scale.t().contiguous().to(dtype), + zero_point=zero.t().contiguous().to(dtype), + block_size=[1, group_size], + shape=torch.Size([N, K]), + ) + return ExportableInt4Tensor.from_int4_tensor(it) + + +class Int4QuantizedLinearModel(nn.Module): + """Linear layer whose weight is an ``ExportableInt4Tensor``.""" + + def __init__(self, in_features: int, out_features: int, bias: bool = True): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear(x) + + +@register_test +class Int4QuantizedLinearTest(OpTestCase): + """ExportableInt4Tensor nn.Linear -> MLX 4-bit affine quantized matmul.""" + + name = "int4_quantized_linear" + rtol = 0.1 + atol = 0.1 + + def __init__( + self, + in_features: int = 64, + out_features: int = 128, + batch_size: int = 2, + seq_len: int = 16, + bias: bool = True, + group_size: int = 32, + dtype: torch.dtype = torch.bfloat16, + ): + self.in_features = in_features + self.out_features = out_features + self.batch_size = batch_size + self.seq_len = seq_len + self.bias = bias + self.group_size = group_size + self.dtype = dtype + + parts = ["int4_quantized_linear", f"g{group_size}"] + if not bias: + parts.append("no_bias") + if dtype != torch.bfloat16: + parts.append(str(dtype).split(".")[-1]) + self.name = "_".join(parts) + + @classmethod + def get_test_configs(cls) -> List["Int4QuantizedLinearTest"]: + return [ + cls(), + cls(bias=False), + cls(group_size=64), + cls(dtype=torch.float32), + ] + + def get_edge_compile_config(self): + from executorch.exir import EdgeCompileConfig + + return EdgeCompileConfig(_check_ir_validity=False) + + def create_model(self) -> nn.Module: + model = Int4QuantizedLinearModel( + self.in_features, self.out_features, bias=self.bias + ).to(self.dtype) + model.linear.weight = nn.Parameter( + _make_int4_quantized_weight(model.linear.weight.data, self.group_size), + requires_grad=False, + ) + return model + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randn( + self.batch_size, self.seq_len, self.in_features, dtype=self.dtype + ) + return (x,) + + +@register_test +class Int4QuantizedEmbeddingTest(OpTestCase): + """ExportableInt4Tensor nn.Embedding -> MLX 4-bit affine quantized gather.""" + + name = "int4_quantized_embedding" + rtol = 0.1 + atol = 0.1 + + def __init__( + self, + num_embeddings: int = 1000, + embedding_dim: int = 128, + batch_size: int = 2, + seq_len: int = 16, + group_size: int = 32, + dtype: torch.dtype = torch.bfloat16, + ): + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.batch_size = batch_size + self.seq_len = seq_len + self.group_size = group_size + self.dtype = dtype + self.name = f"int4_quantized_embedding_g{group_size}" + + @classmethod + def get_test_configs(cls) -> List["Int4QuantizedEmbeddingTest"]: + return [ + cls(), + cls(group_size=64), + cls(group_size=128), + ] + + def get_edge_compile_config(self): + from executorch.exir import EdgeCompileConfig + + return EdgeCompileConfig(_check_ir_validity=False) + + def create_model(self) -> nn.Module: + model = EmbeddingModel(self.num_embeddings, self.embedding_dim) + model = model.to(self.dtype) + model.embedding.weight = nn.Parameter( + _make_int4_quantized_weight(model.embedding.weight.data, self.group_size), + requires_grad=False, + ) + return model + + def create_inputs(self) -> Tuple[torch.Tensor, ...]: + x = torch.randint(0, self.num_embeddings, (self.batch_size, self.seq_len)) + return (x,) diff --git a/examples/models/gemma4_31b/export.py b/examples/models/gemma4_31b/export.py index bcdd49a9a34..64e55319490 100644 --- a/examples/models/gemma4_31b/export.py +++ b/examples/models/gemma4_31b/export.py @@ -310,6 +310,7 @@ def _export_mlx( # weights lower to the fused Q6_K kernels / Q4_K quantized matmul. import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 import executorch.extension.llm.export.gguf # noqa: F401 + import executorch.extension.llm.export.int4 # noqa: F401 from executorch.backends.mlx import MLXPartitioner from executorch.backends.mlx.passes import get_default_passes diff --git a/examples/models/gemma4_31b/quant/pack_mlx.py b/examples/models/gemma4_31b/quant/pack_mlx.py index d627c9c437c..98f5470f184 100644 --- a/examples/models/gemma4_31b/quant/pack_mlx.py +++ b/examples/models/gemma4_31b/quant/pack_mlx.py @@ -6,11 +6,11 @@ """MLX packer: convert quantized weights to MLX-compatible format. -MLX's ``QuantizedLinearHandler`` matches ``dequantize_affine → linear`` -in the exported graph. ``IntxUnpackedToInt8Tensor`` produces this -pattern naturally, but ``Int4Tensor`` does not (its dispatch calls -CUDA-specific mslk kernels). So INT4 weights are converted to -``IntxUnpackedToInt8Tensor(target_dtype=torch.int4)`` at pack time. +``Int4Tensor`` weights are wrapped as ``ExportableInt4Tensor`` so they export to +``dequantize_int4_tensor -> linear/embedding`` (matched by MLX's Int4 handlers). +``IntxUnpackedToInt8Tensor`` (e.g. int8 / Q6_K) already exports to +``dequantize_affine -> linear`` and is assigned directly, regrouped to an +MLX-compatible group size when needed. The backend-agnostic ``pack_model`` dispatcher lives in ``pack.py``. """ @@ -25,45 +25,6 @@ _MLX_SUPPORTED_GROUP_SIZES = (128, 64, 32, 16) -# --------------------------------------------------------------------------- -# Int4Tensor → IntxUnpackedToInt8Tensor conversion - - -def _int4_to_intx_unpacked(w: torch.Tensor) -> torch.Tensor: - """Convert an ``Int4Tensor`` to ``IntxUnpackedToInt8Tensor``. - - Int4Tensor stores qdata as nibble-packed uint8 ``(N, K/2)`` with - scale/zero transposed to ``(K//gs, N)``. IntxUnpackedToInt8Tensor - stores qdata as int8 ``(N, K)`` with scale/zero as ``(N, K//gs)``. - """ - from torchao.quantization import IntxUnpackedToInt8Tensor - - # Unpack nibbles: packed = even | (odd << 4), unsigned [0, 15] - p = w.qdata.to(torch.uint8) - low = (p & 0x0F).to(torch.int8) - high = ((p >> 4) & 0x0F).to(torch.int8) - qdata = torch.stack([low, high], dim=-1).reshape(w.shape) - - # Shift unsigned [0, 15] → signed [-8, 7] - qdata = qdata - 8 - - gs = w.block_size[-1] - - # Transpose scale/zero from (K//gs, N) → (N, K//gs) - scale = w.scale.t().contiguous() - zero_point = (w.zero_point - 8).t().contiguous() - - return IntxUnpackedToInt8Tensor( - qdata=qdata, - scale=scale, - zero_point=zero_point, - target_dtype=torch.int4, - block_size=(1, gs), - dtype=scale.dtype, - activation_quantization=None, - ) - - # --------------------------------------------------------------------------- # Embedding group_size regrouping @@ -130,13 +91,16 @@ def pack_for_mlx(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: Group sizes ≥ 32 use the fused ``QuantizedMatmulNode``; group_size=16 (e.g. GGUF Q6_K) falls back to ``DequantizeNode`` + matmul at export. """ + from executorch.extension.llm.export.int4 import ExportableInt4Tensor from torchao.quantization import IntxUnpackedToInt8Tensor from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor w = weights["weight"] if isinstance(w, Int4Tensor): - w = _int4_to_intx_unpacked(w) - if isinstance(w, IntxUnpackedToInt8Tensor): + # Int4 group is MLX-native (32); wrap so it exports to + # dequantize_int4_tensor -> linear/embedding. + w = ExportableInt4Tensor.from_int4_tensor(w) + elif isinstance(w, IntxUnpackedToInt8Tensor): gs = w.block_size[-1] K = w.qdata.shape[-1] target_gs = _mlx_group_size(gs, K) diff --git a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py index c626904b39f..2753c62bb8c 100644 --- a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py @@ -364,14 +364,14 @@ def test_q6k_linear_delegates(self): self._assert_delegated( self._linear(256, 512, "q6_k"), (torch.randn(4, 512, dtype=torch.bfloat16),), - ("gguf_dequantize", "linear"), + ("dequantize_gguf", "linear"), ) def test_q4k_linear_delegates(self): self._assert_delegated( self._linear(512, 512, "q4_k"), (torch.randn(4, 512, dtype=torch.bfloat16),), - ("gguf_dequantize", "linear"), + ("dequantize_gguf", "linear"), ) @@ -413,7 +413,7 @@ def _assert_delegated(self, ggml_type: str): str(n.target) for n in et.exported_program().graph.nodes if n.op == "call_function" - and any(t in str(n.target) for t in ("gguf_dequantize", "embedding")) + and any(t in str(n.target) for t in ("dequantize_gguf", "embedding")) ] self.assertEqual(remaining, [], f"not delegated to MLX: {remaining}") @@ -424,5 +424,66 @@ def test_q4k_embedding_delegates(self): self._assert_delegated("q4_k") +class TestInt4Mlx(unittest.TestCase): + """ExportableInt4Tensor linear + embedding lower through the MLX Int4 pattern.""" + + def _make_int4(self, N, K, gs=32, seed=0): + from executorch.extension.llm.export.int4 import ExportableInt4Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + g = torch.Generator().manual_seed(seed) + q = torch.randint(0, 16, (N, K), generator=g, dtype=torch.int32) + packed = (q[:, 0::2] | (q[:, 1::2] << 4)).to(torch.uint8) + scale = (torch.randn(K // gs, N, generator=g) * 0.1).to(torch.bfloat16) + zero = torch.randint(0, 16, (K // gs, N), generator=g).to(torch.bfloat16) + it = Int4Tensor( + qdata=packed, + scale=scale, + zero_point=zero, + block_size=[1, gs], + shape=torch.Size([N, K]), + ) + return ExportableInt4Tensor.from_int4_tensor(it) + + def _assert_delegated(self, model, example, leftovers): + import executorch.backends.mlx.patterns # noqa: F401 + from executorch.backends.mlx import MLXPartitioner + from executorch.exir import EdgeCompileConfig, to_edge_transform_and_lower + from torch.export import Dim, export + + seq = Dim("seq", min=1, max=8) + ep = export(model, example, dynamic_shapes=({0: seq},), strict=True) + et = to_edge_transform_and_lower( + ep, + partitioner=[MLXPartitioner()], + compile_config=EdgeCompileConfig(_check_ir_validity=False), + ) + remaining = [ + str(n.target) + for n in et.exported_program().graph.nodes + if n.op == "call_function" and any(t in str(n.target) for t in leftovers) + ] + self.assertEqual(remaining, [], f"not delegated to MLX: {remaining}") + + def test_int4_linear_delegates(self): + lin = nn.Linear(512, 256, bias=False).to(torch.bfloat16) + lin.weight = nn.Parameter(self._make_int4(256, 512), requires_grad=False) + self._assert_delegated( + lin.eval(), + (torch.randn(4, 512, dtype=torch.bfloat16),), + ("dequantize_int4_tensor", "linear"), + ) + + def test_int4_embedding_delegates(self): + vocab, K = 512, 256 + emb = nn.Embedding(vocab, K) + emb.weight = nn.Parameter(self._make_int4(vocab, K), requires_grad=False) + self._assert_delegated( + emb.eval(), + (torch.randint(0, vocab, (4,), dtype=torch.int64),), + ("dequantize_int4_tensor", "embedding"), + ) + + if __name__ == "__main__": unittest.main() diff --git a/extension/llm/export/gguf.py b/extension/llm/export/gguf.py index 9a0b34cdd6c..fe0e78c51fa 100644 --- a/extension/llm/export/gguf.py +++ b/extension/llm/export/gguf.py @@ -15,16 +15,16 @@ (quantized tensors become ``ExportableGGUFTensor``; F32/F16 become plain tensors). No unpacking happens at load. 2. **Lower (dequantize)**: used as a weight, the subclass dequantizes via the - ``torchao::gguf_dequantize`` custom op (gguf-package eager body) and runs the + ``torchao::dequantize_gguf`` custom op (gguf-package eager body) and runs the plain torch ``linear`` / ``embedding`` (NVFP4-style). A backend can - pattern-match ``gguf_dequantize`` -> linear/embedding to fuse. + pattern-match ``dequantize_gguf`` -> linear/embedding to fuse. 3. **Convert**: ``.to_int4_tensor()`` / ``.to_intx_unpacked_to_int8_tensor()`` unpack into torchao tensor subclasses (``Int4Tensor`` for Q4_K, ``IntxUnpackedToInt8Tensor`` for Q4_K or Q6_K) to take the non-fused (affine-dequant) path instead. The GGUF quant type is identified by a **string** (``"q4_k"``, ``"q6_k"``) -everywhere user-facing (subclass attribute + ``gguf_dequantize`` op argument); the +everywhere user-facing (subclass attribute + ``dequantize_gguf`` op argument); the ``gguf`` package's integer ``GGMLQuantizationType`` ids are an internal lookup detail. @@ -76,7 +76,7 @@ def _read_f16(raw: Tensor, col_start: int, col_end: int) -> Tensor: return raw[:, col_start:col_end].contiguous().view(torch.float16).float() -def _gguf_dequantize(raw: Tensor, ggml_type: str, output_dtype: torch.dtype) -> Tensor: +def _dequantize_gguf(raw: Tensor, ggml_type: str, output_dtype: torch.dtype) -> Tensor: """Dequantize a raw GGUF block blob to a float tensor via the ``gguf`` package. ``raw`` is ``(N, row_bytes)`` uint8; the result is ``(N, K)`` in @@ -99,17 +99,17 @@ def _gguf_dequantize(raw: Tensor, ggml_type: str, output_dtype: torch.dtype) -> # --------------------------------------------------------------------------- -@torch.library.custom_op("torchao::gguf_dequantize", mutates_args=()) -def gguf_dequantize( +@torch.library.custom_op("torchao::dequantize_gguf", mutates_args=()) +def dequantize_gguf( weight: Tensor, ggml_type: str, output_dtype: torch.dtype = torch.bfloat16, ) -> Tensor: """Dequantize a raw GGUF block blob (``(N, row_bytes)`` uint8) to ``(N, K)``.""" - return _gguf_dequantize(weight, ggml_type, output_dtype) + return _dequantize_gguf(weight, ggml_type, output_dtype) -@gguf_dequantize.register_fake +@dequantize_gguf.register_fake def _(weight, ggml_type, output_dtype=torch.bfloat16): K = (weight.shape[1] // _BLOCK_BYTES_BY_TYPE[ggml_type]) * QK_K return torch.empty((weight.shape[0], K), dtype=output_dtype, device=weight.device) @@ -205,7 +205,7 @@ class ExportableGGUFTensor(TorchAOBaseTensor): Stores the exact GGUF ``block_q*_K`` byte layout (no repacking) plus the quant type string (``"q4_k"`` / ``"q6_k"``). ``aten.linear`` / ``aten.embedding`` - dequantize via the ``torchao::gguf_dequantize`` op (then a plain + dequantize via the ``torchao::dequantize_gguf`` op (then a plain linear/embedding); :meth:`to_int4_tensor` / :meth:`to_intx_unpacked_to_int8_tensor` convert to torchao subclasses instead. """ @@ -256,7 +256,7 @@ def from_raw( def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> Tensor: """Dequantize to a plain float tensor using the ``gguf`` package.""" - return torch.ops.torchao.gguf_dequantize( + return torch.ops.torchao.dequantize_gguf( self.raw, self.ggml_type, output_dtype or self.orig_dtype ) diff --git a/extension/llm/export/int4.py b/extension/llm/export/int4.py new file mode 100644 index 00000000000..59251ae1875 --- /dev/null +++ b/extension/llm/export/int4.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Int4 export-compatible quantization. + +Wraps a torchao ``Int4Tensor`` (nibble-packed 4-bit groupwise weight) so it +survives ``torch.export`` / ``run_decompositions``: a ``torchao::dequantize_int4_tensor`` +custom op carries the dequant, and ``aten.linear`` / ``aten.embedding`` desugar to +``dequantize_int4_tensor -> linear/embedding`` (mirroring ``dequantize_nvfp4`` / +``dequantize_gguf``). A backend may pattern-match the op to a low-bit kernel; the +eager body is a plain affine dequant so the representation is portable. + +The tensor stores the ``Int4Tensor`` layout verbatim: + * ``qdata`` ``(N, K // 2)`` uint8, two nibbles/byte (even index -> low nibble), + unsigned values in [0, 15]. + * ``scale`` ``(K // group_size, N)``. + * ``zero_point`` ``(K // group_size, N)``, unsigned values in [0, 15]. +Dequant is ``scale * (q - zero_point)`` per group. +""" + +import torch +from torch import Tensor +from torchao.utils import TorchAOBaseTensor + +aten = torch.ops.aten + + +def _dequantize_int4( + qdata: Tensor, + scale: Tensor, + zero_point: Tensor, + group_size: int, + output_dtype: torch.dtype, +) -> Tensor: + """Eager affine dequant of an ``Int4Tensor``-layout weight to ``(N, K)``.""" + p = qdata.view(torch.uint8) + low = (p & 0x0F).to(torch.int32) + high = ((p >> 4) & 0x0F).to(torch.int32) + # Two nibbles/byte: even index -> low, odd -> high. + q = torch.stack([low, high], dim=-1).reshape(p.shape[0], -1).to(torch.float32) + + # scale / zero_point are (K // gs, N) -> transpose to (N, K // gs) and expand. + s = scale.t().to(torch.float32).repeat_interleave(group_size, dim=-1) + z = zero_point.t().to(torch.float32).repeat_interleave(group_size, dim=-1) + return ((q - z) * s).to(output_dtype) + + +@torch.library.custom_op("torchao::dequantize_int4_tensor", mutates_args=()) +def dequantize_int4_tensor( + qdata: Tensor, + scale: Tensor, + zero_point: Tensor, + group_size: int, + output_dtype: torch.dtype = torch.bfloat16, +) -> Tensor: + """Dequantize a nibble-packed Int4 weight (``(N, K//2)`` uint8) to ``(N, K)``.""" + return _dequantize_int4(qdata, scale, zero_point, group_size, output_dtype) + + +@dequantize_int4_tensor.register_fake +def _(qdata, scale, zero_point, group_size, output_dtype=torch.bfloat16): + K = qdata.shape[1] * 2 # two 4-bit values per byte + return torch.empty(qdata.shape[0], K, dtype=output_dtype, device=qdata.device) + + +class ExportableInt4Tensor(TorchAOBaseTensor): + """Int4 tensor subclass that dequantizes via a registered custom op.""" + + tensor_data_names = ["qdata", "scale", "zero_point"] + tensor_attribute_names = ["group_size", "orig_dtype"] + + def __new__(cls, qdata, scale, zero_point, group_size, orig_dtype): + K = qdata.shape[-1] * 2 # two 4-bit values per byte + shape = (qdata.shape[0], K) + self = torch.Tensor._make_wrapper_subclass( + cls, shape, dtype=orig_dtype, device=qdata.device, requires_grad=False + ) + self.qdata = qdata + self.scale = scale + self.zero_point = zero_point + self.group_size = group_size + self.orig_dtype = orig_dtype + return self + + @classmethod + def from_int4_tensor(cls, w: Tensor) -> "ExportableInt4Tensor": + """Build from a torchao ``Int4Tensor`` (copies its packed fields).""" + return cls( + w.qdata, + w.scale, + w.zero_point, + int(w.block_size[-1]), + w.dtype, + ) + + def dequantize(self, output_dtype=None): + return torch.ops.torchao.dequantize_int4_tensor( + self.qdata, + self.scale, + self.zero_point, + self.group_size, + output_dtype=output_dtype or self.orig_dtype, + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + +implements = ExportableInt4Tensor.implements + + +@implements([aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight = args[0], args[1] + bias = args[2] if len(args) > 2 else None + return torch.nn.functional.linear( + input_tensor, weight.dequantize(input_tensor.dtype), bias + ) + + +@implements([aten.embedding.default]) +def _(func, types, args, kwargs): + weight, indices = args[0], args[1] + return torch.nn.functional.embedding(indices, weight.dequantize()) + + +@implements([aten.t.default]) +def _(func, types, args, kwargs): + return args[0].dequantize().t() + + +@implements([aten.detach.default, aten.alias.default]) +def _(func, types, args, kwargs): + return args[0] + + +@implements([aten._to_copy.default]) +def _(func, types, args, kwargs): + return args[0].dequantize(output_dtype=kwargs.get("dtype", args[0].orig_dtype)) diff --git a/extension/llm/export/test/test_gguf.py b/extension/llm/export/test/test_gguf.py index dff987f0749..f6cd52a69ad 100644 --- a/extension/llm/export/test/test_gguf.py +++ b/extension/llm/export/test/test_gguf.py @@ -142,11 +142,11 @@ def test_to_int4_tensor_matches_reference(self): ) ) - def test_gguf_dequantize_op_matches_reference(self): + def test_dequantize_gguf_op_matches_reference(self): for ggml_type, make in (("q4_k", _make_q4k_raw), ("q6_k", _make_q6k_raw)): raw = make(N=3, nb=2) t = ExportableGGUFTensor.from_raw(raw, ggml_type) - out = torch.ops.torchao.gguf_dequantize(raw, ggml_type, torch.float32) + out = torch.ops.torchao.dequantize_gguf(raw, ggml_type, torch.float32) self.assertTrue(torch.equal(out, t.dequantize(torch.float32))) def test_subclass_linear_dispatches_to_dequant(self): @@ -174,14 +174,14 @@ def test_unsupported_type_raises(self): @unittest.skipUnless(_HAS_GGUF, "gguf package not installed") class TestExportableGGUFTensorExport(unittest.TestCase): """Exporting a module whose weight is an ``ExportableGGUFTensor`` should - lower linear/embedding through the ``torchao::gguf_dequantize`` op after + lower linear/embedding through the ``torchao::dequantize_gguf`` op after ``run_decompositions`` (the subclass dispatch fires during decomposition).""" @staticmethod def _targets(ep): return {str(n.target) for n in ep.graph.nodes if n.op == "call_function"} - def test_linear_exports_with_gguf_dequantize(self): + def test_linear_exports_with_dequantize_gguf(self): t = ExportableGGUFTensor.from_raw(_make_q6k_raw(N=4, nb=1), "q6_k") class M(torch.nn.Module): @@ -195,9 +195,9 @@ def forward(self, x): ep = torch.export.export( M(), (torch.randn(2, 256, dtype=torch.bfloat16),) ).run_decompositions({}) - self.assertIn("torchao.gguf_dequantize.default", self._targets(ep)) + self.assertIn("torchao.dequantize_gguf.default", self._targets(ep)) - def test_embedding_exports_with_gguf_dequantize(self): + def test_embedding_exports_with_dequantize_gguf(self): t = ExportableGGUFTensor.from_raw(_make_q6k_raw(N=8, nb=1), "q6_k") class M(torch.nn.Module): @@ -211,7 +211,7 @@ def forward(self, idx): ep = torch.export.export(M(), (torch.tensor([0, 1, 2, 3]),)).run_decompositions( {} ) - self.assertIn("torchao.gguf_dequantize.default", self._targets(ep)) + self.assertIn("torchao.dequantize_gguf.default", self._targets(ep)) if __name__ == "__main__": diff --git a/extension/llm/export/test/test_int4.py b/extension/llm/export/test/test_int4.py new file mode 100644 index 00000000000..9414248d59a --- /dev/null +++ b/extension/llm/export/test/test_int4.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Tests for ExportableInt4Tensor + the torchao::dequantize_int4_tensor op.""" + +import unittest + +import torch +from executorch.extension.llm.export.int4 import ExportableInt4Tensor + + +def _make_int4_tensor(N: int, K: int, gs: int, seed: int = 0): + """Build a synthetic ``Int4Tensor`` plus the (q, scale, zero_point) it encodes. + + Returns ``(int4_tensor, q_unsigned (N,K), scale (K//gs,N), zero (K//gs,N))``. + """ + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + g = torch.Generator().manual_seed(seed) + q = torch.randint(0, 16, (N, K), generator=g, dtype=torch.int32) # unsigned [0,15] + # Pack two nibbles/byte: even index -> low, odd -> high. + packed = (q[:, 0::2] | (q[:, 1::2] << 4)).to(torch.uint8) + scale = (torch.randn(K // gs, N, generator=g) * 0.1).to(torch.bfloat16) + zero = torch.randint(0, 16, (K // gs, N), generator=g).to(torch.bfloat16) + it = Int4Tensor( + qdata=packed, + scale=scale, + zero_point=zero, + block_size=[1, gs], + shape=torch.Size([N, K]), + ) + return it, q, scale, zero + + +def _reference_dequant(q, scale, zero, gs): + """Independent affine dequant: scale * (q - zero), groups expanded.""" + s = scale.t().to(torch.float32).repeat_interleave(gs, dim=-1) + z = zero.t().to(torch.float32).repeat_interleave(gs, dim=-1) + return (q.to(torch.float32) - z) * s + + +class TestDequantizeInt4Op(unittest.TestCase): + def test_op_matches_reference(self): + it, q, scale, zero = _make_int4_tensor(N=8, K=64, gs=32) + out = torch.ops.torchao.dequantize_int4_tensor( + it.qdata, it.scale, it.zero_point, 32, torch.float32 + ) + ref = _reference_dequant(q, scale, zero, 32) + self.assertEqual(tuple(out.shape), (8, 64)) + self.assertTrue(torch.allclose(out, ref, rtol=1e-2, atol=5e-2)) + + def test_subclass_dequantize_matches_op(self): + it, _, _, _ = _make_int4_tensor(N=8, K=64, gs=32) + t = ExportableInt4Tensor.from_int4_tensor(it) + ref = torch.ops.torchao.dequantize_int4_tensor( + it.qdata, it.scale, it.zero_point, 32, torch.bfloat16 + ) + self.assertTrue(torch.equal(t.dequantize(torch.bfloat16), ref)) + + def test_subclass_linear_dispatches_to_dequant(self): + it, _, _, _ = _make_int4_tensor(N=16, K=64, gs=32) + t = ExportableInt4Tensor.from_int4_tensor(it) + x = torch.randn(2, 64, dtype=torch.bfloat16) + out = torch.nn.functional.linear(x, t) + ref = torch.nn.functional.linear(x, t.dequantize(torch.bfloat16)) + self.assertTrue(torch.equal(out, ref)) + + def test_subclass_embedding_dispatches_to_dequant(self): + it, _, _, _ = _make_int4_tensor(N=16, K=64, gs=32) + t = ExportableInt4Tensor.from_int4_tensor(it) + idx = torch.tensor([0, 3, 7, 1]) + out = torch.nn.functional.embedding(idx, t) + ref = torch.nn.functional.embedding(idx, t.dequantize(torch.bfloat16)) + self.assertTrue(torch.equal(out, ref)) + + +class TestExportableInt4TensorExport(unittest.TestCase): + """Exporting a module whose weight is an ``ExportableInt4Tensor`` should lower + linear/embedding through ``torchao::dequantize_int4_tensor`` after + ``run_decompositions`` (the subclass dispatch fires during decomposition).""" + + @staticmethod + def _targets(ep): + return {str(n.target) for n in ep.graph.nodes if n.op == "call_function"} + + def test_linear_exports_with_dequantize_int4(self): + it, _, _, _ = _make_int4_tensor(N=16, K=64, gs=32) + t = ExportableInt4Tensor.from_int4_tensor(it) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.nn.Parameter(t, requires_grad=False) + + def forward(self, x): + return torch.nn.functional.linear(x, self.w) + + ep = torch.export.export( + M(), (torch.randn(2, 64, dtype=torch.bfloat16),) + ).run_decompositions({}) + self.assertIn("torchao.dequantize_int4_tensor.default", self._targets(ep)) + + def test_embedding_exports_with_dequantize_int4(self): + it, _, _, _ = _make_int4_tensor(N=16, K=64, gs=32) + t = ExportableInt4Tensor.from_int4_tensor(it) + + class M(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.nn.Parameter(t, requires_grad=False) + + def forward(self, idx): + return torch.nn.functional.embedding(idx, self.w) + + ep = torch.export.export(M(), (torch.tensor([0, 1, 2, 3]),)).run_decompositions( + {} + ) + self.assertIn("torchao.dequantize_int4_tensor.default", self._targets(ep)) + + +if __name__ == "__main__": + unittest.main() From 70a6eb34134cdf2be18497d4840f004794089825 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Sun, 7 Jun 2026 12:45:49 -0700 Subject: [PATCH 15/18] up --- .../mlx/custom_kernel_ops/gguf/__init__.py | 16 ++---- .../mlx/custom_kernel_ops/gguf/patterns.py | 20 ++++--- backends/mlx/custom_kernel_ops/gguf/q4k.py | 0 .../custom_kernel_ops/gguf/q4k/__init__.py | 14 ++--- .../custom_kernel_ops/gguf/q6k/__init__.py | 17 ++---- .../mlx/custom_kernel_ops/gguf/q6k/common.py | 3 +- .../custom_kernel_ops/gguf/q6k/embedding.py | 14 ++--- .../gguf/test/test_linear.py | 42 ++++++++++++++ backends/mlx/test/test_utils.py | 12 +++- examples/models/gemma4_31b/gguf_loader.py | 25 ++++---- extension/llm/export/gguf.py | 57 ++++++------------- extension/llm/export/test/test_gguf.py | 4 +- 12 files changed, 112 insertions(+), 112 deletions(-) delete mode 100644 backends/mlx/custom_kernel_ops/gguf/q4k.py diff --git a/backends/mlx/custom_kernel_ops/gguf/__init__.py b/backends/mlx/custom_kernel_ops/gguf/__init__.py index 0e6ec0f01a4..1b6c1c5373c 100644 --- a/backends/mlx/custom_kernel_ops/gguf/__init__.py +++ b/backends/mlx/custom_kernel_ops/gguf/__init__.py @@ -8,19 +8,11 @@ """GGUF-quantized weight lowering for the MLX backend. -Submodules: - -* :mod:`.q6k` -- shared Q6_K primitives (constants, pure-torch dequant, - Metal header) and the fused mat-vec / mat-mat / gather kernels. Importing - ``.q6k`` (or ``.q6k.common``) is lightweight and does not touch the registry. -* :mod:`.patterns` -- registers MLX pattern handlers that match - ``torchao::dequantize_gguf -> linear/embedding`` (what ``ExportableGGUFTensor`` - exports) and lower them to the Q6_K kernels. - -To enable GGUF lowering, import :mod:`.patterns` for its side effect:: +Import :mod:`.patterns` for its side effect to enable lowering of +``torchao::dequantize_gguf -> linear/embedding`` to the Q6_K / Q4_K kernels:: import executorch.backends.mlx.custom_kernel_ops.gguf.patterns # noqa: F401 -This package ``__init__`` is intentionally side-effect free so importing -``.q6k`` for the pure-torch dequant does not pull in the MLX builder/registry. +This ``__init__`` is side-effect free, so importing ``.q6k`` for the pure-torch +dequant does not pull in the MLX builder/registry. """ diff --git a/backends/mlx/custom_kernel_ops/gguf/patterns.py b/backends/mlx/custom_kernel_ops/gguf/patterns.py index 58f6fa41a7e..7d3a5bc307c 100644 --- a/backends/mlx/custom_kernel_ops/gguf/patterns.py +++ b/backends/mlx/custom_kernel_ops/gguf/patterns.py @@ -17,9 +17,11 @@ These handlers match that ``dequantize_gguf -> linear/embedding`` subgraph and lower it without materializing the dequantized weight: -* **Q6_K** -> fused custom Metal kernels in :mod:`.q6k` (linear + embedding). -* **Q4_K** -> MLX's native 4-bit ``quantized_matmul`` via :mod:`.q4k` (linear); - the GGUF blocks are repacked into MLX qparams at export time. +* **Q6_K** -> fused custom Metal kernels in :mod:`.q6k`. +* **Q4_K** -> MLX's native 4-bit affine ops via :mod:`.q4k` (GGUF blocks + repacked into MLX qparams at export time). + +Both cover linear and embedding. Other quant types are left unmatched (the caller is expected to convert them to a torchao ``Int4Tensor`` / ``IntxUnpackedToInt8Tensor`` first). @@ -40,8 +42,8 @@ from torch.export.exported_program import ExportedProgram from torch.fx.node import Node -# Quant types each pattern can lower (linear has both a custom Q6_K kernel and an -# MLX-native Q4_K path; embedding only has the Q6_K gather kernel). +# Quant types each pattern can lower (Q6_K via custom Metal kernels, Q4_K via +# MLX-native affine ops). _LINEAR_TYPES = {"q4_k", "q6_k"} _EMBEDDING_TYPES = {"q4_k", "q6_k"} @@ -121,11 +123,11 @@ def __call__(self, P: MLXProgramBuilder, n: Node) -> Slot: @REGISTRY.register_pattern(name="GGUF_QUANTIZED_EMBEDDING") class GGUFQuantizedEmbeddingHandler(PatternHandler): - """Fuse ``dequantize_gguf + embedding`` into the Q6_K gather kernel. - - Matches:: + """Lower ``dequantize_gguf + embedding`` to a quantized gather. - embedding(dequantize_gguf(weight, "q6_k", out_dtype), indices) + Matches ``embedding(dequantize_gguf(weight, ggml_type, out_dtype), indices)`` + and dispatches on ``ggml_type``: Q6_K -> custom Metal gather, Q4_K -> MLX + quantized gather. """ def __init__(self, head, body, weight, ggml_type, output_dtype): diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k.py b/backends/mlx/custom_kernel_ops/gguf/q4k.py deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py b/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py index 661d49e39d0..6f89cfe2c82 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py +++ b/backends/mlx/custom_kernel_ops/gguf/q4k/__init__.py @@ -6,15 +6,9 @@ # LICENSE file in the root directory of this source tree. # -"""GGUF **Q4_K** format lowering for the MLX backend. +"""GGUF Q4_K format lowering for the MLX backend (native affine 4-bit). -Q4_K maps onto MLX's native affine 4-bit kernels (no custom Metal): - -* :mod:`.common` -- repack a raw Q4_K blob into MLX qparams. -* :mod:`.linear` -- ``emit_linear`` (``QuantizedMatmulNode``). -* :mod:`.embedding` -- ``emit_embedding`` (gather + ``DequantizeNode``). - -The pattern handlers in ``custom_kernel_ops.gguf.patterns`` call these ``emit_*`` -functions. ``.linear`` / ``.embedding`` are intentionally NOT imported here so -the package import stays light. +See :mod:`.linear` / :mod:`.embedding` for the ``emit_*`` lowerings (called by +``custom_kernel_ops.gguf.patterns``); they are not imported here to keep the +package import light. """ diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py b/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py index 2b809e3b16a..deb39c4d3c0 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/__init__.py @@ -6,19 +6,12 @@ # LICENSE file in the root directory of this source tree. # -"""GGUF **Q6_K** format implementation. +"""GGUF Q6_K format implementation (fused custom Metal kernels). -* :mod:`.common` -- shared primitives (constants + Metal header). Re-exported - here so ``from ...gguf.q6k import Q6K_BLOCK_BYTES`` stays lightweight (no MLX - builder import). -* :mod:`.linear` -- Q6_K mat-vec/mat-mat kernels + ``emit_linear`` lowering. -* :mod:`.embedding` -- Q6_K gather kernel + ``emit_embedding`` lowering. - -The pattern handlers that match ``torchao::dequantize_gguf -> linear/embedding`` -and call these ``emit_*`` functions live one level up in -``custom_kernel_ops.gguf.patterns``. ``.linear`` / ``.embedding`` are -intentionally NOT imported here so importing :mod:`.common` for the pure-torch -dequant does not pull in the builder. +Re-exports the lightweight constants/header from :mod:`.common` so they can be +imported without pulling in the MLX builder. The ``emit_*`` lowerings live in +:mod:`.linear` / :mod:`.embedding` (called by ``custom_kernel_ops.gguf.patterns``) +and are not imported here. """ from executorch.backends.mlx.custom_kernel_ops.gguf.q6k.common import ( # noqa: F401 diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/common.py b/backends/mlx/custom_kernel_ops/gguf/q6k/common.py index 9445fbc5b36..69ddbb0f406 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q6k/common.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/common.py @@ -16,8 +16,7 @@ * ``_Q6K_HEADER`` -- the Metal header (the ``block_q6_K`` struct plus the per-element and vectorized dequant helpers) shared by all Q6_K Metal kernels. -Adding another GGUF format (e.g. Q4_K) should mirror this module and the pattern -handlers in :mod:`..patterns` dispatch on the GGUF quant type. +Q6_K layout Q6_K layout (per 256-element super-block, 210 bytes, see llama.cpp ``block_q6_K`` in ``ggml-common.h``):: diff --git a/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py b/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py index 64177392eb0..2e7401bdaf4 100644 --- a/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py +++ b/backends/mlx/custom_kernel_ops/gguf/q6k/embedding.py @@ -6,17 +6,11 @@ # LICENSE file in the root directory of this source tree. # -"""GGUF **Q6_K** embedding implementation. +"""GGUF **Q6_K** embedding lowering for the MLX GGUF pattern handler. -Provides the Q6_K embedding lowering used by the MLX GGUF pattern handler -(:mod:`..patterns`): - -* :func:`emit_embedding` -- lowers a ``dequantize_gguf -> embedding`` pattern to - a fused Q6_K gather Metal kernel. - -This is the gather counterpart to :mod:`.linear` and exists because MLX's affine -dequantize has no group_size=16 Metal kernel, so a Q6_K embedding (group_size 16) -cannot use the generic quantized-embedding path. +A custom gather Metal kernel is needed because MLX's affine dequantize has no +group_size=16 kernel, so a Q6_K embedding (group_size 16) can't use the generic +quantized-embedding path. """ from __future__ import annotations diff --git a/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py index daaaef1491c..4a7defbe107 100644 --- a/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py +++ b/backends/mlx/custom_kernel_ops/gguf/test/test_linear.py @@ -123,6 +123,42 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.linear(x) +def _fp32_linear_reference(model: "GGUFLinearModel", x: torch.Tensor): + """fp32-accumulation reference matching the kernel. + + The kernels accumulate in fp32 and cast to the I/O dtype only at the end, so + a bf16 eager matmul is too noisy an oracle over large K. Dequantize in fp32, + matmul in fp32, then cast back -- differences collapse to ~1 output ULP. + + The reference weight must match the representation the kernel consumes: + Q6_K dequantizes the raw blob in-kernel at full precision (use the gguf-exact + dequant), while Q4_K is repacked into bf16 MLX qparams, so use that repacked + dequant (repack precision vs gguf is covered separately by test_gguf.py). + """ + lin = model.linear + weight = lin.weight + if getattr(weight, "ggml_type", None) == "q4_k": + # Q4_K is repacked into bf16 MLX affine qparams (S, Q, B); reconstruct + # exactly what the kernel dequantizes so the oracle isolates kernel + # accumulation (repack precision vs gguf is covered by test_gguf.py). + from executorch.backends.mlx.builder.op_helpers import to_mlx_qparams + + intx = weight.to_intx_unpacked_to_int8_tensor() + gs = int(intx.block_size[-1]) + Q, B = to_mlx_qparams(intx.qdata, intx.scale, intx.zero_point, 4) + qb = Q.view(torch.uint8) + nibbles = torch.stack([(qb & 0xF).float(), ((qb >> 4) & 0xF).float()], dim=-1) + q_unsigned = nibbles.reshape(intx.qdata.shape[0], -1) + scale = intx.scale.float().repeat_interleave(gs, dim=1) + bias_b = B.float().repeat_interleave(gs, dim=1) + w = scale * q_unsigned + bias_b + else: + w = weight.dequantize(torch.float32) + bias = lin.bias.float() if lin.bias is not None else None + out = torch.nn.functional.linear(x.float(), w, bias) + return [out.to(x.dtype)] + + _DTYPE_TOL = { torch.bfloat16: (2e-2, 2e-2), # The mat-mat (prefill) kernel stores tiles in half precision (as in @@ -210,6 +246,9 @@ def create_inputs(self) -> Tuple[torch.Tensor, ...]: torch.manual_seed(0) return (torch.randn(self.M, self.K, dtype=self.dtype),) + def compute_expected_outputs(self, model, test_inputs): + return _fp32_linear_reference(model, test_inputs[0]) + class GGUFLinearDynamicTest(OpTestCase): """Dynamic seqlen: export once with a symbolic M, run with M=1 (decode / @@ -268,6 +307,9 @@ def create_test_inputs(self) -> Tuple[torch.Tensor, ...]: torch.manual_seed(0) return (torch.randn(self.test_M, self.K, dtype=self.dtype),) + def compute_expected_outputs(self, model, test_inputs): + return _fp32_linear_reference(model, test_inputs[0]) + def _eager_sanity() -> None: """Quick CPU check: the subclass linear exports to dequantize_gguf.""" diff --git a/backends/mlx/test/test_utils.py b/backends/mlx/test/test_utils.py index 5dbc35b824d..1a964bea935 100644 --- a/backends/mlx/test/test_utils.py +++ b/backends/mlx/test/test_utils.py @@ -883,6 +883,16 @@ def get_test_dir(self) -> Path: test_dir.mkdir(parents=True, exist_ok=True) return test_dir + def compute_expected_outputs(self, model, test_inputs): + """Reference outputs the device result is compared against. + + Defaults to the eager ``model`` forward. Override to supply a + higher-precision reference -- e.g. fp32 accumulation matching a kernel + that accumulates in fp32, so bf16 reference noise doesn't dominate the + comparison. + """ + return model(*test_inputs) + def generate_test_files(self, verbose: bool = False) -> Tuple[Path, Path, Path]: """ Generate .pte, input.bin, and expected_output.bin files. @@ -915,7 +925,7 @@ def generate_test_files(self, verbose: bool = False) -> Tuple[Path, Path, Path]: with torch.no_grad(): if isinstance(test_inputs, torch.Tensor): test_inputs = (test_inputs,) - expected_outputs = model(*test_inputs) + expected_outputs = self.compute_expected_outputs(model, test_inputs) if isinstance(expected_outputs, torch.Tensor): expected_outputs = [expected_outputs] else: diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index a5452000124..6122af59944 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -12,13 +12,13 @@ FQNs, handles the tied embed/lm_head, and converts each weight for the target backend: -* **MLX**: linears keep the ``ExportableGGUFTensor`` (lowered by the MLX GGUF - pattern -- Q6_K custom kernels, Q4_K native 4-bit matmul); a Q6_K token - embedding keeps it too (fused gather), while a Q4_K embedding is converted to - ``IntxUnpackedToInt8Tensor`` (MLX quantized gather -- there is no Q4_K gather - kernel). -* **CUDA**: Q4_K -> ``Int4Tensor``, Q6_K -> ``IntxUnpackedToInt8Tensor``; the - token embedding is dequantized to bf16 (``Int4Tensor`` can't gather). +* **MLX**: every quantized weight stays an ``ExportableGGUFTensor`` and is lowered + by the MLX GGUF pattern (Q6_K custom kernels, Q4_K native affine ops) for both + linear and embedding. ``embed_tokens`` and ``lm_head`` stay tied -- they share + the one quantized tensor. +* **CUDA**: Q4_K -> ``Int4Tensor``, Q6_K -> ``IntxUnpackedToInt8Tensor``; + ``lm_head`` keeps the quantized tensor but the token embedding is dequantized to + bf16 (``Int4Tensor`` can't gather), so they are untied. Usage: model, config = load_gguf_model("model.gguf", backend="cuda") @@ -124,12 +124,11 @@ def load_gguf_model( ) -> tuple: """Load a GGUF file, remap keys, and convert weights for the target backend. - Streams tensors one at a time for low peak memory. - - GGUF ties ``embed_tokens`` and ``lm_head`` into a single tensor. We untie - them so ``lm_head`` keeps its quantization: on MLX it lowers through the GGUF - linear pattern; on CUDA it stays a quantized ``Int4Tensor`` / - ``IntxUnpackedToInt8Tensor``, while the embedding is dequantized to bf16. + Streams tensors one at a time for low peak memory. GGUF ties ``embed_tokens`` + and ``lm_head``: on MLX they stay tied (one shared quantized tensor); on CUDA + they are untied so the embedding can be dequantized for the gather while + ``lm_head`` keeps its quantization. See the module docstring for the + per-backend conversion details. Returns ``(model, config)``. """ diff --git a/extension/llm/export/gguf.py b/extension/llm/export/gguf.py index fe0e78c51fa..1ffb0435eb9 100644 --- a/extension/llm/export/gguf.py +++ b/extension/llm/export/gguf.py @@ -7,31 +7,22 @@ """Export-time GGUF quantized weights. -``ExportableGGUFTensor`` is the canonical **loading representation** for a -GGUF-quantized weight: it wraps the *raw* GGUF block bytes for one tensor and -defers all unpacking. The intended flow: - -1. **Load**: ``load_gguf(path)`` -> ``dict[name -> ExportableGGUFTensor | Tensor]`` - (quantized tensors become ``ExportableGGUFTensor``; F32/F16 become plain - tensors). No unpacking happens at load. -2. **Lower (dequantize)**: used as a weight, the subclass dequantizes via the - ``torchao::dequantize_gguf`` custom op (gguf-package eager body) and runs the - plain torch ``linear`` / ``embedding`` (NVFP4-style). A backend can - pattern-match ``dequantize_gguf`` -> linear/embedding to fuse. -3. **Convert**: ``.to_int4_tensor()`` / ``.to_intx_unpacked_to_int8_tensor()`` - unpack into torchao tensor subclasses (``Int4Tensor`` for Q4_K, - ``IntxUnpackedToInt8Tensor`` for Q4_K or Q6_K) to take the non-fused - (affine-dequant) path instead. - -The GGUF quant type is identified by a **string** (``"q4_k"``, ``"q6_k"``) -everywhere user-facing (subclass attribute + ``dequantize_gguf`` op argument); the -``gguf`` package's integer ``GGMLQuantizationType`` ids are an internal lookup -detail. - -Backend-agnostic; depends on ``torch``, ``torchao``, ``numpy``, and the ``gguf`` -package. The *policy* of which tensors to convert is left to the caller. - -Attribution: the Q4_K / Q6_K block layouts follow llama.cpp / gguf-py +``ExportableGGUFTensor`` wraps the *raw* GGUF block bytes for one tensor and +defers all unpacking, serving as the canonical GGUF loading representation: + +* ``load_gguf(path)`` -> ``{name -> ExportableGGUFTensor | Tensor}`` (quantized + tensors become subclasses; F32/F16 stay plain). No unpacking at load. +* As a weight, it dequantizes via the ``torchao::dequantize_gguf`` custom op + (gguf-package eager body) then a plain ``linear`` / ``embedding`` -- a backend + can pattern-match ``dequantize_gguf`` -> linear/embedding to fuse. +* ``.to_int4_tensor()`` / ``.to_intx_unpacked_to_int8_tensor()`` convert into + torchao subclasses (``Int4Tensor`` / ``IntxUnpackedToInt8Tensor``) instead. + +The quant type is a string (``"q4_k"`` / ``"q6_k"``); the ``gguf`` package's +integer ``GGMLQuantizationType`` ids are an internal lookup detail. Which tensors +to convert is the caller's policy. + +Attribution: Q4_K / Q6_K block layouts follow llama.cpp / gguf-py (``ggml-common.h``), MIT-licensed (Copyright (c) 2023-2024 The ggml authors). """ @@ -46,9 +37,7 @@ aten = torch.ops.aten -# --------------------------------------------------------------------------- # GGUF k-quant constants -# --------------------------------------------------------------------------- QK_K = 256 # super-block size for k-quants @@ -94,9 +83,7 @@ def _dequantize_gguf(raw: Tensor, ggml_type: str, output_dtype: torch.dtype) -> ) -# --------------------------------------------------------------------------- # Fused ops (eager = gguf.dequantize + torch op; a backend may lower to kernels) -# --------------------------------------------------------------------------- @torch.library.custom_op("torchao::dequantize_gguf", mutates_args=()) @@ -115,9 +102,7 @@ def _(weight, ggml_type, output_dtype=torch.bfloat16): return torch.empty((weight.shape[0], K), dtype=output_dtype, device=weight.device) -# --------------------------------------------------------------------------- # Per-type field extraction (used by the to_*_tensor conversions) -# --------------------------------------------------------------------------- def _q4_k_fields(raw: Tensor, N: int, K: int) -> Tuple[Tensor, Tensor, Tensor]: @@ -195,9 +180,7 @@ def _q6_k_fields(raw: Tensor, N: int, K: int) -> Tuple[Tensor, Tensor]: return q.reshape(N, K).to(torch.int8), eff_scale -# --------------------------------------------------------------------------- # Tensor subclass -# --------------------------------------------------------------------------- class ExportableGGUFTensor(TorchAOBaseTensor): @@ -240,8 +223,6 @@ def __new__(cls, raw: Tensor, ggml_type: str, orig_dtype: torch.dtype): self.orig_dtype = orig_dtype return self - # -- construction -------------------------------------------------------- - @classmethod def from_raw( cls, @@ -252,16 +233,12 @@ def from_raw( """Build from a ``(N, row_bytes)`` uint8 GGUF block blob.""" return cls(raw.contiguous(), ggml_type, orig_dtype) - # -- dequant (via gguf package) ------------------------------------------ - def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> Tensor: """Dequantize to a plain float tensor using the ``gguf`` package.""" return torch.ops.torchao.dequantize_gguf( self.raw, self.ggml_type, output_dtype or self.orig_dtype ) - # -- conversions (unpack lives here) ------------------------------------- - def to_int4_tensor(self) -> Tensor: """Convert a Q4_K tensor to a torchao ``Int4Tensor``.""" from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor @@ -365,9 +342,7 @@ def _(func, types, args, kwargs): return args[0].dequantize(output_dtype=kwargs.get("dtype", args[0].orig_dtype)) -# --------------------------------------------------------------------------- # Loader -# --------------------------------------------------------------------------- def iter_gguf( diff --git a/extension/llm/export/test/test_gguf.py b/extension/llm/export/test/test_gguf.py index f6cd52a69ad..13e2dff53fc 100644 --- a/extension/llm/export/test/test_gguf.py +++ b/extension/llm/export/test/test_gguf.py @@ -10,8 +10,8 @@ The reference oracle is the ``gguf`` package's own ``gguf.dequantize`` (which can dequantize Q4_K / Q6_K). We validate that: -* ``ExportableGGUFTensor.dequantize`` (and the fused ``torchao::gguf_*`` ops, - whose eager bodies use ``gguf``) reproduce ``gguf.dequantize``; +* ``ExportableGGUFTensor.dequantize`` (and the ``torchao::dequantize_gguf`` op, + whose eager body uses ``gguf``) reproduces ``gguf.dequantize``; * our hand-written ``to_int4_tensor`` / ``to_intx_unpacked_to_int8_tensor`` unpack matches ``gguf.dequantize`` (within bf16 storage tolerance); * using the subclass as a weight dispatches linear/embedding to the fused ops. From c68157553e3f51cbe9b782954db95035ad0a5d61 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 8 Jun 2026 10:37:07 -0700 Subject: [PATCH 16/18] up --- .github/workflows/mlx.yml | 2 + examples/models/gemma4_31b/gguf_loader.py | 12 +-- examples/models/gemma4_31b/model.md | 6 +- examples/models/gemma4_31b/quant/README.md | 6 +- examples/models/gemma4_31b/quant/pack_mlx.py | 13 ++- .../gemma4_31b/quant/tests/test_pack_mlx.py | 88 ++----------------- .../gemma4_31b/tests/test_cuda_pipeline.py | 54 ++++++++++++ .../gemma4_31b/tests/test_mlx_pipeline.py | 42 +++++++++ .../models/gemma4_31b/tests/test_pipeline.py | 80 +++++++++++++++++ requirements-dev.txt | 3 +- 10 files changed, 206 insertions(+), 100 deletions(-) diff --git a/.github/workflows/mlx.yml b/.github/workflows/mlx.yml index c1c7af82a06..bd6c8f3ed06 100644 --- a/.github/workflows/mlx.yml +++ b/.github/workflows/mlx.yml @@ -13,6 +13,7 @@ on: - backends/mlx/** - extension/llm/export/** - extension/audio/** + - examples/models/gemma4_31b/** - examples/models/parakeet/** - examples/models/voxtral_realtime/** - examples/models/qwen3_5_moe/** @@ -78,6 +79,7 @@ jobs: backends/mlx/test/test_pattern_utils.py \ backends/mlx/test/test_partitioner.py \ backends/mlx/test/test_serialization_dedup.py \ + examples/models/gemma4_31b/quant/tests/test_pack_mlx.py \ examples/models/gemma4_31b/tests/test_mlx_pipeline.py \ -v echo "::endgroup::" diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index 6122af59944..bab314f29f9 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -87,11 +87,6 @@ def _validate_no_meta(model): p.requires_grad_(False) -def _is_embedding(model, model_key: str) -> bool: - parent = model.get_submodule(model_key.rsplit(".", 1)[0]) - return isinstance(parent, torch.nn.Embedding) - - def _convert_weight(model, model_key: str, gtensor, backend: str): """Convert an ``ExportableGGUFTensor`` to the per-backend module weight.""" if backend == "mlx": @@ -121,6 +116,7 @@ def load_gguf_model( gguf_path: str, max_seq_len: int = 4096, backend: str = "cuda", + config=None, ) -> tuple: """Load a GGUF file, remap keys, and convert weights for the target backend. @@ -130,6 +126,9 @@ def load_gguf_model( ``lm_head`` keeps its quantization. See the module docstring for the per-backend conversion details. + ``config`` defaults to the full Gemma 4 31B config; pass a smaller + ``Gemma4_31BConfig`` (e.g. in tests) to load a GGUF for a tiny model. + Returns ``(model, config)``. """ from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig @@ -147,7 +146,8 @@ def load_gguf_model( else: raise ValueError(f"Unsupported backend: {backend!r}. Supported: 'cuda', 'mlx'.") - config = Gemma4_31BConfig(max_seq_len=max_seq_len) + if config is None: + config = Gemma4_31BConfig(max_seq_len=max_seq_len) print("Building model on meta device...") with torch.device("meta"): diff --git a/examples/models/gemma4_31b/model.md b/examples/models/gemma4_31b/model.md index 13207bdbb06..32f407c6b40 100644 --- a/examples/models/gemma4_31b/model.md +++ b/examples/models/gemma4_31b/model.md @@ -154,8 +154,10 @@ Modules in `quant/`: packers dispatch by module type (`nn.Linear`, `nn.Embedding`). CUDA passes Int4Tensor through (dispatch handled by `int4_dispatch.py`); MLX converts Int4Tensor → IntxUnpackedToInt8Tensor and regroups per-axis embeddings. -- **GGUF** (`gguf.py`): `unpack_gguf_tensor` / `iter_gguf_tensors` for - loading community-quantized GGUF files (Q4_K, Q6_K). +- **GGUF**: community-quantized GGUF files (Q4_K, Q6_K) are loaded by the + shared, backend-agnostic `extension/llm/export/gguf.py` (`load_gguf` / + `iter_gguf` → `ExportableGGUFTensor`); `gguf_loader.py` remaps GGUF names to + model FQNs and picks the per-backend weight representation. The quantize-once flow: diff --git a/examples/models/gemma4_31b/quant/README.md b/examples/models/gemma4_31b/quant/README.md index 92ddbf97243..8906a0faede 100644 --- a/examples/models/gemma4_31b/quant/README.md +++ b/examples/models/gemma4_31b/quant/README.md @@ -11,7 +11,9 @@ Quantization framework: **recipe → quantize → pack**. | `pack.py` | **Packing dispatch** — `pack_model` (bulk) and `pack_one` (streaming) | — | | `pack_cuda.py` | **CUDA packing** — passes Int4Tensor/IntxUnpacked through for CUDA dispatch | pack | | `pack_mlx.py` | **MLX packing** — converts Int4Tensor → IntxUnpacked, regroups per-axis embeddings | pack | -| `gguf.py` | **GGUF import** — unpacks Q4_K/Q6_K blocks to torchao subclasses | torchao | + +GGUF import (unpacking Q4_K/Q6_K blocks) now lives in the shared +`extension/llm/export/gguf.py`. ## Data flow @@ -49,4 +51,4 @@ The format is compatible with torchao's `save_pretrained` / `load_pretrained`. ## TODO - `pack_metal.py` — Metal backend packer. -- `gguf.py` — extend with Q5_K, Q8_0 GGUF quant types. +- GGUF quant types (Q5_K, Q8_0): extend `extension/llm/export/gguf.py`. diff --git a/examples/models/gemma4_31b/quant/pack_mlx.py b/examples/models/gemma4_31b/quant/pack_mlx.py index 98f5470f184..22f525accd2 100644 --- a/examples/models/gemma4_31b/quant/pack_mlx.py +++ b/examples/models/gemma4_31b/quant/pack_mlx.py @@ -83,13 +83,12 @@ def _regroup_intx(w: torch.Tensor, new_gs: int) -> torch.Tensor: def pack_for_mlx(module: nn.Module, weights: dict[str, torch.Tensor]) -> None: """Pack a quantized weight for MLX. - ``Int4Tensor`` is converted to ``IntxUnpackedToInt8Tensor`` so the - default dispatch produces the ``dequantize_affine → linear`` pattern - MLX expects. Regroups to a compatible group_size when needed (e.g. - per-axis group_size=5376 → group_size=128) since MLX's - ``parse_dequant_node`` only accepts group_size in {16, 32, 64, 128}. - Group sizes ≥ 32 use the fused ``QuantizedMatmulNode``; group_size=16 - (e.g. GGUF Q6_K) falls back to ``DequantizeNode`` + matmul at export. + ``Int4Tensor`` is wrapped as ``ExportableInt4Tensor`` (exports to + ``dequantize_int4_tensor → linear/embedding``). ``IntxUnpackedToInt8Tensor`` + is assigned directly, regrouped to a compatible group_size when needed (e.g. + per-axis group_size=5376 → 128) since MLX accepts group_size in + {16, 32, 64, 128}. Group sizes ≥ 32 use the fused ``QuantizedMatmulNode``; + group_size=16 (e.g. GGUF Q6_K) falls back to ``DequantizeNode`` + matmul. """ from executorch.extension.llm.export.int4 import ExportableInt4Tensor from torchao.quantization import IntxUnpackedToInt8Tensor diff --git a/examples/models/gemma4_31b/quant/tests/test_pack_mlx.py b/examples/models/gemma4_31b/quant/tests/test_pack_mlx.py index 2e6310b9c10..4ff2c4149cf 100644 --- a/examples/models/gemma4_31b/quant/tests/test_pack_mlx.py +++ b/examples/models/gemma4_31b/quant/tests/test_pack_mlx.py @@ -13,7 +13,6 @@ from executorch.examples.models.gemma4_31b.quant.pack import pack_model from executorch.examples.models.gemma4_31b.quant.pack_mlx import ( - _int4_to_intx_unpacked, _mlx_group_size, DEFAULT_MLX_PACKERS, pack_for_mlx, @@ -25,91 +24,16 @@ from executorch.examples.models.gemma4_31b.quant.recipe import QuantConfig -class TestInt4ToIntxConversion(unittest.TestCase): - """Int4Tensor → IntxUnpackedToInt8Tensor conversion.""" - - def test_symmetric_dequant_matches(self): - """Converted weight dequantizes to same values as original.""" - torch.manual_seed(0) - weight = torch.randn(64, 128, dtype=torch.bfloat16) - config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") - int4_w = quantize_weight(weight, config) - intx_w = _int4_to_intx_unpacked(int4_w) - - int4_dense = dequantize_weight(int4_w, torch.float32) - intx_dense = dequantize_weight(intx_w, torch.float32) - self.assertTrue( - torch.allclose(int4_dense, intx_dense, atol=1e-5), - f"max diff: {(int4_dense - intx_dense).abs().max():.6g}", - ) - - def test_asymmetric_dequant_matches(self): - torch.manual_seed(0) - weight = torch.randn(64, 128, dtype=torch.bfloat16) - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - int4_w = quantize_weight(weight, config) - intx_w = _int4_to_intx_unpacked(int4_w) - - int4_dense = dequantize_weight(int4_w, torch.float32) - intx_dense = dequantize_weight(intx_w, torch.float32) - self.assertTrue( - torch.allclose(int4_dense, intx_dense, atol=1e-5), - f"max diff: {(int4_dense - intx_dense).abs().max():.6g}", - ) - - def test_output_type_and_shape(self): - from torchao.quantization import IntxUnpackedToInt8Tensor - - torch.manual_seed(0) - config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") - int4_w = quantize_weight(torch.randn(128, 256, dtype=torch.bfloat16), config) - intx_w = _int4_to_intx_unpacked(int4_w) - - self.assertIsInstance(intx_w, IntxUnpackedToInt8Tensor) - self.assertEqual(intx_w.shape, torch.Size([128, 256])) - self.assertEqual(intx_w.qdata.shape, torch.Size([128, 256])) - self.assertEqual(intx_w.target_dtype, torch.int4) - - def test_different_group_sizes(self): - torch.manual_seed(0) - for gs in (32, 64, 128): - with self.subTest(group_size=gs): - config = QuantConfig( - bits=4, group_size=gs, symmetric=True, method="min_max" - ) - int4_w = quantize_weight( - torch.randn(64, 256, dtype=torch.bfloat16), config - ) - intx_w = _int4_to_intx_unpacked(int4_w) - self.assertEqual(intx_w.shape, torch.Size([64, 256])) - - def test_matmul_approximates_original(self): - torch.manual_seed(0) - weight = torch.randn(256, 128, dtype=torch.bfloat16) - x = torch.randn(1, 128, dtype=torch.bfloat16) - original_out = torch.nn.functional.linear(x, weight) - - config = QuantConfig(bits=4, group_size=32, symmetric=False, method="min_max") - int4_w = quantize_weight(weight, config) - intx_w = _int4_to_intx_unpacked(int4_w) - packed_out = torch.nn.functional.linear(x, intx_w.dequantize()) - - rel_error = ( - packed_out.float() - original_out.float() - ).abs().mean() / original_out.float().abs().mean() - self.assertLess(rel_error.item(), 0.15) - - class TestPackLinearForMlx(unittest.TestCase): - def test_int4_converts_to_intx(self): - from torchao.quantization import IntxUnpackedToInt8Tensor + def test_int4_wraps_exportable(self): + from executorch.extension.llm.export.int4 import ExportableInt4Tensor module = nn.Linear(128, 64, bias=False) config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") w = quantize_weight(torch.randn(64, 128, dtype=torch.bfloat16), config) pack_for_mlx(module, {"weight": w}) - self.assertIsInstance(module.weight.data, IntxUnpackedToInt8Tensor) + self.assertIsInstance(module.weight.data, ExportableInt4Tensor) self.assertEqual(module.weight.shape, torch.Size([64, 128])) self.assertFalse(module.weight.requires_grad) @@ -218,14 +142,14 @@ def test_per_axis_regroups(self): self.assertEqual(module.weight.shape, torch.Size([50, 256])) self.assertEqual(module.weight.data.block_size, (1, 128)) - def test_int4_converts_to_intx(self): - from torchao.quantization import IntxUnpackedToInt8Tensor + def test_int4_wraps_exportable(self): + from executorch.extension.llm.export.int4 import ExportableInt4Tensor module = nn.Embedding(100, 64) config = QuantConfig(bits=4, group_size=32, symmetric=True, method="min_max") w = quantize_weight(torch.randn(100, 64, dtype=torch.bfloat16), config) pack_for_mlx(module, {"weight": w}) - self.assertIsInstance(module.weight.data, IntxUnpackedToInt8Tensor) + self.assertIsInstance(module.weight.data, ExportableInt4Tensor) self.assertEqual(module.weight.shape, torch.Size([100, 64])) diff --git a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py index 505d6f7bdc1..29a28754e1d 100644 --- a/examples/models/gemma4_31b/tests/test_cuda_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_cuda_pipeline.py @@ -28,6 +28,7 @@ export_and_lower, load_prequantized_model, ) +from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model from executorch.examples.models.gemma4_31b.inference import _move_to_cuda, generate from executorch.examples.models.gemma4_31b.model import Gemma4_31B from executorch.examples.models.gemma4_31b.quant import ( @@ -36,8 +37,10 @@ quantize_model, ) from executorch.examples.models.gemma4_31b.tests.test_pipeline import ( + build_gguf_checkpoint, build_hf_checkpoint, DEFAULT_RECIPE, + GGUF_CONFIG, MockTokenizer, save_checkpoint, TINY_CONFIG, @@ -225,5 +228,56 @@ def test_embedding_works(self): self.assertFalse(emb.isnan().any()) +class TestGgufCudaPipeline(unittest.TestCase): + """GGUF -> CUDA load -> inference -> export (mirrors TestGgufLinearMlx).""" + + def setUp(self): + _require_cuda(self) + try: + import gguf # noqa: F401 + except ImportError: + self.skipTest("gguf package required") + + def _load(self, tmp): + path = os.path.join(tmp, "tiny.gguf") + build_gguf_checkpoint(path) + return load_gguf_model(path, backend="cuda", config=GGUF_CONFIG) + + def test_load_converts_weights(self): + """GGUF -> CUDA: Q4_K -> Int4Tensor, Q6_K -> IntxUnpacked, embedding bf16.""" + from torchao.quantization import IntxUnpackedToInt8Tensor + from torchao.quantization.quantize_.workflows.int4.int4_tensor import Int4Tensor + + with tempfile.TemporaryDirectory() as tmp: + model, _ = self._load(tmp) + + self.assertIsInstance(model.layers[0].self_attn.q_proj.weight.data, Int4Tensor) + self.assertIsInstance( + model.layers[0].mlp.down_proj.weight.data, IntxUnpackedToInt8Tensor + ) + # Token embedding is dequantized to bf16 (Int4/Intx can't gather). + self.assertEqual(model.embed_tokens.weight.dtype, torch.bfloat16) + + def test_generate(self): + """GGUF -> CUDA -> eager generate produces valid tokens (inference.py).""" + with tempfile.TemporaryDirectory() as tmp: + model, config = self._load(tmp) + _move_to_cuda(model, config) + model.eval() + tokenizer = MockTokenizer(GGUF_CONFIG.vocab_size) + + torch.manual_seed(0) + out = generate(model, tokenizer, prompt="hi", max_new_tokens=3, temperature=1.0) + self.assertIsInstance(out, str) + self.assertGreater(len(out), 0) + + def test_export(self): + """GGUF -> CUDA -> export_and_lower produces a .pte (export.py).""" + with tempfile.TemporaryDirectory() as tmp, tempfile.TemporaryDirectory() as out_dir: + model, config = self._load(tmp) + export_and_lower(model, config, out_dir) + self.assertTrue(os.path.exists(os.path.join(out_dir, "model.pte"))) + + if __name__ == "__main__": unittest.main() diff --git a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py index 2753c62bb8c..b26e2783aa6 100644 --- a/examples/models/gemma4_31b/tests/test_mlx_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_mlx_pipeline.py @@ -30,8 +30,10 @@ QuantRule, ) from executorch.examples.models.gemma4_31b.tests.test_pipeline import ( + build_gguf_checkpoint, build_random_tiny_model, config_dict, + GGUF_CONFIG, save_checkpoint, TINY_CONFIG, ) @@ -485,5 +487,45 @@ def test_int4_embedding_delegates(self): ) +class TestGgufLoadMlx(unittest.TestCase): + """GGUF file -> load_gguf_model(mlx) -> export (parity with the CUDA test).""" + + def setUp(self): + try: + import gguf # noqa: F401 + except ImportError: + self.skipTest("gguf package required") + + def _load(self, tmp): + from executorch.examples.models.gemma4_31b.gguf_loader import load_gguf_model + + path = os.path.join(tmp, "tiny.gguf") + build_gguf_checkpoint(path) + return load_gguf_model(path, backend="mlx", config=GGUF_CONFIG) + + def test_load_keeps_gguf_tensors_and_ties_lm_head(self): + """MLX keeps weights as ExportableGGUFTensor; lm_head stays tied.""" + from executorch.extension.llm.export.gguf import ExportableGGUFTensor + + with tempfile.TemporaryDirectory() as tmp: + model, _ = self._load(tmp) + + self.assertIsInstance( + model.layers[0].self_attn.q_proj.weight.data, ExportableGGUFTensor + ) + self.assertIsInstance(model.embed_tokens.weight.data, ExportableGGUFTensor) + # GGUF ties embed/lm_head; on MLX they share the one quantized tensor. + self.assertIs(model.lm_head.weight.data, model.embed_tokens.weight.data) + + def test_export(self): + """GGUF -> MLX load -> export_and_lower produces a .pte (export.py).""" + from executorch.examples.models.gemma4_31b.export import export_and_lower + + with tempfile.TemporaryDirectory() as tmp, tempfile.TemporaryDirectory() as out_dir: + model, config = self._load(tmp) + export_and_lower(model, config, out_dir, backend="mlx") + self.assertTrue(os.path.exists(os.path.join(out_dir, "model.pte"))) + + if __name__ == "__main__": unittest.main() diff --git a/examples/models/gemma4_31b/tests/test_pipeline.py b/examples/models/gemma4_31b/tests/test_pipeline.py index a8d9d9cbe34..3275542d402 100644 --- a/examples/models/gemma4_31b/tests/test_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_pipeline.py @@ -158,6 +158,86 @@ def build_hf_checkpoint(output_dir: str) -> None: json.dump(config_dict(), f) +# GGUF-friendly tiny config: Q4_K/Q6_K need in-features that are multiples of 256, +# so hidden/intermediate are 256. Two layers exercises a sliding + a global layer. +GGUF_CONFIG = Gemma4_31BConfig( + vocab_size=256, + hidden_size=256, + intermediate_size=256, + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=1, + head_dim=256, + global_head_dim=512, + sliding_window=16, + max_seq_len=64, +) + + +def _model_to_gguf_key(fqn: str): + """Invert ``gguf_loader._KEY_MAP`` (model FQN -> GGUF tensor name).""" + from executorch.examples.models.gemma4_31b.gguf_loader import _KEY_MAP + + for gguf_pat, model_pat in _KEY_MAP.items(): + if "{}" not in model_pat: + if fqn == model_pat: + return gguf_pat + continue + prefix, suffix = model_pat.split("{}") + if fqn.startswith(prefix) and fqn.endswith(suffix): + idx = fqn[len(prefix) : len(fqn) - len(suffix)] + if idx.isdigit(): + return gguf_pat.replace("{}", idx) + return None + + +def build_gguf_checkpoint(path: str, config: Gemma4_31BConfig = GGUF_CONFIG) -> None: + """Write a tiny GGUF file matching ``config``. + + Linears are Q4_K; ``ffn_down`` / ``token_embd`` are Q6_K (to exercise both + GGUF unpack paths); norms / scalars are F32. Tensor shapes are derived from + the instantiated model so per-layer-type differences (e.g. global layers + having no v_proj / q_norm) are handled automatically. ``output.weight`` is + omitted -- GGUF ties lm_head to the token embedding. Requires the ``gguf`` + package. + """ + import gguf + from executorch.extension.llm.export.gguf import QK_K + from executorch.extension.llm.export.test.test_gguf import ( + _make_q4k_raw, + _make_q6k_raw, + ) + + with torch.device("meta"): + model = Gemma4_31B(config) + + writer = gguf.GGUFWriter(path, "gemma") + for fqn, p in model.named_parameters(): + gguf_key = _model_to_gguf_key(fqn) + if gguf_key is None: + continue + if p.dim() == 2: + N, K = int(p.shape[0]), int(p.shape[1]) + nb = K // QK_K + use_q6 = gguf_key == "token_embd.weight" or gguf_key.endswith( + "ffn_down.weight" + ) + blob = (_make_q6k_raw(N, nb) if use_q6 else _make_q4k_raw(N, nb)).numpy() + raw_dtype = ( + gguf.GGMLQuantizationType.Q6_K + if use_q6 + else gguf.GGMLQuantizationType.Q4_K + ) + writer.add_tensor(gguf_key, blob, raw_dtype=raw_dtype) + else: + arr = (torch.randn(tuple(p.shape), dtype=torch.float32) * 0.1).numpy() + writer.add_tensor(gguf_key, arr) + writer.write_header_to_file() + writer.write_kv_data_to_file() + writer.write_tensors_to_file() + writer.close() + + # --------------------------------------------------------------------------- # Tests (CPU only, no backend dependency) diff --git a/requirements-dev.txt b/requirements-dev.txt index d2c3b5fcc20..71c68c968ec 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,4 +14,5 @@ lintrunner-adapters==0.14.0 pytest<9.0 pytest-xdist pytest-rerunfailures==15.1 -pytest-json-report \ No newline at end of file +pytest-json-report +gguf # For extension/llm/export/test/test_gguf.py (GGUF Q4_K/Q6_K dequant tests). \ No newline at end of file From 1085c195398b072c68ed09cf9deff5819ba92ce6 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 8 Jun 2026 12:08:10 -0700 Subject: [PATCH 17/18] up --- examples/models/gemma4_31b/gguf_loader.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/examples/models/gemma4_31b/gguf_loader.py b/examples/models/gemma4_31b/gguf_loader.py index bab314f29f9..5d7c5ec540d 100644 --- a/examples/models/gemma4_31b/gguf_loader.py +++ b/examples/models/gemma4_31b/gguf_loader.py @@ -131,7 +131,11 @@ def load_gguf_model( Returns ``(model, config)``. """ - from executorch.examples.models.gemma4_31b.model import Gemma4_31B, Gemma4_31BConfig + from executorch.examples.models.gemma4_31b.model import ( + Gemma4_31B, + Gemma4_31BConfig, + materialize_runtime_buffers, + ) from executorch.examples.models.gemma4_31b.quant import dequantize_weight, pack_one from executorch.extension.llm.export.gguf import ExportableGGUFTensor, iter_gguf @@ -182,6 +186,11 @@ def load_gguf_model( _resolve_tied_lm_head(model, lm_head_weight, packers) + # Fill RoPE tables / KV caches / scalar constants (left on meta by the + # streaming load), matching load_prequantized_model so the CUDA and eager + # forward paths get bf16 runtime buffers instead of float32 defaults. + materialize_runtime_buffers(model, dtype=torch.bfloat16) + _validate_no_meta(model) model.eval() From 2743bd6f954b79e0de0a60aed99a4f7a83e18329 Mon Sep 17 00:00:00 2001 From: Scott Roy Date: Mon, 8 Jun 2026 13:22:02 -0700 Subject: [PATCH 18/18] up --- examples/models/gemma4_31b/tests/test_pipeline.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/models/gemma4_31b/tests/test_pipeline.py b/examples/models/gemma4_31b/tests/test_pipeline.py index 3275542d402..f81d68c623a 100644 --- a/examples/models/gemma4_31b/tests/test_pipeline.py +++ b/examples/models/gemma4_31b/tests/test_pipeline.py @@ -232,6 +232,16 @@ def build_gguf_checkpoint(path: str, config: Gemma4_31BConfig = GGUF_CONFIG) -> else: arr = (torch.randn(tuple(p.shape), dtype=torch.float32) * 0.1).numpy() writer.add_tensor(gguf_key, arr) + # Per-layer scalars are buffers (not parameters) but are stored in real + # GGUFs (e.g. blk.N.layer_output_scale.weight). Write the ones that have a + # GGUF mapping so they load as bf16; runtime buffers (RoPE, KV cache, ...) + # map to None and are skipped. + for fqn, b in model.named_buffers(): + gguf_key = _model_to_gguf_key(fqn) + if gguf_key is None: + continue + arr = torch.ones(tuple(b.shape), dtype=torch.float32).numpy() + writer.add_tensor(gguf_key, arr) writer.write_header_to_file() writer.write_kv_data_to_file() writer.write_tensors_to_file()