Skip to content

Commit 812375e

Browse files
committed
Update base for Update on "Dtype selective build for optimized ops"
Add dtype selective build for optimized ops. Follows the same process as portable, where we copy the source files and rebuild the library. 1. Generalize copy genrule for portable/optimized/source/header. 2. Copy optimized source files + headers. 3. Build optimized ops using source files, dependencies, portable header. 4. Add test, confirm that we can run addmul with float dtypes (when we remove, the test fails). Differential Revision: [D74688554](https://our.internmc.facebook.com/intern/diff/D74688554/) [ghstack-poisoned]
2 parents d68c90d + 9663bfb commit 812375e

File tree

42 files changed

+2956
-266
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+2956
-266
lines changed

.ci/scripts/build-qnn-sdk.sh

100644100755
Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,12 @@ set -o xtrace
1111

1212
build_qnn_backend() {
1313
echo "Start building qnn backend."
14-
export ANDROID_NDK_ROOT=/opt/ndk
15-
export QNN_SDK_ROOT=/tmp/qnn/2.28.0.241029
14+
export ANDROID_NDK_ROOT=${ANDROID_NDK_ROOT:-/opt/ndk}
15+
export QNN_SDK_ROOT=${QNN_SDK_ROOT:-/tmp/qnn/2.28.0.241029}
1616
export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." && pwd)"
1717

18-
# Workaround to avoid issues around missing flatccrt library (depending on the
19-
# number of jobs used), see issue #7300:
20-
# Build twice (second time with `--no_clean`) to make sure libflatccrt.a is
21-
# available.
22-
# TODO: Remove this workaround once the underlying issue is fixed.
23-
bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release || \
24-
bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release --no_clean
18+
parallelism=$(( $(nproc) - 1 ))
19+
bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number ${parallelism} --release
2520
}
2621

2722
set_up_aot() {

.github/workflows/build-presets.yml

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ on:
66
branches:
77
- main
88
- release/*
9+
paths:
10+
- .github/workflows/build-presets.yml
911
workflow_dispatch:
1012

1113
concurrency:
@@ -16,15 +18,51 @@ jobs:
1618
apple:
1719
uses: pytorch/test-infra/.github/workflows/macos_job.yml@main
1820
strategy:
21+
fail-fast: false
1922
matrix:
20-
preset: [macos-arm64]
23+
preset: [macos-arm64, pybind]
2124
with:
2225
job-name: build
26+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
2327
runner: macos-latest-xlarge
2428
python-version: 3.12
2529
submodules: recursive
30+
timeout: 90
2631
script: |
2732
set -eux
2833
${CONDA_RUN} ./install_requirements.sh > /dev/null
2934
${CONDA_RUN} cmake --preset ${{ matrix.preset }}
3035
${CONDA_RUN} cmake --build cmake-out --parallel
36+
37+
linux:
38+
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
39+
strategy:
40+
fail-fast: false
41+
matrix:
42+
preset: [pybind]
43+
runner: [linux.2xlarge, linux.arm64.2xlarge]
44+
docker-image: [executorch-ubuntu-22.04-clang12, executorch-ubuntu-22.04-gcc11-aarch64]
45+
# Excluding specific runner + docker image combinations that don't make sense:
46+
# - Excluding the ARM64 gcc image on the x86 runner (linux.2xlarge)
47+
# - Excluding the x86 clang image on the ARM64 runner (linux.arm64.2xlarge)
48+
exclude:
49+
- runner: linux.2xlarge
50+
docker-image: executorch-ubuntu-22.04-gcc11-aarch64
51+
- runner: linux.arm64.2xlarge
52+
docker-image: executorch-ubuntu-22.04-clang12
53+
with:
54+
job-name: build
55+
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
56+
runner: ${{ matrix.runner }}
57+
docker-image: ${{ matrix.docker-image }}
58+
submodules: recursive
59+
timeout: 90
60+
script: |
61+
set -eux
62+
# The generic Linux job chooses to use base env, not the one setup by the image
63+
CONDA_ENV=$(conda env list --json | jq -r ".envs | .[-1]")
64+
conda activate "${CONDA_ENV}"
65+
66+
./install_requirements.sh > /dev/null
67+
cmake --preset ${{ matrix.preset }}
68+
cmake --build cmake-out --parallel

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -582,6 +582,7 @@ if(EXECUTORCH_BUILD_PYBIND)
582582
${TORCH_PYTHON_LIBRARY}
583583
bundled_program
584584
etdump
585+
flatccrt
585586
executorch
586587
extension_data_loader
587588
util

CMakePresets.json

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
},
1616
{
1717
"name": "macos-arm64",
18+
"displayName": "Build everything buildable on macOS arm64",
1819
"inherits": ["common"],
1920
"generator": "Xcode",
2021
"cacheVariables": {
@@ -28,6 +29,20 @@
2829
"type": "equals",
2930
"rhs": "Darwin"
3031
}
32+
},
33+
{
34+
"name": "pybind",
35+
"displayName": "Build pybindings exported in the wheel",
36+
"inherits": ["common"],
37+
"cacheVariables": {
38+
"EXECUTORCH_BUILD_PRESET_FILE": "${sourceDir}/tools/cmake/preset/pybind.cmake",
39+
"CMAKE_OSX_DEPLOYMENT_TARGET": "10.15"
40+
},
41+
"condition": {
42+
"type": "inList",
43+
"string": "${hostSystemName}",
44+
"list": ["Darwin", "Linux", "Windows"]
45+
}
3146
}
3247
]
3348
}

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
import unittest
11-
from typing import Tuple
11+
from typing import Final, List, Tuple
1212

1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
@@ -281,25 +281,23 @@ def forward(self, x):
281281
)
282282

283283
def test_no_replace_quant_permute_dequant_with_requantize(self):
284-
class M(torch.nn.Module):
285-
def __init__(self):
286-
super().__init__()
287-
288-
def forward(self, x):
289-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
290-
x, 1.2, 3, 0, 127, torch.int8
291-
)
292-
x = torch.permute(x, [2, 0, 1, 3])
293-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
294-
x, 4.5, 6, 0, 127, torch.int8
295-
)
296-
return x
297-
298-
inputs = torch.randn(2, 12, 1, 6)
299-
model = M()
300-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
301-
302-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
284+
builder = GraphBuilder()
285+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
286+
quant = builder.call_operator(
287+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
288+
args=(x, 1.2, 3, 0, 127, torch.int8),
289+
)
290+
permute = builder.call_operator(
291+
op=exir_ops.edge.aten.permute_copy.default, args=(quant, [2, 0, 1, 3])
292+
)
293+
dequant = builder.call_operator(
294+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
295+
args=(permute, 4.5, 6, 0, 127, torch.int8),
296+
)
297+
builder.output(dequant)
298+
graph_module = FuseQuantDequantToRequantizePass(
299+
force_quant_dequant_fusion=False
300+
)(builder.get_graph_module()).graph_module
303301
self.check_op_counts(
304302
graph_module,
305303
expected_op_counts={
@@ -436,18 +434,28 @@ def forward(self, x):
436434
)
437435

438436
def test_fuse_mul_into_dequant(self):
439-
class M(torch.nn.Module):
440-
def forward(self, x):
441-
x0 = torch.ops.quantized_decomposed.dequantize_per_tensor(
442-
x, 1.5, 0, 0, 255, torch.uint8
443-
)
444-
x1 = torch.full([4, 32], 3, dtype=torch.float32)
445-
x2 = x0 * x1
446-
return x2
437+
INPUT_SHAPE: Final[List[int]] = [4, 32]
438+
DEQUANT_SCALE: Final[float] = 1.5
439+
FULL_VALUE: Final[float] = 3
447440

448-
inputs = (torch.randint(0, 255, [4, 32], dtype=torch.uint8),)
449-
graph_module = export_to_edge(M(), inputs).exported_program().graph_module
450-
graph_module = FuseMulTensorIntoDequantPass()(graph_module).graph_module
441+
builder = GraphBuilder()
442+
x = builder.placeholder("x", torch.randn(*INPUT_SHAPE, dtype=torch.float32))
443+
dequant = builder.call_operator(
444+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
445+
args=(x, DEQUANT_SCALE, 0, 0, 255, torch.uint8),
446+
)
447+
full = builder.call_operator(
448+
op=exir_ops.edge.aten.full.default,
449+
args=(INPUT_SHAPE, FULL_VALUE),
450+
)
451+
mul = builder.call_operator(
452+
op=exir_ops.edge.aten.mul.Tensor,
453+
args=(dequant, full),
454+
)
455+
builder.output(mul)
456+
graph_module = FuseMulTensorIntoDequantPass()(
457+
builder.get_graph_module()
458+
).graph_module
451459

452460
# verify that the mul and full ops were removed
453461
self.check_op_counts(
@@ -466,7 +474,7 @@ def forward(self, x):
466474
== exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
467475
):
468476
deq_scale = node.args[1]
469-
self.assertEqual(deq_scale, 4.5)
477+
self.assertEqual(deq_scale, DEQUANT_SCALE * FULL_VALUE)
470478

471479
def test_fuse_mul_scalar_into_dequant(self):
472480
dequant_scale = 0.006

backends/xnnpack/CMakeLists.txt

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,6 @@ endif()
2525

2626
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
2727

28-
# NB: Enabling this will serialize execution of delegate instances Keeping this
29-
# OFF by default to maintain existing behavior, to be revisited.
30-
option(EXECUTORCH_XNNPACK_SHARED_WORKSPACE
31-
"Enable workspace sharing across different delegate instances" ON
32-
)
33-
# Keeping this OFF by default due to regressions in decode and model load with
34-
# kleidi kernels
35-
option(EXECUTORCH_XNNPACK_ENABLE_KLEIDI "Enable Arm Kleidi kernels" OFF)
36-
37-
# Turning this on cache weights between partitions and methods. If weights
38-
# are shared across methods/partitions then this can reduce load time and
39-
# memory usage
40-
41-
# Keeping this off maintains existing behavior. Turning this on serializes
42-
# execution and initialization of delegates, to be revisited
43-
option(EXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE
44-
"Enable weights cache to cache and manage all packed weights" OFF)
45-
4628
if(EXECUTORCH_XNNPACK_ENABLE_WEIGHT_CACHE)
4729
add_definitions(-DENABLE_XNNPACK_WEIGHTS_CACHE)
4830
endif()

codegen/api/__init__.py

Whitespace-only changes.

codegen/api/custom_ops.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from __future__ import annotations
2+
3+
from collections import defaultdict
4+
from dataclasses import dataclass
5+
from typing import TYPE_CHECKING
6+
7+
from torchgen import dest
8+
9+
10+
# disable import sorting to avoid circular dependency.
11+
from torchgen.api.types import DispatcherSignature # usort: skip
12+
from torchgen.context import method_with_native_function
13+
from torchgen.model import BaseTy, BaseType, DispatchKey, NativeFunction, Variant
14+
from torchgen.utils import concatMap, Target
15+
16+
17+
if TYPE_CHECKING:
18+
from collections.abc import Sequence
19+
20+
from executorch.codegen.model import ETKernelIndex
21+
from torchgen.selective_build.selector import SelectiveBuilder
22+
23+
24+
# Generates RegisterKernelStub.cpp, which provides placeholder kernels for custom operators. This will be used at
25+
# model authoring side.
26+
@dataclass(frozen=True)
27+
class ComputeNativeFunctionStub:
28+
@method_with_native_function
29+
def __call__(self, f: NativeFunction) -> str | None:
30+
if Variant.function not in f.variants:
31+
return None
32+
33+
sig = DispatcherSignature.from_schema(
34+
f.func, prefix=f"wrapper_CPU_{f.func.name.overload_name}_", symint=False
35+
)
36+
assert sig is not None
37+
if len(f.func.returns) == 0:
38+
ret_name = ""
39+
elif len(f.func.returns) == 1:
40+
if f.func.arguments.out:
41+
ret_name = f.func.arguments.out[0].name
42+
else:
43+
ret_name = next(
44+
(
45+
a.name
46+
for a in f.func.arguments.flat_non_out
47+
if a.type == f.func.returns[0].type
48+
),
49+
"",
50+
)
51+
if not ret_name:
52+
# if return type is tensor
53+
if f.func.returns[0].type == BaseType(BaseTy.Tensor):
54+
# Returns an empty tensor
55+
ret_name = "at::Tensor()"
56+
else:
57+
raise Exception( # noqa: TRY002
58+
f"Can't handle this return type {f.func}"
59+
) # noqa: TRY002
60+
elif len(f.func.arguments.out) == len(f.func.returns):
61+
# Returns a tuple of out arguments
62+
tensor_type = "at::Tensor &"
63+
comma = ", "
64+
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
65+
{comma.join([r.name for r in f.func.arguments.out])}
66+
)"""
67+
else:
68+
assert all(
69+
a.type == BaseType(BaseTy.Tensor) for a in f.func.returns
70+
), f"Only support tensor returns but got {f.func.returns}"
71+
# Returns a tuple of empty tensors
72+
tensor_type = "at::Tensor"
73+
comma = ", "
74+
ret_name = f"""::std::tuple<{comma.join([tensor_type] * len(f.func.returns))}>(
75+
{comma.join(["at::Tensor()" for _ in f.func.returns])}
76+
)"""
77+
ret_str = f"return {ret_name};" if len(f.func.returns) > 0 else ""
78+
return f"""
79+
{sig.defn()} {{
80+
{ret_str}
81+
}}
82+
"""
83+
84+
85+
def gen_custom_ops_registration(
86+
*,
87+
native_functions: Sequence[NativeFunction],
88+
selector: SelectiveBuilder,
89+
kernel_index: ETKernelIndex,
90+
rocm: bool,
91+
) -> tuple[str, str]:
92+
"""
93+
Generate custom ops registration code for dest.RegisterDispatchKey.
94+
95+
:param native_functions: a sequence of `NativeFunction`
96+
:param selector: for selective build.
97+
:param kernel_index: kernels for all the ops.
98+
:param rocm: bool for dest.RegisterDispatchKey.
99+
:return: generated C++ code to register custom operators into PyTorch
100+
"""
101+
102+
# convert kernel index to BackendIndex. This is because we can't handle ETKernelIndex yet.
103+
# TODO larryliu: evaluate if this code is still needed. If yes let it handle ETKernelIndex.
104+
105+
dispatch_key = DispatchKey.CPU
106+
backend_index = kernel_index._to_backend_index()
107+
static_init_dispatch_registrations = ""
108+
ns_grouped_native_functions: dict[str, list[NativeFunction]] = defaultdict(list)
109+
for native_function in native_functions:
110+
ns_grouped_native_functions[native_function.namespace].append(native_function)
111+
112+
for namespace, functions in ns_grouped_native_functions.items():
113+
if len(functions) == 0:
114+
continue
115+
dispatch_registrations_body = "\n".join(
116+
list(
117+
concatMap(
118+
dest.RegisterDispatchKey(
119+
backend_index,
120+
Target.REGISTRATION,
121+
selector,
122+
rocm=rocm,
123+
symint=False,
124+
class_method_name=None,
125+
skip_dispatcher_op_registration=False,
126+
),
127+
functions,
128+
)
129+
)
130+
)
131+
static_init_dispatch_registrations += f"""
132+
TORCH_LIBRARY_IMPL({namespace}, {dispatch_key}, m) {{
133+
{dispatch_registrations_body}
134+
}}"""
135+
anonymous_definition = "\n".join(
136+
list(
137+
concatMap(
138+
dest.RegisterDispatchKey(
139+
backend_index,
140+
Target.ANONYMOUS_DEFINITION,
141+
selector,
142+
rocm=rocm,
143+
symint=False,
144+
class_method_name=None,
145+
skip_dispatcher_op_registration=False,
146+
),
147+
native_functions,
148+
)
149+
)
150+
)
151+
return anonymous_definition, static_init_dispatch_registrations

0 commit comments

Comments
 (0)