diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py index ed6803eda..de07e74e3 100755 --- a/bitsandbytes/backends/cpu/ops.py +++ b/bitsandbytes/backends/cpu/ops.py @@ -149,9 +149,13 @@ def _( shape_fallback = shape[-1] % 2 != 0 if avx512_fallback or shape_fallback: - from ..default.ops import _dequantize_4bit_impl + from ..default.ops import _dequantize_4bit_compute + from ..utils import _get_4bit_code - return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) + if A.dtype != torch.uint8: + A = A.view(torch.uint8) + code = _get_4bit_code(quant_type, A.device) + return _dequantize_4bit_compute(A.reshape(-1), absmax, code, blocksize, shape, dtype) # Enable non uint8 dtype if A.dtype != torch.uint8: diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index 80d86321f..9aa2d60c2 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -1,12 +1,12 @@ from collections.abc import Sequence -from functools import wraps +from functools import cache, wraps from math import sqrt from typing import Optional import torch from ..._ops import register_kernel -from ..utils import CODE +from ..utils import _get_4bit_code def _try_torch_compile(func=None, **compile_kwargs): @@ -179,125 +179,112 @@ def _(A: torch.Tensor, threshold=0.0): @register_kernel("bitsandbytes::quantize_blockwise", "default") def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]: - n = A.numel() + A_flat = A.reshape(-1).float() + n = A_flat.numel() rem = n % blocksize - has_rem = rem > 0 - blocks = n // blocksize + has_rem - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - A_reshaped = A.reshape(n) - A_com = A_reshaped[: n - rem] - A_com_reshaped = A_com.reshape(n // blocksize, blocksize) - absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] - scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) - scaled_A = scaled_A.reshape(-1) - if has_rem: - absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() - scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) - scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) - - diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device)) - out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape) - - return out, absmax + full = n - rem + blocks = full // blocksize + A_com = A_flat[:full].reshape(blocks, blocksize) + absmax = A_com.abs().max(dim=-1)[0] + scaled = torch.clamp(A_com * (1.0 / absmax.clamp(min=1e-38).view(-1, 1)), -1, 1).reshape(-1) + if rem: + am = A_flat[full:].abs().max().clamp(min=1e-38) + absmax = torch.cat([absmax, am.unsqueeze(0)]) + scaled = torch.cat([scaled, torch.clamp(A_flat[full:] / am, -1, 1)]) + bounds = (code[:-1] + code[1:]) / 2 # code is always sorted (same assumption as CUDA kernel) + q = torch.bucketize(scaled, bounds, out_int32=True).to(torch.uint8) + return q.reshape(A.shape), absmax + + +@_try_torch_compile(dynamic=True) +def _dequantize_blockwise_compute( + A_flat: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype +): + n = A_flat.numel() + out = code[A_flat.to(torch.int64)] + rem = n % blocksize + if rem == 0: + out = (out.reshape(-1, blocksize) * absmax.view(-1, 1)).reshape(n) + else: + full = n - rem + blocks = full // blocksize + out = torch.cat( + [ + (out[:full].reshape(blocks, blocksize) * absmax[:blocks].view(-1, 1)).reshape(full), + out[full:] * absmax[blocks], + ] + ) + return out.to(dtype) @register_kernel("bitsandbytes::dequantize_blockwise", "default") def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor: - out = code[A.reshape(-1).int()] - blocks = out.shape[-1] // blocksize - res = out.shape[-1] % blocksize - if res != 0: - out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0) - out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1) - out = out[: blocks * blocksize + res] - out = out.reshape(A.shape) + return _dequantize_blockwise_compute(A.reshape(-1), absmax, code, blocksize, dtype).reshape(A.shape) + - return out +@cache +def _get_4bit_quantize_bounds(quant_type: str, device: torch.device): + code = _get_4bit_code(quant_type, device) + order = torch.argsort(code) + midpoints = (code[order[:-1]] + code[order[1:]]) / 2 + return midpoints, order # NF4 order is identity (sorted); FP4 needs remap @register_kernel("bitsandbytes::quantize_4bit", "default") def _( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: - n = A.numel() - full_blocks = n // blocksize + bounds, order = _get_4bit_quantize_bounds(quant_type, A.device) + A_flat = A.reshape(-1).float() + n = A_flat.numel() rem = n % blocksize - blocks = full_blocks + 1 if rem else full_blocks - absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) - A_flattened = A.reshape(n) - - # Scale full blocks of the tensor to [-1, 1] - A_full_blocks = A_flattened[: n - rem].reshape(n // blocksize, blocksize) - absmax[:full_blocks] = torch.abs(A_full_blocks).max(dim=-1)[0] - scaled = torch.clamp(A_full_blocks * (1 / absmax[:full_blocks].view(-1, 1)), -1, 1).reshape(-1) - - # Scale any partial block + full = n - rem + blocks = full // blocksize + A_com = A_flat[:full].reshape(blocks, blocksize) + absmax = A_com.abs().max(dim=-1)[0] + scaled = torch.clamp(A_com * (1.0 / absmax.clamp(min=1e-38).view(-1, 1)), -1, 1).reshape(-1) if rem: - A_rem = A_flattened[-rem:] - absmax[-1] = torch.abs(A_rem).max() - scaled_rem = torch.clamp(A_rem * (1 / absmax[-1]), -1, 1) - scaled = torch.cat([scaled, scaled_rem], dim=0) - - # Quantize with the lookup table - code = CODE[quant_type].to(scaled.device).to(scaled.dtype) - # Pad to even length so packing pairs all elements - if scaled.numel() % 2 != 0: - scaled = torch.nn.functional.pad(scaled, (0, 1), value=0.0) - quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - code), dim=-1, keepdim=True).to(torch.uint8) - - # Pack two quantized values per byte - packed = quantized[::2] << 4 | quantized[1::2] - + am = A_flat[full:].abs().max().clamp(min=1e-38) + absmax = torch.cat([absmax, am.unsqueeze(0)]) + scaled = torch.cat([scaled, torch.clamp(A_flat[full:] / am, -1, 1)]) + if scaled.numel() % 2: + scaled = torch.nn.functional.pad(scaled, (0, 1)) + q = torch.bucketize(scaled, bounds, out_int32=True) + if quant_type != "nf4": + q = order[q] + q8 = q.to(torch.uint8) + packed = ((q8[::2] << 4) | q8[1::2]).unsqueeze(1) if quant_storage != torch.uint8: packed = packed.squeeze().view(quant_storage).unsqueeze(1) - - return packed, absmax.float() + return packed, absmax -def _dequantize_4bit_impl( - A: torch.Tensor, +@_try_torch_compile(dynamic=True) +def _dequantize_4bit_compute( + A_flat: torch.Tensor, absmax: torch.Tensor, + code: torch.Tensor, blocksize: int, - quant_type: str, shape: Sequence[int], dtype: torch.dtype, -) -> torch.Tensor: - # Enable non uint8 dtype - if A.dtype != torch.uint8: - A = A.view(torch.uint8) - - A = A.reshape(-1) - # Map nf4 to [-1, 1] - out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) - out_dq[1::2] = A & 0xF - out_dq[::2] = A >> 4 - # code is fp32, cast to dtype to avoid the mismatch issue - code = CODE[quant_type].to(dtype).to(A.device) - out_dq = code[out_dq] - - # Use the actual output size, not the unpacked size (which may include padding) +): n = 1 for s in shape: n *= s - # Trim any extra elements from padding during quantization - out_dq = out_dq[:n] - - # Apply scales - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 + out_dq = torch.empty(A_flat.size(0) * 2, dtype=torch.int32, device=A_flat.device) + out_dq[1::2] = A_flat & 0xF + out_dq[::2] = A_flat >> 4 + out_dq = code[out_dq][:n] # stays fp32, matches C++ / CUDA behavior rem = n % blocksize - has_rem = rem > 0 - - out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) - if has_rem: - out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) - out[n - rem :] = out_dq[n - rem :] * absmax[-1] + if rem: + full = n - rem + blocks = full // blocksize + out = torch.empty(n, dtype=torch.float32, device=A_flat.device) + out[:full] = (out_dq[:full].view(-1, blocksize) * absmax[:blocks].view(-1, 1)).reshape(full) + out[full:] = out_dq[full:] * absmax[blocks] else: - out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) - - out = out.reshape(-1, *shape[1:]).to(dtype) - - return out + out = (out_dq.view(-1, blocksize) * absmax.view(-1, 1)).reshape(n) + return out.reshape(-1, *shape[1:]).to(dtype) @register_kernel("bitsandbytes::dequantize_4bit", "default") @@ -309,7 +296,10 @@ def _( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) + if A.dtype != torch.uint8: + A = A.view(torch.uint8) + code = _get_4bit_code(quant_type, A.device) + return _dequantize_4bit_compute(A.reshape(-1), absmax, code, blocksize, shape, dtype) @register_kernel("bitsandbytes::gemv_4bit", "default")