Skip to content

Commit 219fe79

Browse files
Use int4mm weight packing mps kernel
1 parent d71783c commit 219fe79

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

quantization/qops.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -395,15 +395,9 @@ def _prepare_weight_and_scales_and_zeros(
395395
weight_int32, scales_and_zeros = group_quantize_tensor(
396396
weight_bf16, n_bit=4, groupsize=groupsize
397397
)
398-
if weight_bf16.device.type == "mps":
399-
# There are still no MPS-accelerated conversion OP
400-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
401-
weight_int32.cpu(), inner_k_tiles
402-
).to("mps")
403-
else:
404-
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
405-
weight_int32, inner_k_tiles
406-
)
398+
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
399+
weight_int32, inner_k_tiles
400+
)
407401
return weight_int4pack, scales_and_zeros
408402

409403
@classmethod

0 commit comments

Comments
 (0)