Skip to content

Commit 781ba31

Browse files
digantdesaifacebook-github-bot
authored andcommitted
Better 4bit packing (#2649)
Summary: Just use tensor methods, drop custom op and the previous python logic Differential Revision: D55319010
1 parent a531ca5 commit 781ba31

File tree

1 file changed

+9
-21
lines changed

1 file changed

+9
-21
lines changed

backends/xnnpack/operators/node_visitor.py

Lines changed: 9 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -481,34 +481,22 @@ def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor:
481481
assert (
482482
inp.ndim == 2
483483
), f"convert_to_qc4w: expecting input tensor to be 2d, got {inp.ndim}"
484-
oc, ic = inp.shape
484+
assert inp.ndim == 2, "convert_to_qc4w: expecting input tensor to be 2d"
485485

486486
# pad ic
487-
if ic % 2 != 0:
487+
if inp.shape[-1] % 2 != 0:
488488
inp = F.pad(input=inp, pad=(0, 1, 0, 0), mode="constant", value=0)
489489

490+
# Shape after padding
491+
oc, ic = inp.shape
492+
assert ic % 2 == 0, "convert_to_qc4w: expecting ic to be even"
493+
490494
# Adjust inp tensor for zp
491495
inp = inp.to(dtype=torch.uint8) + 8
492496

493-
# prepare result tensor
494-
ric = int((ic + 1) / 2)
495-
result = torch.zeros([oc, ric], dtype=torch.uint8)
496-
497-
try:
498-
aot_path = NodeVisitor.find_aot_util_path()
499-
torch.ops.load_library(aot_path)
500-
result = torch.ops.xnnpack.convert_to_qc4w(inp)
501-
except:
502-
# Fallback to python implementation
503-
# TODO Warn the user? They might be developing in-tree and didn't install,
504-
# in which case, this will be very slow for large models.
505-
for o in range(oc):
506-
for i in range(ric):
507-
j = 2 * i
508-
result[o][i] = inp[o][j]
509-
result[o][i] += inp[o][j + 1] << 4
510-
511-
return result
497+
# Prepare the Result tensor
498+
inp = inp.contiguous().view(-1)
499+
return (inp[1::2] << 4 | inp[::2]).view(oc, ic / 2)
512500

513501
def get_serialized_buffer_index(
514502
self,

0 commit comments

Comments
 (0)