Skip to content

Commit f264904

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Better 4bit packing (#2649)
Summary: Just use tensor methods, drop custom op and the previous python logic Differential Revision: D55319010
1 parent a531ca5 commit f264904

File tree

1 file changed

+11
-24
lines changed

1 file changed

+11
-24
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -478,37 +478,24 @@ def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor:
478478
# Assuming we have a 2d tensor
479479
if inp.ndim != 2:
480480
inp = inp.squeeze()
481-
assert (
482-
inp.ndim == 2
483-
), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}"
484-
oc, ic = inp.shape
481+
assert (
482+
inp.ndim == 2
483+
), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}"
485484

486485
# pad ic
487-
if ic % 2 != 0:
486+
if inp.shape[-1] % 2 != 0:
488487
inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0)
489488

489+
# Shape after padding
490+
oc, ic = inp.shape
491+
assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even"
492+
490493
# Adjust inp tensor for zp
491494
inp = inp.to(dtype=torch.uint8) + 8
492495

493-
# prepare result tensor
494-
ric = int((ic + 1) / 2)
495-
result = torch.zeros([oc, ric], dtype=torch.uint8)
496-
497-
try:
498-
aot_path = NodeVisitor.find_aot_util_path()
499-
torch.ops.load_library(aot_path)
500-
result = torch.ops.xnnpack.convert_to_qc4w(inp)
501-
except:
502-
# Fallback to python implementation
503-
# TODO Warn the user? They might be developing in-tree and didn't install,
504-
# in which case, this will be very slow for large models.
505-
for o in range(oc):
506-
for i in range(ric):
507-
j = 2 * i
508-
result[o][i] = inp[o][j]
509-
result[o][i] += inp[o][j + 1] << 4
510-
511-
return result
496+
# Prepare the Result tensor
497+
inp = inp.contiguous().view(-1)
498+
return (inp[1::2] << 4 | inp[::2]).view(oc, ic / 2)
512499

513500
def get_serialized_buffer_index(
514501
self,

0 commit comments

Comments
 (0)