Skip to content

Commit 13efb5c

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Better 4bit packing (#2649)
Summary: Just use tensor methods, drop custom op and the previous python logic Differential Revision: D55319010
1 parent 67a7d20 commit 13efb5c

File tree

1 file changed

+11
-36
lines changed

1 file changed

+11
-36
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -449,18 +449,6 @@ def define_tensor(
449449
if quant_params is not None:
450450
vals_to_ids[quant_params.q_input] = id_out
451451

452-
@staticmethod
453-
def find_aot_util_path() -> str:
454-
# Look for .so installed by wheel (OSS).
455-
rel_path = "executorch/extension/pybindings/libaot_util.so"
456-
for sys_path in sys.path:
457-
so_path = Path(sys_path) / rel_path
458-
if so_path.exists():
459-
return str(so_path.absolute().as_posix())
460-
461-
# Fall back to buck.
462-
return "//executorch/extension/aot_util:aot_util"
463-
464452
@staticmethod
465453
def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor:
466454
"""
@@ -478,37 +466,24 @@ def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor:
478466
# Assuming we have a 2d tensor
479467
if inp.ndim != 2:
480468
inp = inp.squeeze()
481-
assert (
482-
inp.ndim == 2
483-
), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}"
484-
oc, ic = inp.shape
469+
assert (
470+
inp.ndim == 2
471+
), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}"
485472

486473
# pad ic
487-
if ic % 2 != 0:
474+
if inp.shape[-1] % 2 != 0:
488475
inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0)
489476

477+
# Shape after padding
478+
oc, ic = inp.shape
479+
assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even"
480+
490481
# Adjust inp tensor for zp
491482
inp = inp.to(dtype=torch.uint8) + 8
492483

493-
# prepare result tensor
494-
ric = int((ic + 1) / 2)
495-
result = torch.zeros([oc, ric], dtype=torch.uint8)
496-
497-
try:
498-
aot_path = NodeVisitor.find_aot_util_path()
499-
torch.ops.load_library(aot_path)
500-
result = torch.ops.xnnpack.convert_to_qc4w(inp)
501-
except:
502-
# Fallback to python implementation
503-
# TODO Warn the user? They might be developing in-tree and didn't install,
504-
# in which case, this will be very slow for large models.
505-
for o in range(oc):
506-
for i in range(ric):
507-
j = 2 * i
508-
result[o][i] = inp[o][j]
509-
result[o][i] += inp[o][j + 1] << 4
510-
511-
return result
484+
# Prepare the Result tensor
485+
inp = inp.contiguous().view(-1)
486+
return (inp[1::2] << 4 | inp[::2]).view(oc, int(ic / 2))
512487

513488
def get_serialized_buffer_index(
514489
self,

0 commit comments

Comments
 (0)