Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions bitsandbytes/backends/cpu/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,13 @@ def _(
shape_fallback = shape[-1] % 2 != 0

if avx512_fallback or shape_fallback:
from ..default.ops import _dequantize_4bit_impl
from ..default.ops import _dequantize_4bit_compute
from ..utils import _get_4bit_code

return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
if A.dtype != torch.uint8:
A = A.view(torch.uint8)
code = _get_4bit_code(quant_type, A.device)
return _dequantize_4bit_compute(A.reshape(-1), absmax, code, blocksize, shape, dtype)

# Enable non uint8 dtype
if A.dtype != torch.uint8:
Expand Down
180 changes: 85 additions & 95 deletions bitsandbytes/backends/default/ops.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from collections.abc import Sequence
from functools import wraps
from functools import cache, wraps
from math import sqrt
from typing import Optional

import torch

from ..._ops import register_kernel
from ..utils import CODE
from ..utils import _get_4bit_code


def _try_torch_compile(func=None, **compile_kwargs):
Expand Down Expand Up @@ -179,125 +179,112 @@ def _(A: torch.Tensor, threshold=0.0):

@register_kernel("bitsandbytes::quantize_blockwise", "default")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
n = A.numel()
A_flat = A.reshape(-1).float()
n = A_flat.numel()
rem = n % blocksize
has_rem = rem > 0
blocks = n // blocksize + has_rem
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
A_reshaped = A.reshape(n)
A_com = A_reshaped[: n - rem]
A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
scaled_A = scaled_A.reshape(-1)
if has_rem:
absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)

diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)

return out, absmax
full = n - rem
blocks = full // blocksize
A_com = A_flat[:full].reshape(blocks, blocksize)
absmax = A_com.abs().max(dim=-1)[0]
scaled = torch.clamp(A_com * (1.0 / absmax.clamp(min=1e-38).view(-1, 1)), -1, 1).reshape(-1)
if rem:
am = A_flat[full:].abs().max().clamp(min=1e-38)
absmax = torch.cat([absmax, am.unsqueeze(0)])
scaled = torch.cat([scaled, torch.clamp(A_flat[full:] / am, -1, 1)])
bounds = (code[:-1] + code[1:]) / 2 # code is always sorted (same assumption as CUDA kernel)
q = torch.bucketize(scaled, bounds, out_int32=True).to(torch.uint8)
return q.reshape(A.shape), absmax


@_try_torch_compile(dynamic=True)
def _dequantize_blockwise_compute(
A_flat: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
):
n = A_flat.numel()
out = code[A_flat.to(torch.int64)]
rem = n % blocksize
if rem == 0:
out = (out.reshape(-1, blocksize) * absmax.view(-1, 1)).reshape(n)
else:
full = n - rem
blocks = full // blocksize
out = torch.cat(
[
(out[:full].reshape(blocks, blocksize) * absmax[:blocks].view(-1, 1)).reshape(full),
out[full:] * absmax[blocks],
]
)
return out.to(dtype)


@register_kernel("bitsandbytes::dequantize_blockwise", "default")
def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
out = code[A.reshape(-1).int()]
blocks = out.shape[-1] // blocksize
res = out.shape[-1] % blocksize
if res != 0:
out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
out = out[: blocks * blocksize + res]
out = out.reshape(A.shape)
return _dequantize_blockwise_compute(A.reshape(-1), absmax, code, blocksize, dtype).reshape(A.shape)


return out
@cache
def _get_4bit_quantize_bounds(quant_type: str, device: torch.device):
code = _get_4bit_code(quant_type, device)
order = torch.argsort(code)
midpoints = (code[order[:-1]] + code[order[1:]]) / 2
return midpoints, order # NF4 order is identity (sorted); FP4 needs remap


@register_kernel("bitsandbytes::quantize_4bit", "default")
def _(
A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
n = A.numel()
full_blocks = n // blocksize
bounds, order = _get_4bit_quantize_bounds(quant_type, A.device)
A_flat = A.reshape(-1).float()
n = A_flat.numel()
rem = n % blocksize
blocks = full_blocks + 1 if rem else full_blocks
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
A_flattened = A.reshape(n)

# Scale full blocks of the tensor to [-1, 1]
A_full_blocks = A_flattened[: n - rem].reshape(n // blocksize, blocksize)
absmax[:full_blocks] = torch.abs(A_full_blocks).max(dim=-1)[0]
scaled = torch.clamp(A_full_blocks * (1 / absmax[:full_blocks].view(-1, 1)), -1, 1).reshape(-1)

# Scale any partial block
full = n - rem
blocks = full // blocksize
A_com = A_flat[:full].reshape(blocks, blocksize)
absmax = A_com.abs().max(dim=-1)[0]
scaled = torch.clamp(A_com * (1.0 / absmax.clamp(min=1e-38).view(-1, 1)), -1, 1).reshape(-1)
if rem:
A_rem = A_flattened[-rem:]
absmax[-1] = torch.abs(A_rem).max()
scaled_rem = torch.clamp(A_rem * (1 / absmax[-1]), -1, 1)
scaled = torch.cat([scaled, scaled_rem], dim=0)

# Quantize with the lookup table
code = CODE[quant_type].to(scaled.device).to(scaled.dtype)
# Pad to even length so packing pairs all elements
if scaled.numel() % 2 != 0:
scaled = torch.nn.functional.pad(scaled, (0, 1), value=0.0)
quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - code), dim=-1, keepdim=True).to(torch.uint8)

# Pack two quantized values per byte
packed = quantized[::2] << 4 | quantized[1::2]

am = A_flat[full:].abs().max().clamp(min=1e-38)
absmax = torch.cat([absmax, am.unsqueeze(0)])
scaled = torch.cat([scaled, torch.clamp(A_flat[full:] / am, -1, 1)])
if scaled.numel() % 2:
scaled = torch.nn.functional.pad(scaled, (0, 1))
q = torch.bucketize(scaled, bounds, out_int32=True)
if quant_type != "nf4":
q = order[q]
q8 = q.to(torch.uint8)
packed = ((q8[::2] << 4) | q8[1::2]).unsqueeze(1)
if quant_storage != torch.uint8:
packed = packed.squeeze().view(quant_storage).unsqueeze(1)

return packed, absmax.float()
return packed, absmax


def _dequantize_4bit_impl(
A: torch.Tensor,
@_try_torch_compile(dynamic=True)
def _dequantize_4bit_compute(
A_flat: torch.Tensor,
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
# Enable non uint8 dtype
if A.dtype != torch.uint8:
A = A.view(torch.uint8)

A = A.reshape(-1)
# Map nf4 to [-1, 1]
out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device)
out_dq[1::2] = A & 0xF
out_dq[::2] = A >> 4
# code is fp32, cast to dtype to avoid the mismatch issue
code = CODE[quant_type].to(dtype).to(A.device)
out_dq = code[out_dq]

# Use the actual output size, not the unpacked size (which may include padding)
):
n = 1
for s in shape:
n *= s
# Trim any extra elements from padding during quantization
out_dq = out_dq[:n]

# Apply scales
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
out_dq = torch.empty(A_flat.size(0) * 2, dtype=torch.int32, device=A_flat.device)
out_dq[1::2] = A_flat & 0xF
out_dq[::2] = A_flat >> 4
out_dq = code[out_dq][:n] # stays fp32, matches C++ / CUDA behavior
rem = n % blocksize
has_rem = rem > 0

out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1)
if has_rem:
out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1)
out[n - rem :] = out_dq[n - rem :] * absmax[-1]
if rem:
full = n - rem
blocks = full // blocksize
out = torch.empty(n, dtype=torch.float32, device=A_flat.device)
out[:full] = (out_dq[:full].view(-1, blocksize) * absmax[:blocks].view(-1, 1)).reshape(full)
out[full:] = out_dq[full:] * absmax[blocks]
else:
out = out_dq.view(-1, blocksize) * absmax.view(-1, 1)

out = out.reshape(-1, *shape[1:]).to(dtype)

return out
out = (out_dq.view(-1, blocksize) * absmax.view(-1, 1)).reshape(n)
return out.reshape(-1, *shape[1:]).to(dtype)


@register_kernel("bitsandbytes::dequantize_4bit", "default")
Expand All @@ -309,7 +296,10 @@ def _(
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype)
if A.dtype != torch.uint8:
A = A.view(torch.uint8)
code = _get_4bit_code(quant_type, A.device)
return _dequantize_4bit_compute(A.reshape(-1), absmax, code, blocksize, shape, dtype)


@register_kernel("bitsandbytes::gemv_4bit", "default")
Expand Down
Loading