Skip to content

Commit 108116c

Browse files
committed
Update base for Update on "Use std::variant to implement pytree Key"
Key was a struct that should've been a union; std::variant makes using a union much easier. Differential Revision: [D65575184](https://our.internmc.facebook.com/intern/diff/D65575184/) [ghstack-poisoned]
2 parents 03b1ef2 + 545535b commit 108116c

File tree

17 files changed

+229
-591
lines changed

17 files changed

+229
-591
lines changed

.github/workflows/ghstack_land.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ on:
55
branches:
66
- 'gh/cccclai/[0-9]+/base'
77
- 'gh/dbort/[0-9]+/base'
8+
- 'gh/dvorjackz/[0-9]+/base'
89
- 'gh/guangy10/[0-9]+/base'
910
- 'gh/helunwencser/[0-9]+/base'
1011
- 'gh/jorgep31415/[0-9]+/base'

backends/arm/test/runner_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,16 +448,21 @@ def run_tosa_ref_model(
448448
), "There are no quantization parameters, check output parameters"
449449
tosa_ref_output = (tosa_ref_output - quant_param.zp) * quant_param.scale
450450

451+
if tosa_ref_output.dtype == np.double:
452+
tosa_ref_output = tosa_ref_output.astype("float32")
453+
451454
# tosa_output is a numpy array, convert to torch tensor for comparison
452-
tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output.astype("float32")))
455+
tosa_ref_outputs.append(torch.from_numpy(tosa_ref_output))
453456

454457
return tosa_ref_outputs
455458

456459

457460
def prep_data_for_save(
458461
data, is_quantized: bool, input_name: str, quant_param: QuantizationParams
459462
):
460-
data_np = np.array(data.detach(), order="C").astype(np.float32)
463+
data_np = np.array(data.detach(), order="C").astype(
464+
f"{data.dtype}".replace("torch.", "")
465+
)
461466

462467
if is_quantized:
463468
assert quant_param.node_name in input_name, (

backends/cadence/aot/ops_registrations.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@
6666
lib.define(
6767
"quantized_conv.out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, Tensor out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
6868
)
69+
lib.define(
70+
"quantized_conv.per_tensor(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False) -> (Tensor Z)"
71+
)
72+
lib.define(
73+
"quantized_conv.per_tensor_out(Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, int[] dilation, int groups, int input_zero_point, int weight_zero_point, float bias_scale, float out_scale, int out_zero_point, int out_multiplier, int out_shift, bool channel_last=False, *, Tensor(a!) out) -> Tensor(a!)"
74+
)
6975

7076
lib.define(
7177
"quantized_matmul(Tensor X, int X_zero_point, Tensor Y, int Y_zero_point, Tensor? bias, int out_multiplier, int out_shift, int out_zero_point, bool transposed=False) -> (Tensor Z)"
@@ -171,6 +177,54 @@ def quantized_conv_meta(
171177
return input.new_empty(output_size, dtype=input.dtype)
172178

173179

180+
@register_fake("cadence::quantized_conv.per_tensor")
181+
def quantized_conv_per_tensor_meta(
182+
input: torch.Tensor,
183+
weight: torch.Tensor,
184+
bias: torch.Tensor,
185+
stride: Tuple[int],
186+
padding: Tuple[int],
187+
dilation: Tuple[int],
188+
groups: int,
189+
in_zero_point: int,
190+
weight_zero_point: int,
191+
bias_scale: float,
192+
output_scale: float,
193+
output_zero_point: int,
194+
out_multiplier: int,
195+
out_shift: int,
196+
channel_last: bool = False,
197+
) -> torch.Tensor:
198+
if channel_last:
199+
out_channels, *kernel_size, _ = weight.shape
200+
else:
201+
out_channels, _, *kernel_size = weight.shape
202+
203+
in_size = input.shape
204+
# Assert that the input tensor has at least 3 dimensions, and at most 6
205+
assert len(in_size) > 2
206+
assert len(in_size) < 6
207+
208+
# Compute the output tensor size
209+
output_size = (
210+
get_conv1d_output_size(
211+
in_size,
212+
out_channels,
213+
stride[1],
214+
padding[1],
215+
dilation[1],
216+
kernel_size[0],
217+
channel_last,
218+
)
219+
if len(in_size) == 3
220+
else get_conv2d_output_size(
221+
in_size, out_channels, stride, padding, dilation, kernel_size, channel_last
222+
)
223+
)
224+
225+
return input.new_empty(output_size, dtype=input.dtype)
226+
227+
174228
@register_fake("cadence::quantized_layer_norm")
175229
def quantized_layer_norm_meta(
176230
input: torch.Tensor,

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ def __init__(
540540
env: Dict[Any, Any],
541541
glslc_path: Optional[str],
542542
glslc_flags: str = "",
543+
replace_u16vecn: bool = False,
543544
) -> None:
544545
if isinstance(src_dir_paths, str):
545546
self.src_dir_paths = [src_dir_paths]
@@ -549,6 +550,7 @@ def __init__(
549550
self.env = env
550551
self.glslc_path = glslc_path
551552
self.glslc_flags = glslc_flags
553+
self.replace_u16vecn = replace_u16vecn
552554

553555
self.glsl_src_files: Dict[str, str] = {}
554556
self.template_yaml_files: List[str] = []
@@ -705,6 +707,22 @@ def constructOutputMap(self) -> None:
705707
self.create_shader_params(),
706708
)
707709

710+
def maybe_replace_u16vecn(self, input_text: str) -> str:
711+
"""
712+
There is a latency benefit to using u16vecn variables to store texture position
713+
variables instead of ivecn, likely due to reduced register pressure. However,
714+
SwiftShader does not support 16 bit integer types in shaders, so this is a crude
715+
way to fallback to using ivecn to store texture positions so that testing with
716+
SwiftShader is still possible.
717+
"""
718+
if not self.replace_u16vecn:
719+
return input_text
720+
if "codegen-nosub" in input_text:
721+
return input_text
722+
723+
input_text = input_text.replace("u16vec", "ivec")
724+
return input_text
725+
708726
def generateSPV(self, output_dir: str) -> Dict[str, str]:
709727
output_file_map = {}
710728

@@ -716,6 +734,7 @@ def process_shader(shader_paths_pair):
716734

717735
with codecs.open(source_glsl, "r", encoding="utf-8") as input_file:
718736
input_text = input_file.read()
737+
input_text = self.maybe_replace_u16vecn(input_text)
719738
output_text = preprocess(input_text, shader_params)
720739

721740
glsl_out_path = os.path.join(output_dir, f"{shader_name}.glsl")
@@ -1029,6 +1048,7 @@ def main(argv: List[str]) -> int:
10291048
parser.add_argument("-c", "--glslc-path", required=True, help="")
10301049
parser.add_argument("-t", "--tmp-dir-path", required=True, help="/tmp")
10311050
parser.add_argument("-o", "--output-path", required=True, help="")
1051+
parser.add_argument("--replace-u16vecn", action="store_true", default=False)
10321052
parser.add_argument("--optimize_size", action="store_true", help="")
10331053
parser.add_argument("--optimize", action="store_true", help="")
10341054
parser.add_argument(
@@ -1056,7 +1076,11 @@ def main(argv: List[str]) -> int:
10561076
glslc_flags += "-O"
10571077

10581078
shader_generator = SPVGenerator(
1059-
options.glsl_paths, env, options.glslc_path, glslc_flags
1079+
options.glsl_paths,
1080+
env,
1081+
options.glslc_path,
1082+
glslc_flags=glslc_flags,
1083+
replace_u16vecn=options.replace_u16vecn,
10601084
)
10611085
output_spv_files = shader_generator.generateSPV(options.tmp_dir_path)
10621086

backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
// codegen-nosub
10+
911
#version 450 core
1012

1113
#define PRECISION ${PRECISION}

backends/vulkan/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def vulkan_spv_shader_lib(name, spv_filegroups, is_fbcode = False):
2727
select({
2828
"DEFAULT": "",
2929
"ovr_config//os:android": "--optimize",
30+
"ovr_config//os:linux": "--replace-u16vecn",
3031
})
3132
)
3233

devtools/inspector/_inspector_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def get_scalar_type_size(scalar_type: ScalarType) -> Tuple[torch.dtype, int]:
112112
ScalarType.BYTE: (torch.uint8, 1),
113113
ScalarType.CHAR: (torch.int8, 1),
114114
ScalarType.BOOL: (torch.bool, 1),
115+
ScalarType.BITS16: (torch.uint16, 2),
115116
ScalarType.SHORT: (torch.int16, 2),
116117
ScalarType.HALF: (torch.float16, 2),
117118
ScalarType.INT: (torch.int, 4),

exir/passes/executorch_prim_ops_registry.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import math
78
import operator
89
from typing import Dict, Set, Union
910

@@ -14,6 +15,8 @@
1415
from torch._ops import OpOverload
1516
from torch.library import Library
1617

18+
# pyre-unsafe
19+
1720

1821
executorch_prims_lib = Library("executorch_prim", "DEF")
1922

@@ -91,7 +94,13 @@ def neg(a: _SymScalar) -> _SymScalar:
9194
return -a # pyre-ignore
9295

9396

97+
@bind_pattern_to_op(executorch_prims_lib, "trunc.Scalar(Scalar a) -> Scalar")
98+
def trunc(a: _SymScalar) -> _SymScalar:
99+
return math.trunc(a) # pyre-ignore
100+
101+
94102
_PYTHON_SYM_OPS_TO_EXECUTORCH_SYM_OPS: Dict[OpOverload, OpOverload] = {
103+
math.trunc: ops.backend.executorch_prim.trunc.Scalar,
95104
operator.sub: ops.backend.executorch_prim.sub.Scalar,
96105
operator.mul: ops.backend.executorch_prim.mul.Scalar,
97106
operator.add: ops.backend.executorch_prim.add.Scalar,

extension/llm/custom_ops/targets.bzl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
2+
load(
3+
"@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl",
4+
"get_vec_preprocessor_flags",
5+
"get_vec_deps",
6+
)
27
load(
38
"@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl",
49
"get_compiler_optimization_flags",
510
)
611

7-
812
def define_common_targets():
913
"""Defines targets that should be shared between fbcode and xplat.
1014
@@ -26,6 +30,7 @@ def define_common_targets():
2630
"op_sdpa.h",
2731
"op_update_quantized_cache.h",
2832
],
33+
preprocessor_flags = get_vec_preprocessor_flags(),
2934
exported_deps = [
3035
"//executorch/runtime/kernel:kernel_includes",
3136
"//executorch/kernels/portable/cpu:scalar_utils",
@@ -38,7 +43,7 @@ def define_common_targets():
3843
deps = [
3944
"//executorch/kernels/portable/cpu/util:reduce_util",
4045
"//executorch/extension/llm/custom_ops/spinquant:fast_hadamard_transform",
41-
],
46+
] + get_vec_deps(),
4247
compiler_flags = ["-Wno-missing-prototypes", "-Wno-global-constructors"] + get_compiler_optimization_flags(),
4348
visibility = [
4449
"//executorch/...",

kernels/optimized/lib_defs.bzl

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,44 @@ load(
1515
# functions in order to declare the required compiler flags needed in order to
1616
# access CPU vector intrinsics.
1717

18-
def get_vec_android_preprocessor_flags():
19-
preprocessor_flags = [
20-
(
21-
"^android-arm64.*$",
22-
[
18+
def get_vec_preprocessor_flags():
19+
if not runtime.is_oss:
20+
# various ovr_configs are not available in oss
21+
preprocessor_flags = select({
22+
"ovr_config//os:linux-x86_64": [
2323
"-DET_BUILD_ARM_VEC256_WITH_SLEEF",
24-
],
25-
),
26-
]
27-
return preprocessor_flags
24+
] if not runtime.is_oss else [],
25+
"ovr_config//os:iphoneos-arm64": [
26+
"-DET_BUILD_ARM_VEC256_WITH_SLEEF",
27+
] if not runtime.is_oss else [],
28+
"ovr_config//os:macos-arm64": [
29+
"-DET_BUILD_ARM_VEC256_WITH_SLEEF",
30+
] if not runtime.is_oss else [],
31+
"ovr_config//os:android-arm64": [
32+
"-DET_BUILD_ARM_VEC256_WITH_SLEEF",
33+
] if not runtime.is_oss else [],
34+
"DEFAULT": [],
35+
})
36+
return preprocessor_flags
37+
return []
38+
39+
def get_vec_deps():
40+
if not runtime.is_oss:
41+
# various ovr_configs are not available in oss
42+
deps = select({
43+
"ovr_config//os:iphoneos-arm64": [
44+
"fbsource//third-party/sleef:sleef_arm",
45+
] if not runtime.is_oss else [],
46+
"ovr_config//os:macos-arm64": [
47+
"fbsource//third-party/sleef:sleef_arm",
48+
] if not runtime.is_oss else [],
49+
"ovr_config//os:android-arm64": [
50+
"fbsource//third-party/sleef:sleef_arm",
51+
] if not runtime.is_oss else [],
52+
"DEFAULT": [],
53+
})
54+
return deps
55+
return []
2856

2957
def get_vec_cxx_preprocessor_flags():
3058
preprocessor_flags = [

kernels/optimized/op_registration_util.bzl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22
load("@fbsource//xplat/executorch/build:selects.bzl", "selects")
33
load(
44
"@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl",
5-
"get_vec_android_preprocessor_flags",
5+
"get_vec_preprocessor_flags",
6+
"get_vec_deps",
67
)
78
load(
89
"@fbsource//xplat/executorch/kernels/portable:op_registration_util.bzl",
@@ -94,8 +95,8 @@ def define_op_library(name, deps):
9495
compiler_flags = ["-Wno-missing-prototypes"] + get_compiler_optimization_flags(),
9596
deps = [
9697
"//executorch/runtime/kernel:kernel_includes",
97-
] + augmented_deps,
98-
fbandroid_platform_preprocessor_flags = get_vec_android_preprocessor_flags(),
98+
] + augmented_deps + get_vec_deps(),
99+
preprocessor_flags = get_vec_preprocessor_flags(),
99100
# sleef needs to be added as a direct dependency of the operator target when building for Android,
100101
# or a linker error may occur. Not sure why this happens; it seems that fbandroid_platform_deps of
101102
# dependencies are not transitive

kernels/optimized/test/targets.bzl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
22
load(
33
"@fbsource//xplat/executorch/kernels/optimized:lib_defs.bzl",
4-
"get_vec_android_preprocessor_flags",
4+
"get_vec_preprocessor_flags",
55
"get_vec_cxx_preprocessor_flags",
66
)
77
load("@fbsource//xplat/executorch/kernels/test:util.bzl", "define_supported_features_lib")
@@ -27,7 +27,7 @@ def _lib_test_bin(name, extra_deps = [], in_cpu = False):
2727
"//executorch/kernels/optimized{}:{}".format(cpu_path, lib_root),
2828
] + extra_deps,
2929
cxx_platform_preprocessor_flags = get_vec_cxx_preprocessor_flags(),
30-
fbandroid_platform_preprocessor_flags = get_vec_android_preprocessor_flags(),
30+
preprocessor_flags = get_vec_preprocessor_flags(),
3131
)
3232

3333
def define_common_targets():

kernels/prim_ops/register_prim_ops.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
#include <executorch/runtime/kernel/kernel_includes.h>
1313
#include <executorch/runtime/kernel/operator_registry.h>
1414

15+
#include <cmath>
16+
1517
using torch::executor::function::et_copy_index;
1618

1719
namespace torch {
@@ -301,6 +303,20 @@ static Kernel prim_ops[] = {
301303
}
302304
}),
303305

306+
// trunc.Scalar(Scalar a) -> Scalar
307+
Kernel(
308+
"executorch_prim::trunc.Scalar",
309+
[](KernelRuntimeContext& context, EValue** stack) {
310+
(void)context;
311+
EValue& a = *stack[0];
312+
EValue& out = *stack[1];
313+
if (a.isDouble()) {
314+
out = EValue(static_cast<int64_t>(trunc(a.toDouble())));
315+
} else {
316+
ET_CHECK_MSG(false, "%zu", (size_t)a.tag);
317+
}
318+
}),
319+
304320
// executorch_prim::et_copy_index.tensor(tensor, tensor) -> tensor
305321
Kernel("executorch_prim::et_copy_index.tensor", &et_copy_index),
306322
// executorch_prim::et_view.default(Tensor, int[]) -> Tensor

0 commit comments

Comments
 (0)