Skip to content

Commit 5c6cefc

Browse files
kimishpatelfacebook-github-bot
authored andcommitted
Use custom cpp op for packing 4 bit weights (#1899)
Summary: Pull Request resolved: #1899 It is extremely slow otherwise ghstack-source-id: 215452811 exported-using-ghexport Reviewed By: digantdesai Differential Revision: D53594767 fbshipit-source-id: a7af8e4aea86c6ef7dec6036d0257dbc7b323a59
1 parent b601b49 commit 5c6cefc

File tree

5 files changed

+123
-5
lines changed

5 files changed

+123
-5
lines changed

backends/xnnpack/operators/TARGETS

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,18 @@ runtime.python_library(
1010
"@EXECUTORCH_CLIENTS",
1111
],
1212
deps = [
13+
":convert_to_qc4w",
1314
"//executorch/backends/xnnpack/utils:xnnpack_utils",
1415
"//executorch/exir:graph_module",
1516
"//executorch/exir/backend:backend_details",
1617
],
1718
)
19+
20+
runtime.cxx_library(
21+
name = "convert_to_qc4w",
22+
srcs = ["convert_to_qc4w.cpp"],
23+
visibility = [
24+
"//executorch/...",
25+
],
26+
external_deps = ["libtorch"],
27+
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <ATen/ATen.h>
10+
#include <torch/library.h>
11+
12+
at::Tensor convert_to_qc4w(at::Tensor x) {
13+
std::vector<int64_t> sizes = x.sizes().vec();
14+
TORCH_CHECK(sizes.size() == 2, "Expecting 2D tensor");
15+
TORCH_CHECK(sizes[1] % 2 == 0);
16+
TORCH_CHECK(
17+
x.options().dtype() == at::kByte, "Input tensor must be of type uint8.");
18+
sizes[1] = sizes[1] / 2;
19+
at::Tensor output = at::empty(sizes, x.options().dtype());
20+
uint8_t* x_ptr = x.data_ptr<uint8_t>();
21+
uint8_t* output_ptr = output.data_ptr<uint8_t>();
22+
for (int i = 0; i < output.numel(); ++i) {
23+
int32_t input_i = i * 2;
24+
int32_t input_i_plus_1 = i * 2 + 1;
25+
output_ptr[i] = (x_ptr[input_i_plus_1] << 4) | (x_ptr[input_i]);
26+
}
27+
return output;
28+
}
29+
30+
TORCH_LIBRARY_FRAGMENT(xnnpack, m) {
31+
m.def("convert_to_qc4w", &convert_to_qc4w);
32+
}

backends/xnnpack/operators/node_visitor.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -409,11 +409,19 @@ def convert_to_qc4w(inp: torch.Tensor) -> torch.Tensor:
409409
ric = int((ic + 1) / 2)
410410
result = torch.zeros([oc, ric], dtype=torch.uint8)
411411

412-
for o in range(oc):
413-
for i in range(ric):
414-
j = 2 * i
415-
result[o][i] = inp[o][j]
416-
result[o][i] += inp[o][j + 1] << 4
412+
try:
413+
# TODO(): Enable this in OSS
414+
torch.ops.load_library(
415+
"//executorch/backends/xnnpack/operators:convert_to_qc4w"
416+
)
417+
result = torch.ops.xnnpack.convert_to_qc4w(inp)
418+
except:
419+
# Fallback to python implementation
420+
for o in range(oc):
421+
for i in range(ric):
422+
j = 2 * i
423+
result[o][i] = inp[o][j]
424+
result[o][i] += inp[o][j + 1] << 4
417425

418426
return result
419427

backends/xnnpack/test/TARGETS

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,15 @@ runtime.python_test(
7474
"//executorch/backends/xnnpack:xnnpack_preprocess",
7575
],
7676
)
77+
78+
runtime.python_test(
79+
name = "test_custom_convert_qc4w_op",
80+
srcs = ["ops/test_custom_convert_to_qc4w.py"],
81+
deps = [
82+
"//caffe2:torch",
83+
"//executorch/backends/xnnpack/operators:convert_to_qc4w",
84+
],
85+
external_deps = [
86+
"libtorch",
87+
],
88+
)
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
11+
12+
class TestCustomQC4WConvert(unittest.TestCase):
13+
def setUp(self):
14+
torch.ops.load_library(
15+
"//executorch/backends/xnnpack/operators:convert_to_qc4w"
16+
)
17+
18+
def test_convert(self):
19+
def _ref_output(inp):
20+
oc, ic = inp.shape
21+
if ic % 2 != 0:
22+
raise ValueError("Number of input channels not divisible by 2.")
23+
ric = (ic + 1) // 2
24+
result = torch.zeros([oc, ric], dtype=torch.uint8)
25+
for o in range(oc):
26+
for i in range(ric):
27+
j = 2 * i
28+
result[o][i] = inp[o][j]
29+
result[o][i] += inp[o][j + 1] << 4
30+
return result
31+
32+
inp = torch.randint(low=0, high=15, size=(20, 42), dtype=torch.uint8)
33+
result = torch.ops.xnnpack.convert_to_qc4w(inp)
34+
ref_result = _ref_output(inp)
35+
assert torch.equal(result, ref_result), "Outputs dont match"
36+
37+
def test_convert_throws(self):
38+
inp = torch.randint(low=0, high=15, size=(20, 41), dtype=torch.uint8)
39+
exception_thrown = False
40+
# Because for some reason self.assertRaises does not work
41+
# and didnt try to debug
42+
try:
43+
torch.ops.xnnpack.convert_to_qc4w(inp)
44+
except:
45+
exception_thrown = True
46+
self.assertTrue(exception_thrown)
47+
48+
inp = torch.rand((20, 41))
49+
exception_thrown = False
50+
# Because for some reason self.assertRaises does not work
51+
# and didnt try to debug
52+
try:
53+
torch.ops.xnnpack.convert_to_qc4w(inp)
54+
except:
55+
exception_thrown = True
56+
self.assertTrue(exception_thrown)

0 commit comments

Comments
 (0)