Skip to content

Commit 1bbb2aa

Browse files
committed
Merge remote-tracking branch 'origin/main' into tiktoken
2 parents f12edc8 + f9efb05 commit 1bbb2aa

File tree

27 files changed

+717
-34
lines changed

27 files changed

+717
-34
lines changed

backends/arm/arm_backend.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from executorch.backends.arm.arm_vela import vela_compile
1818
from executorch.backends.arm.operators.node_visitor import get_node_visitors
1919
from executorch.backends.arm.operators.op_placeholder import process_placeholder
20-
from executorch.backends.arm.tosa_mapping import TosaArg
21-
from executorch.backends.arm.tosa_quant_utils import is_quant_node
20+
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
21+
from executorch.backends.arm.tosa_quant_utils import get_quant_node_dtype, is_quant_node
2222
from executorch.backends.arm.tosa_utils import (
2323
dbg_fail,
2424
dbg_tosa_dump,
@@ -280,7 +280,11 @@ def preprocess( # noqa: C901
280280
if is_permute_node_before_addmm(node)
281281
else output.shape
282282
),
283-
ts.DType.INT8 if is_quant_node(node) else output.dtype,
283+
(
284+
map_dtype(get_quant_node_dtype(node))
285+
if is_quant_node(node)
286+
else output.dtype
287+
),
284288
)
285289

286290
# Visiting each Node

backends/arm/operators/op_placeholder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from executorch.backends.arm.tosa_mapping import TosaArg
1010
from executorch.backends.arm.tosa_quant_utils import (
11+
get_quant_arg_dtype,
1112
get_quant_node_args,
1213
is_quant_arg,
1314
q_op,
@@ -166,7 +167,7 @@ def process_placeholder(
166167
tensor = ts.TosaSerializerTensor(
167168
inputs[0].name,
168169
input_shape,
169-
ts.DType.INT8 if is_quant_arg(node) else inputs[0].dtype,
170+
get_quant_arg_dtype(node) if is_quant_arg(node) else inputs[0].dtype,
170171
data=None,
171172
placeholderFilename=inputs[0].name + ".npy",
172173
)

backends/arm/test/runner_utils.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,24 @@
2323

2424

2525
class QuantizationParams:
26-
__slots__ = ["node_name", "zp", "scale"]
26+
__slots__ = ["node_name", "zp", "scale", "qmin", "qmax", "dtype"]
2727

2828
# todo: zps and scales can be per tensors or per channel => a list??
29-
def __init__(self, node_name: str, zp: int, scale: float):
29+
def __init__(
30+
self,
31+
node_name: str,
32+
zp: int,
33+
scale: float,
34+
qmin: int,
35+
qmax: int,
36+
dtype: torch.dtype,
37+
):
3038
self.node_name = node_name # not need I think, but good for error check
3139
self.zp = zp
3240
self.scale = scale
41+
self.qmin = qmin
42+
self.qmax = qmax
43+
self.dtype = dtype
3344

3445

3546
def _get_input_names(program: ExportedProgram) -> list[str]:
@@ -74,7 +85,12 @@ def _get_input_quantization_params(
7485
and node.args[0].name in input_names
7586
):
7687
qp = QuantizationParams(
77-
node_name=node.args[0].name, scale=node.args[1], zp=node.args[2]
88+
node_name=node.args[0].name,
89+
scale=node.args[1],
90+
zp=node.args[2],
91+
qmin=node.args[3],
92+
qmax=node.args[4],
93+
dtype=node.args[5],
7894
)
7995
quant_params.append(qp)
8096
if (
@@ -122,7 +138,12 @@ def _get_output_quantization_params(
122138
and node == output_node.args[0][0]
123139
):
124140
quant_params = QuantizationParams(
125-
node_name=node.args[0].name, scale=node.args[1], zp=node.args[2]
141+
node_name=node.args[0].name,
142+
scale=node.args[1],
143+
zp=node.args[2],
144+
qmin=node.args[3],
145+
qmax=node.args[4],
146+
dtype=node.args[5],
126147
)
127148
break # break early, there's only one output node
128149
if quant_params is None:
@@ -376,13 +397,13 @@ def prep_data_for_save(
376397
assert (
377398
quant_param.node_name == input_name
378399
), "These quantization params do not match the input tensor name"
379-
int8_max = np.iinfo(np.int8).max
380-
int8_min = np.iinfo(np.int8).min
381400
data_np = (
382401
((data_np / np.float32(quant_param.scale)) + quant_param.zp)
383402
.round()
384-
.clip(int8_min, int8_max)
385-
.astype(np.int8)
403+
.clip(quant_param.qmin, quant_param.qmax)
404+
.astype(
405+
f"{quant_param.dtype}".replace("torch.", "")
406+
) # Use string format of dtype to convert to numpy dtype
386407
)
387408
return data_np
388409

backends/arm/tosa_quant_utils.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import serializer.tosa_serializer as ts
1212
import torch.fx
13-
from executorch.backends.arm.tosa_mapping import TosaArg
13+
from executorch.backends.arm.tosa_mapping import map_dtype, TosaArg
1414
from executorch.exir.dialects._ops import ops as exir_ops
1515
from serializer.tosa_serializer import TosaOp, TosaSerializerTensor
1616

@@ -45,11 +45,41 @@ def is_quant_node(node: torch.fx.Node):
4545
)
4646

4747

48+
def get_quant_node_dtype(node: torch.fx.Node):
49+
if "tosa" in node.target.__name__:
50+
return node.meta["val"].dtype
51+
52+
if node.target in dq_q_ops:
53+
return node.args[5]
54+
55+
# if not a tosa node, nor a q/dq op, walk the graph until we find a q op
56+
consumer_node = list(node.users)[0]
57+
while True:
58+
if consumer_node.target in dq_q_ops:
59+
return consumer_node.args[5]
60+
61+
# Try to move on to the next node
62+
if len(consumer_node.users) == 0:
63+
raise RuntimeError("No quantized node found in graph")
64+
consumer_node = list(consumer_node.users)[0]
65+
66+
4867
def is_quant_arg(arg):
4968
consumer_node = list(arg.users)[0]
5069
return consumer_node.target == q_op
5170

5271

72+
def get_quant_arg_dtype(node: torch.fx.Node):
73+
consumer_node = list(node.users)[0]
74+
75+
# Get type of quant node, args differ from per_tensor and per_channel.
76+
if consumer_node.target == q_op:
77+
if is_quant_arg(node):
78+
return map_dtype(consumer_node.args[5])
79+
else:
80+
raise RuntimeError("Quantization argument not found")
81+
82+
5383
def get_quant_node_args(node: torch.fx.Node):
5484
"""
5585
Get the quantization parameters from a quant node.

backends/transforms/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ runtime.python_library(
120120
"//executorch/backends/...",
121121
"//executorch/examples/...",
122122
"//executorch/extension/llm/...",
123+
"@EXECUTORCH_CLIENTS",
123124
],
124125
deps = [
125126
"//caffe2:torch",

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def __contains__(self, op):
4545
BINARY_OPS = [
4646
exir_ops.edge.aten.add.Tensor,
4747
exir_ops.edge.aten.sub.Tensor,
48+
exir_ops.edge.aten.minimum.default,
4849
exir_ops.edge.aten.mul.Tensor,
4950
exir_ops.edge.aten.div.Tensor,
5051
exir_ops.edge.aten.div.Tensor_mode,

backends/vulkan/runtime/gen_vulkan_spv.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -525,8 +525,29 @@ def generateVariantCombinations(
525525
if param_name not in exclude_params:
526526
param_values = []
527527
for value in value_list:
528-
suffix = value.get("SUFFIX", value["VALUE"])
529-
param_values.append((param_name, suffix, value["VALUE"]))
528+
if "RANGE" in value:
529+
value_range = value["RANGE"]
530+
suffix = value.get("SUFFIX", "")
531+
if isinstance(value_range, list) and len(value_range) == 2:
532+
for i in range(value_range[0], value_range[1] + 1):
533+
curr_suffix = (
534+
suffix + "_" + str(i) if suffix else str(i)
535+
)
536+
param_values.append((param_name, curr_suffix, str(i)))
537+
else:
538+
raise ValueError(
539+
f"{value['RANGE']} is not a valid range. Must be in format [start, end] (inclusive)."
540+
)
541+
542+
elif "VALUE" in value:
543+
suffix = value.get("SUFFIX", value["VALUE"])
544+
param_values.append((param_name, suffix, value["VALUE"]))
545+
546+
else:
547+
raise KeyError(
548+
"Parameter must be 'VALUE: string' or 'RANGE: [a, b]'"
549+
)
550+
530551
all_iterated_params.append(param_values)
531552

532553
return list(product(*all_iterated_params))

backends/vulkan/runtime/graph/ops/glsl/binary_op.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,5 @@ binary_op:
2828
OPERATOR: pow(X, Y)
2929
- NAME: binary_floor_divide
3030
OPERATOR: floor(X / Y)
31+
- NAME: binary_minimum
32+
OPERATOR: min(X, Y)

backends/vulkan/runtime/graph/ops/impl/BinaryOp.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ DEFINE_BINARY_OP_WITH_ALPHA_FN(floor_divide);
118118
DEFINE_BINARY_OP_FN(mul);
119119
DEFINE_BINARY_OP_FN(div);
120120
DEFINE_BINARY_OP_FN(pow);
121+
DEFINE_BINARY_OP_FN(minimum);
121122

122123
REGISTER_OPERATORS {
123124
VK_REGISTER_OP(aten.add.Tensor, add);
@@ -126,6 +127,7 @@ REGISTER_OPERATORS {
126127
VK_REGISTER_OP(aten.div.Tensor, div);
127128
VK_REGISTER_OP(aten.div.Tensor_mode, floor_divide);
128129
VK_REGISTER_OP(aten.pow.Tensor_Tensor, pow);
130+
VK_REGISTER_OP(aten.minimum.default, minimum);
129131
}
130132

131133
} // namespace vkcompute

backends/vulkan/runtime/vk_api/QueryPool.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,5 +248,22 @@ unsigned long QueryPool::get_total_shader_ns(std::string kernel_name) {
248248
}
249249
return 0;
250250
}
251+
252+
unsigned long QueryPool::get_mean_shader_ns(std::string kernel_name) {
253+
uint64_t total_ns = 0;
254+
uint32_t count = 0;
255+
for (ShaderDuration& entry : shader_durations_) {
256+
if (entry.kernel_name == kernel_name) {
257+
std::chrono::duration<size_t, std::nano> exec_duration_ns(
258+
entry.execution_duration_ns);
259+
total_ns += exec_duration_ns.count();
260+
count++;
261+
}
262+
}
263+
if (count == 0) {
264+
return 0;
265+
}
266+
return total_ns / count;
267+
}
251268
} // namespace vkapi
252269
} // namespace vkcompute

backends/vulkan/runtime/vk_api/QueryPool.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class QueryPool final {
102102
std::string generate_string_report();
103103
void print_results();
104104
unsigned long get_total_shader_ns(std::string kernel_name);
105+
unsigned long get_mean_shader_ns(std::string kernel_name);
105106

106107
operator bool() const {
107108
return querypool_ != VK_NULL_HANDLE;

backends/vulkan/test/op_tests/cases.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,3 +1022,19 @@ def get_constant_pad_nd_inputs():
10221022
]
10231023
)
10241024
return test_suite
1025+
1026+
1027+
@register_test_suite("aten.minimum.default")
1028+
def get_minimum_inputs():
1029+
test_suite = VkTestSuite(
1030+
[
1031+
((M1, M2), (M2)),
1032+
((M1, M2), (M1, M2)),
1033+
((M1, M2, M), (M2, M)),
1034+
((M1, M1, S1, S2), (M1, M1, S1, S2)),
1035+
((S1, S1, S2, S), (S1, S2, S)),
1036+
((M1, S1, S2), (L, M1, S1, S2)),
1037+
((S1, S2), (L, M1, S1, S2)),
1038+
]
1039+
)
1040+
return test_suite

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,25 @@ def forward(self, x):
10721072
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
10731073
)
10741074

1075+
def test_vulkan_backend_minimum(self):
1076+
class MinimumModule(torch.nn.Module):
1077+
def __init__(self):
1078+
super().__init__()
1079+
1080+
def forward(self, x, y):
1081+
return torch.minimum(x, y)
1082+
1083+
sample_inputs = (
1084+
torch.rand(size=(3, 5, 6, 4), dtype=torch.float32),
1085+
torch.rand(size=(6, 4), dtype=torch.float32),
1086+
)
1087+
1088+
self.lower_module_and_test_output(
1089+
MinimumModule(),
1090+
sample_inputs,
1091+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1092+
)
1093+
10751094
def test_vulkan_backend_reshape(self):
10761095
class ReshapeModule(torch.nn.Module):
10771096
def __init__(self):

backends/vulkan/tools/gpuinfo/TARGETS

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
load("@fbcode_macros//build_defs:native_rules.bzl", "buck_filegroup")
2+
load("@fbsource//tools/build_defs:fb_xplat_cxx_binary.bzl", "fb_xplat_cxx_binary")
3+
load(
4+
"@fbsource//tools/build_defs:platform_defs.bzl",
5+
"ANDROID",
6+
)
7+
load(
8+
"@fbsource//xplat/executorch/backends/vulkan:targets.bzl",
9+
"vulkan_spv_shader_lib",
10+
)
11+
12+
oncall("executorch")
13+
14+
buck_filegroup(
15+
name = "gpuinfo_shaders",
16+
srcs = glob([
17+
"glsl/*",
18+
]),
19+
visibility = [
20+
"PUBLIC",
21+
],
22+
)
23+
24+
vulkan_spv_shader_lib(
25+
name = "gpuinfo_shader_lib",
26+
spv_filegroups = {
27+
":gpuinfo_shaders": "glsl",
28+
},
29+
)
30+
31+
fb_xplat_cxx_binary(
32+
name = "vulkan_gpuinfo",
33+
srcs = glob([
34+
"**/*.cpp",
35+
]),
36+
headers = glob([
37+
"**/*.h",
38+
]),
39+
header_namespace = "/include",
40+
include_directories = ["/include"],
41+
platforms = ANDROID,
42+
raw_headers = glob([
43+
"**/*.h",
44+
]),
45+
deps = [
46+
":gpuinfo_shader_lib",
47+
"//executorch/backends/vulkan:vulkan_graph_runtime",
48+
],
49+
)

0 commit comments

Comments
 (0)