We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent d71783c commit 219fe79Copy full SHA for 219fe79
quantization/qops.py
@@ -395,15 +395,9 @@ def _prepare_weight_and_scales_and_zeros(
395
weight_int32, scales_and_zeros = group_quantize_tensor(
396
weight_bf16, n_bit=4, groupsize=groupsize
397
)
398
- if weight_bf16.device.type == "mps":
399
- # There are still no MPS-accelerated conversion OP
400
- weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
401
- weight_int32.cpu(), inner_k_tiles
402
- ).to("mps")
403
- else:
404
405
- weight_int32, inner_k_tiles
406
- )
+ weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(
+ weight_int32, inner_k_tiles
+ )
407
return weight_int4pack, scales_and_zeros
408
409
@classmethod
0 commit comments