@@ -449,18 +449,6 @@ def define_tensor(
449
449
if quant_params is not None :
450
450
vals_to_ids [quant_params .q_input ] = id_out
451
451
452
- @staticmethod
453
- def find_aot_util_path () -> str :
454
- # Look for .so installed by wheel (OSS).
455
- rel_path = "executorch/extension/pybindings/libaot_util.so"
456
- for sys_path in sys .path :
457
- so_path = Path (sys_path ) / rel_path
458
- if so_path .exists ():
459
- return str (so_path .absolute ().as_posix ())
460
-
461
- # Fall back to buck.
462
- return "//executorch/extension/aot_util:aot_util"
463
-
464
452
@staticmethod
465
453
def convert_to_qc4w (inp : torch .Tensor ) -> torch .Tensor :
466
454
"""
@@ -478,37 +466,24 @@ def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor:
478
466
# Assuming we have a 2d tensor
479
467
if inp .ndim != 2 :
480
468
inp = inp .squeeze ()
481
- assert (
482
- inp .ndim == 2
483
- ), f"convert_to_qc4w: expecting input tensor to be 2d, got { inp .ndim } "
484
- oc , ic = inp .shape
469
+ assert (
470
+ inp .ndim == 2
471
+ ), f"convert_to_qc4w: expecting input tensor to be 2d, got { inp .ndim } "
485
472
486
473
# pad ic
487
- if ic % 2 != 0 :
474
+ if inp . shape [ - 1 ] % 2 != 0 :
488
475
inp = F .pad (input = inp , pad = (0 , 1 , 0 , 0 ), mode = "constant" , value = 0 )
489
476
477
+ # Shape after padding
478
+ oc , ic = inp .shape
479
+ assert ic % 2 == 0 , "convert_to_qc4w: expecting ic to be even"
480
+
490
481
# Adjust inp tensor for zp
491
482
inp = inp .to (dtype = torch .uint8 ) + 8
492
483
493
- # prepare result tensor
494
- ric = int ((ic + 1 ) / 2 )
495
- result = torch .zeros ([oc , ric ], dtype = torch .uint8 )
496
-
497
- try :
498
- aot_path = NodeVisitor .find_aot_util_path ()
499
- torch .ops .load_library (aot_path )
500
- result = torch .ops .xnnpack .convert_to_qc4w (inp )
501
- except :
502
- # Fallback to python implementation
503
- # TODO Warn the user? They might be developing in-tree and didn't install,
504
- # in which case, this will be very slow for large models.
505
- for o in range (oc ):
506
- for i in range (ric ):
507
- j = 2 * i
508
- result [o ][i ] = inp [o ][j ]
509
- result [o ][i ] += inp [o ][j + 1 ] << 4
510
-
511
- return result
484
+ # Prepare the Result tensor
485
+ inp = inp .contiguous ().view (- 1 )
486
+ return (inp [1 ::2 ] << 4 | inp [::2 ]).view (oc , int (ic / 2 ))
512
487
513
488
def get_serialized_buffer_index (
514
489
self ,
0 commit comments