Skip to content

Commit 645907d

Browse files
author
morelos
committed
Update on "[ET-VK][Ops] choose_qparams op shaders and impl"
Creating the choose_qparams per_tensor and per_token logic shaders and impl which are linked with the testing framework Differential Revision: [D76436933](https://our.internmc.facebook.com/intern/diff/D76436933/) [ghstack-poisoned]
2 parents 5ed4fa0 + 149a7a5 commit 645907d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+1805
-387
lines changed

.ci/scripts/utils.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,14 @@ build_executorch_runner() {
156156
}
157157

158158
cmake_install_executorch_lib() {
159+
build_type="${1:-Release}"
159160
echo "Installing libexecutorch.a and libportable_kernels.a"
160161
clean_executorch_install_folders
161162
retry cmake -DCMAKE_INSTALL_PREFIX=cmake-out \
162-
-DCMAKE_BUILD_TYPE=Release \
163+
-DCMAKE_BUILD_TYPE=${build_type} \
163164
-DPYTHON_EXECUTABLE="$PYTHON_EXECUTABLE" \
164165
-Bcmake-out .
165-
cmake --build cmake-out -j9 --target install --config Release
166+
cmake --build cmake-out -j9 --target install --config ${build_type}
166167
}
167168

168169
download_stories_model_artifacts() {

.github/workflows/lint.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ jobs:
4646
fi
4747
4848
# This has already been cached in the docker image
49-
lintrunner init 2> /dev/null
49+
lintrunner init
5050
5151
RC=0
5252
# Run lintrunner on all files

backends/qualcomm/_passes/i64_to_i32.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@ class I64toI32(ExportPass):
2828
I64_OPS = {
2929
exir_ops.edge.aten.argmin.default,
3030
exir_ops.edge.aten.arange.start_step,
31+
exir_ops.edge.aten.cumsum.default,
3132
exir_ops.edge.aten.full.default,
3233
exir_ops.edge.aten.scalar_tensor.default,
34+
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
3335
}
3436
# This dict is to ensure that the input of the OPs are int64 due to Pytorch restrictions.
3537
# For example, scatter op can only accept args[2], the index, as int64.

backends/qualcomm/_passes/lift_constant_scalar_operands.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ def _build_tensor_constant(
8686
dtype=(
8787
node.args[0].meta["val"].dtype
8888
if not is_float_tensor(node)
89-
and not SCALAR_OPS.get(node.target).use_self_dtype
89+
and (info := SCALAR_OPS.get(node.target))
90+
and not info.use_self_dtype
9091
else node.meta["val"].dtype
9192
),
9293
device=node.meta["val"].device,

backends/qualcomm/_passes/replace_inf_values.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ def call(self, graph_module: torch.fx.GraphModule):
3030
arg_list[index] = torch.finfo(torch.float32).min
3131
elif arg == float("inf"):
3232
arg_list[index] = torch.finfo(torch.float32).max
33+
34+
if node.target == torch.ops.aten.masked_fill.Scalar:
35+
if arg_list[2] == torch.finfo(torch.float32).max:
36+
arg_list[2] = 255
37+
elif arg_list[2] == torch.finfo(torch.float32).min:
38+
arg_list[2] = -255
3339
node.args = tuple(arg_list)
3440

3541
graph_module.recompile()

backends/qualcomm/builders/op_cum_sum.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def define_node(
5151
dim = self.get_param(node, input_tensor)
5252

5353
output_tensor = self.get_tensor(node, node)
54+
if output_tensor.dtype == torch.int64:
55+
output_tensor = output_tensor.to(torch.int32)
5456
output_tensor_wrapper = self.define_tensor(
5557
node,
5658
node,

backends/qualcomm/tests/models.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,6 +1101,16 @@ def forward(self, x):
11011101
return torch.mean(x, (-1, -2))
11021102

11031103

1104+
class MaskedFill(torch.nn.Module):
1105+
def __init__(self):
1106+
super().__init__()
1107+
1108+
def forward(self, attn_mask):
1109+
return attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
1110+
attn_mask == 0, float(0.0)
1111+
)
1112+
1113+
11041114
class Maximum(torch.nn.Module):
11051115
def __init__(self):
11061116
super().__init__()
@@ -1751,16 +1761,6 @@ def forward(self, x):
17511761
)
17521762

17531763

1754-
class MaskedFill(torch.nn.Module):
1755-
def __init__(self):
1756-
super().__init__()
1757-
1758-
def forward(self, attn_mask):
1759-
return attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
1760-
attn_mask == 0, float(0.0)
1761-
)
1762-
1763-
17641764
# Mimi Decoder has 0D tensor which QNN cannot handle.
17651765
class ZeroDimTensor(torch.nn.Module):
17661766
def __init__(self):

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,9 +272,24 @@ def test_qnn_backend_cos(self):
272272
self.lower_module_and_test_output(module, sample_input)
273273

274274
def test_qnn_backend_cumsum(self):
275-
module = CumSum() # noqa: F405
276-
sample_input = (torch.randn(4),)
277-
self.lower_module_and_test_output(module, sample_input)
275+
sample_input = ()
276+
test_comb = [
277+
{
278+
QCOM_MODULE: [CumSum()], # noqa: F405
279+
QCOM_SAMPLE_INPUTS: [
280+
(torch.randn(4),),
281+
(torch.randint(0, 10, size=(4,)),),
282+
],
283+
}
284+
]
285+
286+
index = 0
287+
for comb in test_comb:
288+
for module in comb[QCOM_MODULE]:
289+
for sample_input in comb[QCOM_SAMPLE_INPUTS]:
290+
with self.subTest(i=index):
291+
self.lower_module_and_test_output(module, sample_input)
292+
index += 1
278293

279294
def test_qnn_backend_einsum_outer_product(self):
280295
module = EinsumOuterProduct() # noqa: F405
@@ -311,6 +326,12 @@ def test_qnn_backend_element_wise_add(self):
311326
QCOM_MODULE: [AddConstantFloat()], # noqa: F405
312327
QCOM_SAMPLE_INPUTS: [(torch.randn(2, 5, 1, 3),)],
313328
},
329+
{
330+
QCOM_MODULE: [
331+
AddConstantLong(), # noqa: F405
332+
],
333+
QCOM_SAMPLE_INPUTS: [(torch.randint(0, 10, size=(2, 3)),)],
334+
},
314335
]
315336

316337
index = 0
@@ -4526,6 +4547,40 @@ def test_retinanet(self):
45264547
else:
45274548
self.assertGreaterEqual(msg["mAP"], 0.6)
45284549

4550+
def test_roberta(self):
4551+
if not self.required_envs([self.sentence_dataset]):
4552+
self.skipTest("missing required envs")
4553+
cmds = [
4554+
"python",
4555+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/roberta.py",
4556+
"--dataset",
4557+
self.sentence_dataset,
4558+
"--artifact",
4559+
self.artifact_dir,
4560+
"--build_folder",
4561+
self.build_folder,
4562+
"--device",
4563+
self.device,
4564+
"--model",
4565+
self.model,
4566+
"--ip",
4567+
self.ip,
4568+
"--port",
4569+
str(self.port),
4570+
]
4571+
if self.host:
4572+
cmds.extend(["--host", self.host])
4573+
4574+
p = subprocess.Popen(cmds, stdout=subprocess.DEVNULL)
4575+
with Listener((self.ip, self.port)) as listener:
4576+
conn = listener.accept()
4577+
p.communicate()
4578+
msg = json.loads(conn.recv())
4579+
if "Error" in msg:
4580+
self.fail(msg["Error"])
4581+
else:
4582+
self.assertGreaterEqual(msg["accuracy"], 0.5)
4583+
45294584
def test_squeezenet(self):
45304585
if not self.required_envs([self.image_dataset]):
45314586
self.skipTest("missing required envs")
@@ -5344,6 +5399,11 @@ def setup_environment():
53445399
help="Location for imagenet dataset",
53455400
type=str,
53465401
)
5402+
parser.add_argument(
5403+
"--sentence_dataset",
5404+
help="Location for sentence dataset",
5405+
type=str,
5406+
)
53475407
parser.add_argument(
53485408
"-p",
53495409
"--pretrained_weight",
@@ -5402,6 +5462,7 @@ def setup_environment():
54025462
TestQNN.executorch_root = args.executorch_root
54035463
TestQNN.artifact_dir = args.artifact_dir
54045464
TestQNN.image_dataset = args.image_dataset
5465+
TestQNN.sentence_dataset = args.sentence_dataset
54055466
TestQNN.pretrained_weight = args.pretrained_weight
54065467
TestQNN.model_name = args.model_name
54075468
TestQNN.online_prepare = args.online_prepare

backends/qualcomm/tests/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ class TestQNN(unittest.TestCase):
183183
executorch_root: str = ""
184184
artifact_dir: str = ""
185185
image_dataset: str = ""
186+
sentence_dataset: str = ""
186187
pretrained_weight: str = ""
187188
enable_profile: bool = False
188189
op_package_dir: str = ""

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_buffer.glsl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,18 @@
1212

1313
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
1414

15+
#define ${MODE}
16+
1517
${define_active_storage_type("buffer")}
1618
${define_required_extensions(IN_DTYPE)}
1719

1820
#extension GL_EXT_control_flow_attributes : require
1921

2022
layout(std430) buffer;
2123

22-
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
2324
${layout_declare_tensor(B, "w", "t_scale", "float", "buffer")}
2425
${layout_declare_tensor(B, "w", "t_zero_point", "int", "buffer")}
26+
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "buffer")}
2527

2628
$if MODE == "per_tensor":
2729
layout(push_constant) uniform restrict Block {
@@ -53,11 +55,11 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
5355
shared float shared_min[NWORKERS];
5456
shared float shared_max[NWORKERS];
5557

56-
void main() {
57-
$if MODE == "per_tensor":
58+
#ifdef per_tensor
59+
60+
void choose_qparams_per_tensor() {
5861
uint global_id = gl_GlobalInvocationID.x;
5962
uint local_id = gl_LocalInvocationID.x;
60-
uint group_id = gl_WorkGroupID.x;
6163
uint total_threads = gl_NumWorkGroups.x * gl_WorkGroupSize.x;
6264

6365
uint total_elements = uint(t_in_sizes.x * t_in_sizes.y * t_in_sizes.z * t_in_sizes.w);
@@ -114,8 +116,11 @@ $if MODE == "per_tensor":
114116
t_scale[0] = scale_val;
115117
t_zero_point[0] = zero_point_val;
116118
}
119+
}
120+
121+
#else
117122

118-
$if MODE == "per_token":
123+
void choose_qparams_per_token() {
119124
uint global_id = gl_GlobalInvocationID.x;
120125
uint local_id = gl_LocalInvocationID.x;
121126
uint group_id = gl_WorkGroupID.x;
@@ -201,3 +206,9 @@ $if MODE == "per_token":
201206
barrier();
202207
}
203208
}
209+
210+
#endif
211+
212+
void main() {
213+
choose_qparams_${MODE}();
214+
}

backends/vulkan/runtime/graph/ops/glsl/choose_qparams_texture.glsl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,18 @@
1313
#define IN_T ${buffer_scalar_type(IN_DTYPE)}
1414
#define FVEC4_T ${texel_load_type(IN_DTYPE, "texture3d")}
1515

16+
#define ${MODE}
17+
1618
${define_active_storage_type("texture3d")}
1719
${define_required_extensions(IN_DTYPE)}
1820

1921
#extension GL_EXT_control_flow_attributes : require
2022

2123
layout(std430) buffer;
2224

23-
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}
2425
${layout_declare_tensor(B, "w", "t_scale", "float", "texture3d")}
2526
${layout_declare_tensor(B, "w", "t_zero_point", "int", "texture3d")}
27+
${layout_declare_tensor(B, "r", "t_in", IN_DTYPE, "texture3d")}
2628

2729
$if MODE == "per_tensor":
2830
layout(push_constant) uniform restrict Block {
@@ -51,8 +53,9 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
5153
shared float shared_min[NWORKERS];
5254
shared float shared_max[NWORKERS];
5355

54-
void main() {
55-
$if MODE == "per_tensor":
56+
#ifdef per_tensor
57+
58+
void choose_qparams_per_tensor() {
5659
uint global_id = gl_GlobalInvocationID.x;
5760
uint local_id = gl_LocalInvocationID.x;
5861
uint group_id = gl_WorkGroupID.x;
@@ -85,7 +88,7 @@ $if MODE == "per_tensor":
8588
// Calculate total tensor elements to determine padding
8689
int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4;
8790
int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x +
88-
tensor_coord.z * sizes.x * sizes.y;
91+
tensor_coord.z * sizes.x * sizes.y;
8992
int remaining_elements = total_elements - (linear_tensor_idx);
9093
int valid_elements = min(4, remaining_elements);
9194

@@ -168,8 +171,11 @@ $if MODE == "per_tensor":
168171
write_texel(t_scale, ivec3(0, 0, 0), vec4(scale_val, 0.0, 0.0, 0.0));
169172
write_texel(t_zero_point, ivec3(0, 0, 0), ivec4(zero_point_val, 0, 0, 0));
170173
}
174+
}
175+
176+
#else
171177

172-
$if MODE == "per_token":
178+
void choose_qparams_per_token() {
173179
// Each token is processed by multiple workgroups for parallel reduction
174180
uint local_id = gl_LocalInvocationID.x;
175181
uint group_id = gl_WorkGroupID.x;
@@ -219,7 +225,7 @@ $if MODE == "per_token":
219225
// Calculate total tensor elements to determine padding
220226
int total_elements = t_in_limits.x * t_in_limits.y * t_in_limits.z * 4;
221227
int linear_tensor_idx = tensor_coord.x + tensor_coord.y * sizes.x +
222-
tensor_coord.z * sizes.x * sizes.y;
228+
tensor_coord.z * sizes.x * sizes.y;
223229
int remaining_elements = total_elements - (linear_tensor_idx);
224230
int valid_elements = min(4, remaining_elements);
225231

@@ -316,3 +322,9 @@ $if MODE == "per_token":
316322
barrier();
317323
}
318324
}
325+
326+
#endif
327+
328+
void main() {
329+
choose_qparams_${MODE}();
330+
}

0 commit comments

Comments
 (0)