Skip to content

Commit 6ac0dfc

Browse files
author
morelos
committed
Update base for Update on "[ET-VK][Ops] quantization op shaders and impl"
Creating the quantize_per_tensor and quantize_per_token logic shaders and impl which are linked with the testing framework. NOTE: Currently the only input types supported are **half** (fp16) and **float** (fp32). The only output types supported are **byte** (uint8), **char** (int8), **short** (int16), **int** (int32). Differential Revision: [D75959064](https://our.internmc.facebook.com/intern/diff/D75959064/) [ghstack-poisoned]
2 parents 3fa3891 + 2d09ab8 commit 6ac0dfc

35 files changed

+1446
-136
lines changed

.ci/scripts/utils.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,14 @@ build_executorch_runner() {
156156
}
157157

158158
cmake_install_executorch_lib() {
159+
build_type="${1:-Release}"
159160
echo "Installing libexecutorch.a and libportable_kernels.a"
160161
clean_executorch_install_folders
161162
retry cmake -DCMAKE_INSTALL_PREFIX=cmake-out \
162-
-DCMAKE_BUILD_TYPE=Release \
163+
-DCMAKE_BUILD_TYPE=${build_type} \
163164
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
164165
-Bcmake-out .
165-
cmake --build cmake-out -j9 --target install --config Release
166+
cmake --build cmake-out -j9 --target install --config ${build_type}
166167
}
167168

168169
download_stories_model_artifacts() {

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
fi
4747
4848
# This has already been cached in the docker image
49-
lintrunner init 2> /dev/null
49+
lintrunner init
5050
5151
RC=0
5252
# Run lintrunner on all files

backends/qualcomm/_passes/i64_to_i32.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ class I64toI32(ExportPass):
2828
I64_OPS = {
2929
exir_ops.edge.aten.argmin.default,
3030
exir_ops.edge.aten.arange.start_step,
31+
exir_ops.edge.aten.cumsum.default,
3132
exir_ops.edge.aten.full.default,
3233
exir_ops.edge.aten.scalar_tensor.default,
34+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
3335
}
3436
# This dict is to ensure that the input of the OPs are int64 due to Pytorch restrictions.
3537
# For example, scatter op can only accept args[2], the index, as int64.

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def _build_tensor_constant(
8686
dtype=(
8787
node.args[0].meta["val"].dtype
8888
if not is_float_tensor(node)
89-
and not SCALAR_OPS.get(node.target).use_self_dtype
89+
and (info := SCALAR_OPS.get(node.target))
90+
and not info.use_self_dtype
9091
else node.meta["val"].dtype
9192
),
9293
device=node.meta["val"].device,

backends/qualcomm/_passes/replace_inf_values.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ def call(self, graph_module: torch.fx.GraphModule):
3030
arg_list[index] = torch.finfo(torch.float32).min
3131
elif arg == float("inf"):
3232
arg_list[index] = torch.finfo(torch.float32).max
33+
34+
if node.target == torch.ops.aten.masked_fill.Scalar:
35+
if arg_list[2] == torch.finfo(torch.float32).max:
36+
arg_list[2] = 255
37+
elif arg_list[2] == torch.finfo(torch.float32).min:
38+
arg_list[2] = -255
3339
node.args = tuple(arg_list)
3440

3541
graph_module.recompile()

backends/qualcomm/builders/op_cum_sum.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def define_node(
5151
dim = self.get_param(node, input_tensor)
5252

5353
output_tensor = self.get_tensor(node, node)
54+
if output_tensor.dtype == torch.int64:
55+
output_tensor = output_tensor.to(torch.int32)
5456
output_tensor_wrapper = self.define_tensor(
5557
node,
5658
node,

backends/qualcomm/tests/models.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,16 @@ def forward(self, x):
11011101
return torch.mean(x, (-1, -2))
11021102

11031103

1104+
class MaskedFill(torch.nn.Module):
1105+
def __init__(self):
1106+
super().__init__()
1107+
1108+
def forward(self, attn_mask):
1109+
return attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
1110+
attn_mask == 0, float(0.0)
1111+
)
1112+
1113+
11041114
class Maximum(torch.nn.Module):
11051115
def __init__(self):
11061116
super().__init__()
@@ -1751,16 +1761,6 @@ def forward(self, x):
17511761
)
17521762

17531763

1754-
class MaskedFill(torch.nn.Module):
1755-
def __init__(self):
1756-
super().__init__()
1757-
1758-
def forward(self, attn_mask):
1759-
return attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
1760-
attn_mask == 0, float(0.0)
1761-
)
1762-
1763-
17641764
# Mimi Decoder has 0D tensor which QNN cannot handle.
17651765
class ZeroDimTensor(torch.nn.Module):
17661766
def __init__(self):

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,24 @@ def test_qnn_backend_cos(self):
272272
self.lower_module_and_test_output(module, sample_input)
273273

274274
def test_qnn_backend_cumsum(self):
275-
module = CumSum() # noqa: F405
276-
sample_input = (torch.randn(4),)
277-
self.lower_module_and_test_output(module, sample_input)
275+
sample_input = ()
276+
test_comb = [
277+
{
278+
QCOM_MODULE: [CumSum()], # noqa: F405
279+
QCOM_SAMPLE_INPUTS: [
280+
(torch.randn(4),),
281+
(torch.randint(0, 10, size=(4,)),),
282+
],
283+
}
284+
]
285+
286+
index = 0
287+
for comb in test_comb:
288+
for module in comb[QCOM_MODULE]:
289+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
290+
with self.subTest(i=index):
291+
self.lower_module_and_test_output(module, sample_input)
292+
index += 1
278293

279294
def test_qnn_backend_einsum_outer_product(self):
280295
module = EinsumOuterProduct() # noqa: F405
@@ -311,6 +326,12 @@ def test_qnn_backend_element_wise_add(self):
311326
QCOM_MODULE: [AddConstantFloat()], # noqa: F405
312327
QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
313328
},
329+
{
330+
QCOM_MODULE: [
331+
AddConstantLong(), # noqa: F405
332+
],
333+
QCOM_SAMPLE_INPUTS: [(torch.randint(0, 10, size=(2, 3)),)],
334+
},
314335
]
315336

316337
index = 0
@@ -4526,6 +4547,40 @@ def test_retinanet(self):
45264547
else:
45274548
self.assertGreaterEqual(msg["mAP"], 0.6)
45284549

4550+
def test_roberta(self):
4551+
if not self.required_envs([self.sentence_dataset]):
4552+
self.skipTest("missing required envs")
4553+
cmds = [
4554+
"python",
4555+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/roberta.py",
4556+
"--dataset",
4557+
self.sentence_dataset,
4558+
"--artifact",
4559+
self.artifact_dir,
4560+
"--build_folder",
4561+
self.build_folder,
4562+
"--device",
4563+
self.device,
4564+
"--model",
4565+
self.model,
4566+
"--ip",
4567+
self.ip,
4568+
"--port",
4569+
str(self.port),
4570+
]
4571+
if self.host:
4572+
cmds.extend(["--host", self.host])
4573+
4574+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4575+
with Listener((self.ip, self.port)) as listener:
4576+
conn = listener.accept()
4577+
p.communicate()
4578+
msg = json.loads(conn.recv())
4579+
if "Error" in msg:
4580+
self.fail(msg["Error"])
4581+
else:
4582+
self.assertGreaterEqual(msg["accuracy"], 0.5)
4583+
45294584
def test_squeezenet(self):
45304585
if not self.required_envs([self.image_dataset]):
45314586
self.skipTest("missing required envs")
@@ -5344,6 +5399,11 @@ def setup_environment():
53445399
help="Location for imagenet dataset",
53455400
type=str,
53465401
)
5402+
parser.add_argument(
5403+
"--sentence_dataset",
5404+
help="Location for sentence dataset",
5405+
type=str,
5406+
)
53475407
parser.add_argument(
53485408
"-p",
53495409
"--pretrained_weight",
@@ -5402,6 +5462,7 @@ def setup_environment():
54025462
TestQNN.executorch_root = args.executorch_root
54035463
TestQNN.artifact_dir = args.artifact_dir
54045464
TestQNN.image_dataset = args.image_dataset
5465+
TestQNN.sentence_dataset = args.sentence_dataset
54055466
TestQNN.pretrained_weight = args.pretrained_weight
54065467
TestQNN.model_name = args.model_name
54075468
TestQNN.online_prepare = args.online_prepare

backends/qualcomm/tests/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ class TestQNN(unittest.TestCase):
183183
executorch_root: str = ""
184184
artifact_dir: str = ""
185185
image_dataset: str = ""
186+
sentence_dataset: str = ""
186187
pretrained_weight: str = ""
187188
enable_profile: bool = False
188189
op_package_dir: str = ""

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,10 @@ def is_nchw_node(self, node: torch.fx.Node) -> bool:
9191
return not self.is_nhwc_node(node)
9292

9393
def requires_nhwc_input(self, node: torch.fx.Node) -> bool:
94-
return (
95-
node.target in self.memory_sensitive_ops_nhwc
96-
or node.name == "output"
97-
and not node.args[0][0].meta["val"].is_contiguous()
98-
)
94+
return node.target in self.memory_sensitive_ops_nhwc
9995

10096
def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
101-
return (
102-
node.target in self.memory_sensitive_ops_nchw
103-
or node.name == "output"
104-
and node.args[0][0].meta["val"].is_contiguous()
105-
)
97+
return node.target in self.memory_sensitive_ops_nchw
10698

10799
def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
108100
# There are two conditions that must be met for a node to be able to
@@ -380,18 +372,21 @@ def call(self, graph_module: torch.fx.GraphModule): # noqa: C901
380372
# This node has no inputs so we don't need to change anything
381373
continue
382374

383-
if self.requires_nhwc_input(node):
375+
# Need special case for output node because it can have multiple output dim orders as we can output a tuple multiple nodes
376+
if node.op == "output":
377+
out_tuple = node.args[0]
378+
for out_node in out_tuple:
379+
if out_node.meta["val"].is_contiguous():
380+
self.input_to_nchw(graph_module, out_node, node)
381+
else:
382+
self.input_to_nhwc(graph_module, out_node, node)
383+
elif self.requires_nhwc_input(node):
384384
# Nodes which enter this branch are ones that require their
385385
# first input to be nhwc. This makes this node's output nhwc too
386-
# Currently, all nodes like this should have all of their other
387-
# inputs as nchw, so fail if this is not true
388-
if node.name == "output":
389-
self.input_to_nhwc(graph_module, node.args[0][0], node)
390-
else:
391-
self.input_to_nhwc(graph_module, node.args[0], node)
392-
393-
for input_node in node.all_input_nodes[1:]:
394-
if self.is_nhwc_node(input_node):
386+
387+
self.input_to_nhwc(graph_module, node.args[0], node)
388+
for input_node in node.all_input_nodes:
389+
if input_node.op == "placeholder" and self.is_nhwc_node(input_node):
395390
raise AssertionError(
396391
f"Expected {input_node} to be NCHW in channels last reshape pass"
397392
)

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,3 +335,50 @@ def test_dq_conv2d_channels_last_tagged_reshape_pass(self) -> None:
335335
)
336336
.run_method_and_compare_outputs()
337337
)
338+
339+
class ConvAddConvOutput(torch.nn.Module):
340+
def __init__(self):
341+
super().__init__()
342+
self.conv1 = torch.nn.Conv2d(3, 16, 3)
343+
self.conv2 = torch.nn.Conv2d(16, 16, 3)
344+
345+
def forward(self, x):
346+
y = self.conv1(x)
347+
z = torch.add(y, 1.0)
348+
out1 = self.conv2(z)
349+
out2 = z
350+
return out1, out2
351+
352+
ConvAddConvOutputModule = ConvAddConvOutput()
353+
354+
def test_conv_add_conv_output(self):
355+
x = torch.randn(1, 3, 8, 8)
356+
357+
self.run_tester(self.ConvAddConvOutput().eval(), (x,))
358+
359+
x_cl = x.to(memory_format=torch.channels_last)
360+
self.run_tester(self.ConvAddConvOutput().eval(), (x_cl,))
361+
362+
class ThreeOutputsModel(torch.nn.Module):
363+
def __init__(self):
364+
super().__init__()
365+
self.conv1 = torch.nn.Conv2d(3, 3, 3)
366+
self.conv2 = torch.nn.Conv2d(3, 3, 3)
367+
self.linear = torch.nn.Linear(6, 6)
368+
369+
def forward(self, x):
370+
conv1_out = self.conv1(x)
371+
conv2_out = self.conv2(x)
372+
linear_out = self.linear(x)
373+
374+
return linear_out, conv1_out, conv2_out
375+
376+
ThreeOutputsModelModule = ThreeOutputsModel()
377+
378+
def test_three_outputs_model(self):
379+
x = torch.randn(1, 3, 6, 6)
380+
381+
self.run_tester(self.ThreeOutputsModelModule.eval(), (x,))
382+
383+
x_cl = x.to(memory_format=torch.channels_last)
384+
self.run_tester(self.ThreeOutputsModelModule.eval(), (x_cl,))

examples/qualcomm/oss_scripts/llama/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ list(
3636
${CMAKE_CURRENT_LIST_DIR}/runner/token_generator.h
3737
${CMAKE_CURRENT_LIST_DIR}/runner/imem_alloc.h
3838
${CMAKE_CURRENT_LIST_DIR}/runner/client_mem.h
39+
${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.cpp
40+
${CMAKE_CURRENT_LIST_DIR}/runner/lhd_token_generator.h
3941
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.cpp
4042
${CMAKE_CURRENT_LIST_DIR}/runner/rpc_mem.h
4143
${CMAKE_CURRENT_LIST_DIR}/runner/kv_manager.cpp

examples/qualcomm/oss_scripts/llama/README.md

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
This file provides you the instructions to run LLAMA model with different parameters via Qualcomm HTP backend. We currently support the following models:
55
1. LLAMA2 Stories 110M
66
2. LLAMA3.2 1B
7-
3. LLAMA3.2 3B (WIP)
7+
3. LLAMA3.2 3B
88

99
We offer the following modes to execute the model:
1010

11-
KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt.
11+
- KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt.
1212

13-
Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache modes to optimize token generation speed. Initially, it uses AR-N model to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens.
13+
- Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache modes to optimize token generation speed. Initially, it uses AR-N model to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens.
1414
- AR-N model: The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use it to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor in hybrid mode.
1515
- Prompt processing with AR-N model:
1616
<figure>
@@ -19,6 +19,7 @@ Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache
1919
</figcaption>
2020
</figure>
2121

22+
- Lookahead Mode: Lookahead Mode introduces [lookahead decoding](https://arxiv.org/abs/2402.02057) and uses AR-N model to process prompt to enhance token generation speed. While decoding multiple tokens in a single step is infeasible, an LLM can generate multiple guess tokens in parallel. These guess tokens may fit into future parts of the generated sequence. The lookahead decoder generates and verifies these guess tokens, integrating them into the sequence if suitable. In some cases, it can obtain more than one token in a single step. Result is lossless.
2223

2324
## Instructions
2425
### Note
@@ -127,3 +128,14 @@ You can select the KV Cache update mechanism at runtime by setting the `KV_UPDAT
127128
```bash
128129
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --kv_updator ${KV_UPDATER}
129130
```
131+
132+
You can choose the lookahead mode to enhance decoding speed. To use this mode, you need to specify the following parameters:
133+
- `--ngram` (N-gram size): Represents the size of the n-grams used in the lookahead process.
134+
- `--window` (window size): Determines how many future tokens the algorithm attempts to predict in each step.
135+
- `--gcap` (Verification candidates): Represents the maximum number of speculations or candidate n-grams that the algorithm considers in each step for verification. It balances the trade-off between computation efficiency and exploring more possibilities.
136+
137+
For more details, please refer to the paper ["Break the Sequential Dependency of LLM Inference Using Lookahead Decoding"](https://arxiv.org/abs/2402.02057)
138+
139+
```bash
140+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode lookahead --prefill_ar_len 32 --max_seq_len 128 --prompt "what is 1+1" --ngram 3 --window 2 --gcap 2
141+
```

examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,17 @@ python_binary(
4545
],
4646
)
4747

48+
python_binary(
49+
name = "eval_llama_qnn",
50+
srcs = ["eval_llama_qnn.py"],
51+
main_function = "executorch.examples.qualcomm.oss_scripts.llama.eval_llama_qnn.main",
52+
deps = [
53+
":llama_lib",
54+
"//executorch/examples/models/llama:eval_library",
55+
"fbsource//third-party/pypi/lm-eval:lm-eval",
56+
],
57+
)
58+
4859
runtime.command_alias(
4960
name = "llama_qnn",
5061
env = {

0 commit comments

Comments
 (0)