@@ -478,37 +478,24 @@ def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor:
478
478
# Assuming we have a 2d tensor
479
479
if inp .ndim != 2 :
480
480
inp = inp .squeeze ()
481
- assert (
482
- inp .ndim == 2
483
- ), f"convert_to_qc4w: expecting input tensor to be 2d, got { inp .ndim } "
484
- oc , ic = inp .shape
481
+ assert (
482
+ inp .ndim == 2
483
+ ), f"convert_to_qc4w: expecting input tensor to be 2d, got { inp .ndim } "
485
484
486
485
# pad ic
487
- if ic % 2 != 0 :
486
+ if inp . shape [ - 1 ] % 2 != 0 :
488
487
inp = F .pad (input = inp , pad = (0 , 1 , 0 , 0 ), mode = "constant" , value = 0 )
489
488
489
+ # Shape after padding
490
+ oc , ic = inp .shape
491
+ assert ic % 2 == 0 , "convert_to_qc4w: expecting ic to be even"
492
+
490
493
# Adjust inp tensor for zp
491
494
inp = inp .to (dtype = torch .uint8 ) + 8
492
495
493
- # prepare result tensor
494
- ric = int ((ic + 1 ) / 2 )
495
- result = torch .zeros ([oc , ric ], dtype = torch .uint8 )
496
-
497
- try :
498
- aot_path = NodeVisitor .find_aot_util_path ()
499
- torch .ops .load_library (aot_path )
500
- result = torch .ops .xnnpack .convert_to_qc4w (inp )
501
- except :
502
- # Fallback to python implementation
503
- # TODO Warn the user? They might be developing in-tree and didn't install,
504
- # in which case, this will be very slow for large models.
505
- for o in range (oc ):
506
- for i in range (ric ):
507
- j = 2 * i
508
- result [o ][i ] = inp [o ][j ]
509
- result [o ][i ] += inp [o ][j + 1 ] << 4
510
-
511
- return result
496
+ # Prepare the Result tensor
497
+ inp = inp .contiguous ().view (- 1 )
498
+ return (inp [1 ::2 ] << 4 | inp [::2 ]).view (oc , ic / 2 )
512
499
513
500
def get_serialized_buffer_index (
514
501
self ,
0 commit comments