@@ -481,34 +481,24 @@ def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor:
481
481
assert (
482
482
inp .ndim == 2
483
483
), f"convert_to_qc4w: expecting input tensor to be 2d, got { inp .ndim } "
484
- oc , ic = inp . shape
484
+ assert inp . ndim == 2 , "convert_to_qc4w: expecting input tensor to be 2d"
485
485
486
486
# pad ic
487
- if ic % 2 != 0 :
487
+ if inp . shape [ - 1 ] % 2 != 0 :
488
488
inp = F .pad (input = inp , pad = (0 , 1 , 0 , 0 ), mode = "constant" , value = 0 )
489
489
490
+
491
+ # Shape after padding
492
+ oc , ic = inp .shape
493
+ assert ic % 2 == 0 , "convert_to_qc4w: expecting ic to be even"
494
+
490
495
# Adjust inp tensor for zp
491
496
inp = inp .to (dtype = torch .uint8 ) + 8
492
497
493
- # prepare result tensor
494
- ric = int ((ic + 1 ) / 2 )
495
- result = torch .zeros ([oc , ric ], dtype = torch .uint8 )
496
-
497
- try :
498
- aot_path = NodeVisitor .find_aot_util_path ()
499
- torch .ops .load_library (aot_path )
500
- result = torch .ops .xnnpack .convert_to_qc4w (inp )
501
- except :
502
- # Fallback to python implementation
503
- # TODO Warn the user? They might be developing in-tree and didn't install,
504
- # in which case, this will be very slow for large models.
505
- for o in range (oc ):
506
- for i in range (ric ):
507
- j = 2 * i
508
- result [o ][i ] = inp [o ][j ]
509
- result [o ][i ] += inp [o ][j + 1 ] << 4
510
-
511
- return result
498
+ # Prepare the Result tensor
499
+ inp = inp .contiguous ().view (- 1 )
500
+ return (inp [1 ::2 ] << 4 | inp [::2 ]).view (oc , ic / 2 )
501
+
512
502
513
503
def get_serialized_buffer_index (
514
504
self ,
0 commit comments