Skip to content

jedt/metal-quant-ext2

Repository files navigation

metal-quant-ext2

metal-quant-ext2 is a repository of my research on PyTorch MPS kernel extensions using Apple Metal. The name includes quant which implies some quantization.

Goal is to develop mps pytorch extensions for efficient local fine-tuning using pytorch huggingface models and TRL

The study includes:

Requirements

  • MacOS 15.3.1 or later
  • Python 3.12
  • pytorch

Usage

pip3 install -r requirements.txt
pip3 install --ignore-installed .

Blockwise Quantization 8-bit

blockwise_quant is a function that applies symmetric blockwise 8-bit quantization to a pytorch tensor

from metal_quant_ext2 import blockwise_quant, dequantize
mps_device = torch.device("mps")

input_tensor = torch.randn(1024, device=mps_device, dtype=torch.float32)

quantized = torch.empty_like(input_tensor, dtype=torch.int8) # Will inherit device from input_tensor (MPS)

scales = torch.empty(num_blocks, device=cpu_device, dtype=torch.float32)

offsets = torch.empty(num_blocks, device=cpu_device, dtype=torch.float32)

# the actual MTL call
blockwise_quant(input_tensor, quantized, scales, offsets)

print(f"quantized: {quantized}")
assert torch.all(quantized.cpu() >= -127) and torch.al(quantized.cpu() <= 127)

# Dequantize MTL call
scales = scales.to(mps_device)
output = torch.empty_like(input_tensor)
dequantize(quantized, scales, output)

Testing

Check out the test file with assertions test-blockwise-quant.py

Blockwise Quantization Example

Below is a python script that helped me understand blockwise quantization

code-samples/blockwise-quantization.py

About

PyTorch extension with metal MPS kernel

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published