Skip to content

Commit 8cf16cb

Browse files
committed
Update on "[ET-VK] Adding UniformData struct in vTensor class to store uniform data, which will be stored using shared ptr and can be shared with push constants."
This diff adds a new struct called `UniformData` in the `vTensor` class to store uniform data, which can be shared with push constants. The `UniformData` struct contains the sizes, strides, and logical limits of the tensor, as well as the number of elements in the tensor. Diff adds `Attribute` enum to Tensor class to enumerate attributes supplied to dispatch and `UniformData` class to store tensor data supplied as uniforms to op shaders. The diff also adds write_attribute function to UniformData class to write attribute data to a given memory. Differential Revision: [D66733611](https://our.internmc.facebook.com/intern/diff/D66733611/) [ghstack-poisoned]
2 parents ec5c2b9 + df988d0 commit 8cf16cb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+7942
-14930
lines changed

.ci/docker/build.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ case "${IMAGE_NAME}" in
4141
QNN_SDK=yes
4242
CLANG_VERSION=12
4343
;;
44+
executorch-ubuntu-22.04-mediatek-sdk)
45+
MEDIATEK_SDK=yes
46+
CLANG_VERSION=12
47+
;;
4448
executorch-ubuntu-22.04-clang12-android)
4549
LINTRUNNER=""
4650
CLANG_VERSION=12
@@ -77,6 +81,7 @@ docker build \
7781
--build-arg "BUILD_DOCS=${BUILD_DOCS}" \
7882
--build-arg "ARM_SDK=${ARM_SDK:-}" \
7983
--build-arg "QNN_SDK=${QNN_SDK:-}" \
84+
--build-arg "MEDIATEK_SDK=${MEDIATEK_SDK:-}" \
8085
--build-arg "ANDROID_NDK_VERSION=${ANDROID_NDK_VERSION:-}" \
8186
-f "${OS}"/Dockerfile \
8287
"$@" \

.ci/docker/ubuntu/Dockerfile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,7 @@ RUN if [ -n "${ARM_SDK}" ]; then git config --global user.email "[email protected]
8585

8686
ARG QNN_SDK
8787

88+
ARG MEDIATEK_SDK
89+
8890
USER ci-user
8991
CMD ["bash"]

.ci/scripts/test_llama.sh

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ else
110110
COREML=OFF
111111
fi
112112

113+
if [[ "${MODE}" =~ .*quantize_kv.* ]]; then
114+
QUANTIZE_KV_CACHE=ON
115+
else
116+
QUANTIZE_KV_CACHE=OFF
117+
fi
118+
113119
echo "COREML option ${COREML}"
114120

115121
if [[ "${MODE}" =~ .*qnn.* ]]; then
@@ -249,6 +255,9 @@ if [[ "${QNN}" == "ON" ]]; then
249255
EXPORT_ARGS+=" --tokenizer_path tokenizer.model --pt2e_quantize qnn_16a16w --calibration_tasks wikitext --calibration_limit 1 --calibration_seq_length 128 --calibration_data Once "
250256
fi
251257
fi
258+
if [[ "${QUANTIZE_KV_CACHE}" == "ON" ]]; then
259+
EXPORT_ARGS="${EXPORT_ARGS} --quantize_kv_cache"
260+
fi
252261
# Add dynamically linked library location
253262
$PYTHON_EXECUTABLE -m examples.models.llama.export_llama ${EXPORT_ARGS}
254263

.github/workflows/docker-builds.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ jobs:
4242
- docker-image-name: executorch-ubuntu-22.04-linter
4343
- docker-image-name: executorch-ubuntu-22.04-arm-sdk
4444
- docker-image-name: executorch-ubuntu-22.04-qnn-sdk
45+
- docker-image-name: executorch-ubuntu-22.04-mediatek-sdk
4546
- docker-image-name: executorch-ubuntu-22.04-clang12-android
4647
env:
4748
DOCKER_IMAGE: 308535385114.dkr.ecr.us-east-1.amazonaws.com/executorch/${{ matrix.docker-image-name }}

.github/workflows/pull.yml

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ jobs:
8686
strategy:
8787
matrix:
8888
dtype: [fp32]
89-
mode: [portable, xnnpack+custom, xnnpack+custom+qe]
89+
mode: [portable, xnnpack+custom, xnnpack+custom+qe,xnnpack+custom+quantize_kv,xnnpack+quantize_kv]
9090
include:
9191
- dtype: bf16
9292
mode: portable
@@ -504,3 +504,21 @@ jobs:
504504
505505
# run llama runner in eager mode
506506
PYTHON_EXECUTABLE=python bash .ci/scripts/test_llama_runner_eager.sh
507+
508+
test-mediatek-models-linux:
509+
name: test-mediatek-models-linux
510+
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
511+
strategy:
512+
fail-fast: false
513+
with:
514+
runner: linux.24xlarge
515+
docker-image: executorch-ubuntu-22.04-mediatek-sdk
516+
submodules: 'true'
517+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
518+
timeout: 90
519+
script: |
520+
# The generic Linux job chooses to use base env, not the one setup by the image
521+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
522+
conda activate "${CONDA_ENV}"
523+
524+
# placeholder for mediatek to add more tests

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ jobs:
225225
strategy:
226226
matrix:
227227
dtype: [fp32]
228-
mode: [portable, xnnpack+kv+custom, mps, coreml]
228+
mode: [portable, xnnpack+kv+custom, mps, coreml, xnnpack+custom+quantize_kv]
229229
include:
230230
- dtype: bf16
231231
mode: portable

CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -742,9 +742,9 @@ if(EXECUTORCH_BUILD_PYBIND)
742742
endif()
743743

744744
if(EXECUTORCH_BUILD_XNNPACK)
745-
# need to explicitly specify XNNPACK here otherwise uses XNNPACK symbols
746-
# from libtorch_cpu
747-
list(APPEND _dep_libs xnnpack_backend XNNPACK)
745+
# need to explicitly specify XNNPACK and microkernels-prod
746+
# here otherwise uses XNNPACK and microkernel-prod symbols from libtorch_cpu
747+
list(APPEND _dep_libs xnnpack_backend XNNPACK microkernels-prod)
748748
endif()
749749

750750
# compile options for pybind

backends/arm/arm_backend.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import logging
1515
import os
16-
from typing import final, List, Optional
16+
from typing import cast, final, List, Optional
1717

1818
import serializer.tosa_serializer as ts
1919
from executorch.backends.arm.arm_vela import vela_compile
@@ -32,6 +32,7 @@
3232
from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
3333
from executorch.exir.backend.compile_spec_schema import CompileSpec
3434
from torch.export.exported_program import ExportedProgram
35+
from torch.fx import Node
3536

3637
# TOSA backend debug functionality
3738
logger = logging.getLogger(__name__)
@@ -269,6 +270,7 @@ def preprocess( # noqa: C901
269270
node_visitors = get_node_visitors(edge_program, tosa_spec)
270271
input_count = 0
271272
for node in graph_module.graph.nodes:
273+
node = cast(Node, node)
272274
if node.op == "call_function":
273275
process_call_function(node, tosa_graph, node_visitors, tosa_spec)
274276
elif node.op == "placeholder":
@@ -288,9 +290,6 @@ def preprocess( # noqa: C901
288290
"The rank of the input order is not equal to amount of input tensors"
289291
)
290292

291-
# TODO: It would be awesome if this dump could somehow be done on top level and not here.
292-
# Problem is that the desc.json has to be created on the tosa_graph object, which we can't
293-
# access from top level.
294293
if artifact_path:
295294
tag = _get_first_delegation_tag(graph_module)
296295
dbg_tosa_dump(
@@ -311,6 +310,4 @@ def preprocess( # noqa: C901
311310
else:
312311
raise RuntimeError(f"Unknown format {output_format}")
313312

314-
# Continueing from above. Can I put tosa_graph into this function?
315-
# debug_handle_map = ...
316313
return PreprocessResult(processed_bytes=binary)

backends/arm/test/common.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -74,19 +74,15 @@ def get_tosa_compile_spec_unbuilt(
7474
the compile spec before calling .build() to finalize it.
7575
"""
7676
if not custom_path:
77-
intermediate_path = maybe_get_tosa_collate_path() or tempfile.mkdtemp(
78-
prefix="arm_tosa_"
79-
)
80-
else:
81-
intermediate_path = custom_path
77+
custom_path = maybe_get_tosa_collate_path()
8278

83-
if not os.path.exists(intermediate_path):
84-
os.makedirs(intermediate_path, exist_ok=True)
79+
if custom_path is not None:
80+
os.makedirs(custom_path, exist_ok=True)
8581
compile_spec_builder = (
8682
ArmCompileSpecBuilder()
8783
.tosa_compile_spec(tosa_version)
8884
.set_permute_memory_format(permute_memory_to_nhwc)
89-
.dump_intermediate_artifacts_to(intermediate_path)
85+
.dump_intermediate_artifacts_to(custom_path)
9086
)
9187

9288
return compile_spec_builder

backends/arm/test/misc/test_debug_feats.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,9 @@ def test_numerical_diff_prints(self):
111111
model,
112112
example_inputs=model.get_inputs(),
113113
compile_spec=common.get_tosa_compile_spec(
114-
"TOSA-0.80.0+MI", permute_memory_to_nhwc=True
114+
"TOSA-0.80.0+MI",
115+
permute_memory_to_nhwc=True,
116+
custom_path=tempfile.mkdtemp("diff_print_test"),
115117
),
116118
)
117119
.export()

backends/arm/test/ops/test_cat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int):
124124
def test_cat_4d_tosa_MI(self):
125125
square = torch.ones((2, 2, 2, 2))
126126
for dim in range(-3, 3):
127-
test_data = ((square, square), dim)
127+
test_data = ((square, square.clone()), dim)
128128
self._test_cat_tosa_MI_pipeline(self.Cat(), test_data)
129129

130130
@parameterized.expand(Cat.test_parameters)

backends/arm/test/ops/test_scalars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _test_add_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: tuple):
157157
def test_MI(self, test_name: str, op: torch.nn.Module, x, y):
158158
expected_exception = None
159159
if any(token in test_name for token in ("Sub_int", "Sub__int")):
160-
expected_exception = RuntimeError
160+
expected_exception = ValueError
161161
elif test_name.endswith("_st"):
162162
expected_exception = AttributeError
163163

backends/arm/test/ops/test_select.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,6 @@ def _test_select_tosa_BI_pipeline(
9393
.check(["torch.ops.quantized_decomposed"])
9494
.to_edge()
9595
.partition()
96-
.dump_artifact()
97-
.dump_operator_distribution()
9896
.check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
9997
.to_executorch()
10098
.run_method_and_compare_outputs(inputs=test_data)

backends/arm/test/runner_utils.py

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@
1616

1717
import numpy as np
1818
import torch
19+
import tosa_reference_model
1920

2021
from executorch.backends.arm.test.conftest import is_option_enabled
2122

2223
from torch.export import ExportedProgram
2324
from torch.fx.node import Node
25+
from tosa import TosaGraph
2426

2527
logger = logging.getLogger(__name__)
26-
logger.setLevel(logging.WARNING)
28+
logger.setLevel(logging.CRITICAL)
2729

2830

2931
class QuantizationParams:
@@ -169,7 +171,7 @@ def __init__(
169171
):
170172
self.intermediate_path = intermediate_path
171173
self.tosa_ref_model_path = tosa_ref_model_path or "tosa_reference_model"
172-
assert os.path.exists(
174+
assert self.intermediate_path is None or os.path.exists(
173175
self.intermediate_path
174176
), f"TOSA artifact path don't exist! Path: {self.intermediate_path}"
175177

@@ -332,7 +334,46 @@ def run_corstone(
332334
tosa_ref_output = np.fromfile(out_path_with_suffix, dtype=np.float32)
333335
output_shape = self.output_node.args[0][0].meta["val"].shape
334336
tosa_ref_output = torch.from_numpy(tosa_ref_output).reshape(output_shape)
335-
return [tosa_ref_output]
337+
return tosa_ref_output
338+
339+
def run_tosa_graph(
340+
self, graph: TosaGraph, inputs: list[np.ndarray] | list[torch.Tensor]
341+
) -> torch.Tensor:
342+
"""Runs the TOSA reference model with inputs and returns the result."""
343+
data_np = [
344+
prep_data_for_save(
345+
input, self.is_quantized, self.input_names[i], self.qp_input[i]
346+
)
347+
for i, input in enumerate(inputs)
348+
]
349+
# tosa_profile: 0 = Base Inference, 1 = Main Inference, 2 = Main Training.
350+
tosa_profile = 0 if self.is_quantized else 1
351+
debug_mode = "ALL" if logger.level <= logging.DEBUG else None
352+
outputs, status = tosa_reference_model.run(
353+
graph,
354+
data_np,
355+
verbosity=_tosa_refmodel_loglevel(logger.level),
356+
tosa_profile=tosa_profile,
357+
initialize_variable_tensor_from_numpy=1, # True
358+
debug_mode=debug_mode,
359+
)
360+
361+
assert (
362+
status == tosa_reference_model.GraphStatus.TOSA_VALID
363+
), "Non-valid TOSA given to reference model."
364+
365+
outputs_torch = []
366+
for output in outputs:
367+
output = torch.from_numpy(output)
368+
if self.is_quantized:
369+
# Need to dequant back to FP32 for comparison with torch output
370+
quant_param = self.qp_output
371+
assert (
372+
quant_param is not None
373+
), "There are no quantization parameters, check output parameters"
374+
output = (output.to(torch.float32) - quant_param.zp) * quant_param.scale
375+
outputs_torch.append(output)
376+
return tuple(outputs_torch)
336377

337378
def run_tosa_ref_model(
338379
self,
@@ -417,21 +458,13 @@ def run_tosa_ref_model(
417458
assert (
418459
shutil.which(self.tosa_ref_model_path) is not None
419460
), f"tosa_reference_model tool not found, did you run examples/arm/setup.sh? Path: {self.tosa_ref_model_path}"
420-
loglevel_map = {
421-
logging.INFO: "INFO",
422-
logging.CRITICAL: "LOW",
423-
logging.ERROR: "LOW",
424-
logging.WARNING: "MED",
425-
logging.DEBUG: "HIGH",
426-
logging.NOTSET: "MED",
427-
}
428-
clamped_logging_level = max(min(logger.level // 10 * 10, 50), 0)
461+
429462
cmd_ref_model = [
430463
self.tosa_ref_model_path,
431464
"--test_desc",
432465
desc_file_path,
433466
"-l",
434-
loglevel_map[clamped_logging_level],
467+
_tosa_refmodel_loglevel(logger.level),
435468
]
436469
_run_cmd(cmd_ref_model)
437470

@@ -467,7 +500,10 @@ def run_tosa_ref_model(
467500

468501

469502
def prep_data_for_save(
470-
data, is_quantized: bool, input_name: str, quant_param: QuantizationParams
503+
data: torch.Tensor,
504+
is_quantized: bool,
505+
input_name: str,
506+
quant_param: QuantizationParams,
471507
):
472508
data_np = np.array(data.detach(), order="C").astype(
473509
f"{data.dtype}".replace("torch.", "")
@@ -576,7 +612,6 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict:
576612
assert os.path.exists(
577613
tosa_schema_file
578614
), f"tosa_schema_file: {tosa_schema_file} does not exist"
579-
580615
assert shutil.which("flatc") is not None
581616
cmd_flatc = [
582617
"flatc",
@@ -611,3 +646,19 @@ def dbg_tosa_fb_to_json(tosa_fb: bytes) -> Dict:
611646
pass
612647

613648
return json_out
649+
650+
651+
def _tosa_refmodel_loglevel(loglevel: int) -> str:
652+
"""Converts a logging loglevel to tosa_reference_model logginglevel,
653+
returned as string.
654+
"""
655+
loglevel_map = {
656+
logging.INFO: "INFO",
657+
logging.CRITICAL: "LOW",
658+
logging.ERROR: "LOW",
659+
logging.WARNING: "MED",
660+
logging.DEBUG: "HIGH",
661+
logging.NOTSET: "MED",
662+
}
663+
clamped_logging_level = max(min(loglevel // 10 * 10, 50), 0)
664+
return loglevel_map[clamped_logging_level]

0 commit comments

Comments
 (0)