Skip to content

Commit daacbe2

Browse files
committed
Update base for Update on "[ET-VK][Ops] aten.convolution (Bias=False)"
The final touches to get ET-VK convolution on-par with ATen-VK's convolution. ## Idea In our shaders, we add the bias to our sum. ``` ${VEC4_T[DTYPE]} sum = texelFetch(bias_in, ivec2(pos.z, 0), 0); ``` To keep our shaders as is, we implement having no bias by allocating a buffer of zeros. Then, our shader adds zero to our sum. ## Issue If `Bias=False`, dummy buffer of zeros is not serialized with the graph. The bias ValueRef is deserialized in the runtime as `TypeTag::NONE`, not `TypeTag::TENSORREF`. ## Solution If `TypeTag::NONE` is given, (1) create the `vTensor` using the `out_channels` value from the weights and (2) allocate a StagingBuffer of that size. The StagingBuffer will be transferred to GPU memory and initialized to zeros. Differential Revision: [D55814589](https://our.internmc.facebook.com/intern/diff/D55814589/) [ghstack-poisoned]
2 parents a85f33e + 99c4f4e commit daacbe2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+1058
-602
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
buck-out/
33
cmake-out/
44
cmake-android-out/
5+
cmake-out-android/
56
cmake-ios-out/
67
ethos-u-scratch/
78
executorch.egg-info

CMakeLists.txt

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -352,23 +352,27 @@ add_subdirectory(schema)
352352
# Only contains primitive operators; does not contain portable kernels or other
353353
# full operators. Does not contain any backends.
354354
#
355-
356-
add_library(executorch ${_executorch__srcs})
357-
target_link_libraries(executorch PRIVATE program_schema)
358-
target_link_options_shared_lib(executorch)
355+
add_library(executorch_no_prim_ops ${_executorch_no_prim_ops__srcs})
356+
target_link_libraries(executorch_no_prim_ops PRIVATE program_schema)
359357
# Check if dl exists for this toolchain and only then link it.
360358
find_library(DL_LIBRARY_EXISTS NAMES dl)
361359
# Check if the library was found
362360
if(DL_LIBRARY_EXISTS)
363-
target_link_libraries(executorch PRIVATE dl) # For dladdr()
361+
target_link_libraries(executorch_no_prim_ops PRIVATE dl) # For dladdr()
364362
endif()
365-
target_include_directories(executorch PUBLIC ${_common_include_directories})
366-
target_compile_options(executorch PUBLIC ${_common_compile_options})
363+
target_include_directories(executorch_no_prim_ops PUBLIC ${_common_include_directories})
364+
target_compile_options(executorch_no_prim_ops PUBLIC ${_common_compile_options})
367365
if(MAX_KERNEL_NUM)
368-
target_compile_definitions(executorch
366+
target_compile_definitions(executorch_no_prim_ops
369367
PRIVATE MAX_KERNEL_NUM=${MAX_KERNEL_NUM})
370368
endif()
371369

370+
add_library(executorch ${_executorch__srcs})
371+
target_link_libraries(executorch PRIVATE executorch_no_prim_ops)
372+
target_include_directories(executorch PUBLIC ${_common_include_directories})
373+
target_compile_options(executorch PUBLIC ${_common_compile_options})
374+
target_link_options_shared_lib(executorch)
375+
372376
#
373377
# portable_ops_lib: A library to register core ATen ops using portable kernels,
374378
# see kernels/portable/CMakeLists.txt.
@@ -406,7 +410,7 @@ endif()
406410
# Install `executorch` library as well as `executorch-config.cmake` under
407411
# ${CMAKE_INSTALL_PREFIX}/
408412
install(
409-
TARGETS executorch
413+
TARGETS executorch executorch_no_prim_ops
410414
DESTINATION lib
411415
INCLUDES
412416
DESTINATION ${_common_include_directories})

backends/qualcomm/builders/op_dequantize.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,16 @@ def define_node(
5656

5757

5858
@register_node_visitor
59-
class PerTensorDequantizeDefault(DequantizeOpBase):
60-
target = ["quantized_decomposed.dequantize_per_tensor.default"]
59+
class PerTensorDequantize(DequantizeOpBase):
60+
target = [
61+
"quantized_decomposed.dequantize_per_tensor.default",
62+
"quantized_decomposed.dequantize_per_tensor.tensor",
63+
]
6164

6265

6366
@register_node_visitor
64-
class PerTensorDequantizeTensor(DequantizeOpBase):
65-
target = ["quantized_decomposed.dequantize_per_tensor.tensor"]
66-
67-
68-
@register_node_visitor
69-
class PerChannelDequantizeDefault(DequantizeOpBase):
70-
target = ["quantized_decomposed.dequantize_per_channel.default"]
71-
72-
73-
@register_node_visitor
74-
class PerChannelDequantizeTensor(DequantizeOpBase):
75-
target = ["quantized_decomposed.dequantize_per_channel.tensor"]
67+
class PerChannelDequantize(DequantizeOpBase):
68+
target = [
69+
"quantized_decomposed.dequantize_per_channel.default",
70+
"quantized_decomposed.dequantize_per_channel.tensor",
71+
]

backends/qualcomm/passes/convert_hardsigmoid.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ def call(self, graph_module: torch.fx.GraphModule):
2525
partitions = get_source_partitions(graph, [torch.nn.Hardsigmoid])
2626
for _, src_partitions in partitions.items():
2727
for src_partition in src_partitions:
28+
if exir_ops.edge.aten.hardswish.default in [
29+
node.target for node in src_partition.nodes
30+
]:
31+
continue
2832
if self.quantization_capture:
2933
# only one hardsigmoid op will be seen
3034
input_nodes = src_partition.input_nodes
@@ -34,8 +38,6 @@ def call(self, graph_module: torch.fx.GraphModule):
3438
else:
3539
in_ops_target = exir_ops.edge.aten.add.Tensor
3640
out_ops_target = exir_ops.edge.aten.div.Tensor
37-
# see the reverse engineering logic hardswish
38-
# https://shorturl.at/pACEL
3941
input_nodes = [
4042
n for n in src_partition.nodes if n.target is in_ops_target
4143
]

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66
import json
77
import subprocess
88
import sys
9+
import tempfile
910
import unittest
1011
from multiprocessing.connection import Listener
12+
from pathlib import Path
1113

1214
import torch
1315
from executorch.backends.qualcomm.tests.utils import (
@@ -1102,6 +1104,19 @@ def test_qnn_backend_shared_buffer(self):
11021104
expected_partitions=1,
11031105
)
11041106

1107+
def test_qnn_backend_online_prepare(self):
1108+
backend_options = generate_htp_compiler_spec(use_fp16=True)
1109+
TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1110+
soc_model=self.arch_table[TestQNN.model],
1111+
backend_options=backend_options,
1112+
debug=False,
1113+
saver=False,
1114+
online_prepare=True,
1115+
)
1116+
module = SimpleModel() # noqa: F405
1117+
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1118+
self.lower_module_and_test_output(module, sample_input)
1119+
11051120

11061121
class TestQNNQuantizedUtils(TestQNN):
11071122
# TODO: refactor to support different backends
@@ -1223,6 +1238,20 @@ def test_qnn_backend_shared_buffer(self):
12231238
expected_partitions=1,
12241239
)
12251240

1241+
def test_qnn_backend_online_prepare(self):
1242+
backend_options = generate_htp_compiler_spec(use_fp16=False)
1243+
TestQNN.compiler_specs = generate_qnn_executorch_compiler_spec(
1244+
soc_model=self.arch_table[TestQNN.model],
1245+
backend_options=backend_options,
1246+
debug=False,
1247+
saver=False,
1248+
online_prepare=True,
1249+
)
1250+
module = SimpleModel() # noqa: F405
1251+
sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28))
1252+
module = self.get_qdq_module(module, sample_input)
1253+
self.lower_module_and_test_output(module, sample_input)
1254+
12261255

12271256
class TestExampleOssScript(TestQNN):
12281257
def required_envs(self, conditions=None) -> bool:
@@ -1640,6 +1669,29 @@ def test_ptq_mobilebert(self):
16401669
for k, v in cpu.items():
16411670
self.assertLessEqual(abs(v[0] - htp[k][0]), 5)
16421671

1672+
def test_export_example(self):
1673+
if not self.required_envs([self.model_name]):
1674+
self.skipTest("missing required envs")
1675+
1676+
with tempfile.TemporaryDirectory() as tmp_dir:
1677+
cmds = [
1678+
"python",
1679+
"qualcomm/scripts/export_example.py",
1680+
"--model_name",
1681+
self.model_name,
1682+
"--output_folder",
1683+
"{}/".format(tmp_dir),
1684+
"--generate_etrecord",
1685+
]
1686+
1687+
p = subprocess.Popen(
1688+
cmds, stdout=subprocess.DEVNULL, cwd=f"{self.executorch_root}/examples"
1689+
)
1690+
p.communicate()
1691+
self.assertTrue(
1692+
Path("{0}/{1}.pte".format(tmp_dir, self.model_name)).exists()
1693+
)
1694+
16431695

16441696
def setup_environment():
16451697
parser = setup_common_args_and_variables()
@@ -1669,6 +1721,12 @@ def setup_environment():
16691721
default="",
16701722
type=str,
16711723
)
1724+
parser.add_argument(
1725+
"-n",
1726+
"--model_name",
1727+
help="Input the model to export",
1728+
type=str,
1729+
)
16721730
parser.add_argument(
16731731
"-o",
16741732
"--online_prepare",
@@ -1697,6 +1755,7 @@ def setup_environment():
16971755
TestQNN.artifact_dir = args.artifact_dir
16981756
TestQNN.image_dataset = args.image_dataset
16991757
TestQNN.pretrained_weight = args.pretrained_weight
1758+
TestQNN.model_name = args.model_name
17001759
TestQNN.online_prepare = args.online_prepare
17011760
TestQNN.enable_profile = args.enable_profile
17021761
TestQNN.error_only = args.error_only

backends/qualcomm/utils/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
ConvertBinaryOpsWithScalar,
2020
)
2121
from executorch.backends.qualcomm.passes.convert_bmm_to_matmul import ConvertBmmToMatmul
22+
from executorch.backends.qualcomm.passes.convert_hardsigmoid import ConvertHardsigmoid
2223
from executorch.backends.qualcomm.passes.convert_interpolate_with_upsample2d import (
2324
ConvertInterpolateWithUpsample2D,
2425
)
@@ -103,6 +104,7 @@ def _transform(edge_program: ExportedProgram) -> None:
103104
graph_module = edge_program.graph_module
104105
RemoveClone()(graph_module)
105106
ConvertToLinear()(graph_module)
107+
ConvertHardsigmoid()(graph_module)
106108
ConvertBmmToMatmul()(graph_module)
107109
ConvertInterpolateWithUpsample2D()(graph_module)
108110
I64toI32(edge_program)(graph_module)

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,26 +77,26 @@ api::StorageType get_storage_type(
7777
const vkgraph::VkStorageType& vk_storage_type) {
7878
switch (vk_storage_type) {
7979
case vkgraph::VkStorageType::BUFFER:
80-
return api::StorageType::BUFFER;
80+
return api::kBuffer;
8181
case vkgraph::VkStorageType::TEXTURE_3D:
82-
return api::StorageType::TEXTURE_3D;
82+
return api::kTexture3D;
8383
case vkgraph::VkStorageType::TEXTURE_2D:
84-
return api::StorageType::TEXTURE_2D;
84+
return api::kTexture2D;
8585
default:
8686
break;
8787
}
88-
return api::StorageType::UNKNOWN;
88+
VK_THROW("Invalid storage type encountered!");
8989
}
9090

9191
api::GPUMemoryLayout get_memory_layout(
9292
const vkgraph::VkMemoryLayout& vk_memory_layout) {
9393
switch (vk_memory_layout) {
9494
case vkgraph::VkMemoryLayout::TENSOR_WIDTH_PACKED:
95-
return api::GPUMemoryLayout::TENSOR_WIDTH_PACKED;
95+
return api::kWidthPacked;
9696
case vkgraph::VkMemoryLayout::TENSOR_HEIGHT_PACKED:
97-
return api::GPUMemoryLayout::TENSOR_HEIGHT_PACKED;
97+
return api::kHeightPacked;
9898
case vkgraph::VkMemoryLayout::TENSOR_CHANNELS_PACKED:
99-
return api::GPUMemoryLayout::TENSOR_CHANNELS_PACKED;
99+
return api::kChannelsPacked;
100100
default:
101101
break;
102102
}

backends/vulkan/runtime/api/Shader.cpp

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,38 +23,19 @@ ShaderInfo::ShaderInfo()
2323
0u,
2424
} {}
2525

26-
ShaderInfo::ShaderInfo(
27-
std::string name,
28-
const uint32_t* const spirv_bin,
29-
const uint32_t size,
30-
std::vector<VkDescriptorType> layout)
31-
: src_code{
32-
spirv_bin,
33-
size,
34-
},
35-
kernel_name{std::move(name)},
36-
kernel_layout{std::move(layout)} {}
37-
3826
ShaderInfo::ShaderInfo(
3927
std::string name,
4028
const uint32_t* const spirv_bin,
4129
const uint32_t size,
4230
std::vector<VkDescriptorType> layout,
43-
const std::vector<uint32_t>& tile_size,
44-
const StorageType bias_storage_type,
45-
const StorageType weight_storage_type)
31+
const utils::uvec3 tile_size)
4632
: src_code{
4733
spirv_bin,
4834
size,
4935
},
5036
kernel_name{std::move(name)},
5137
kernel_layout{std::move(layout)},
52-
tile_size(tile_size),
53-
bias_storage_type(bias_storage_type),
54-
weight_storage_type(weight_storage_type) {
55-
for (uint64_t i = 0; i < tile_size.size(); ++i) {
56-
out_tile_size.data[i] = tile_size[i];
57-
}
38+
out_tile_size(tile_size) {
5839
}
5940

6041
bool operator==(const ShaderInfo& _1, const ShaderInfo& _2) {

backends/vulkan/runtime/api/Shader.h

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -62,25 +62,14 @@ struct ShaderInfo final {
6262
// Shader Metadata
6363
utils::uvec3 out_tile_size{1u, 1u, 1u};
6464

65-
std::vector<uint32_t> tile_size;
66-
StorageType bias_storage_type{StorageType::UNKNOWN};
67-
StorageType weight_storage_type{StorageType::UNKNOWN};
68-
6965
explicit ShaderInfo();
70-
explicit ShaderInfo(std::string, const char*);
71-
explicit ShaderInfo(
72-
std::string,
73-
const uint32_t*,
74-
const uint32_t,
75-
std::vector<VkDescriptorType>);
66+
7667
explicit ShaderInfo(
7768
std::string,
7869
const uint32_t*,
7970
const uint32_t,
8071
std::vector<VkDescriptorType>,
81-
const std::vector<uint32_t>& tile_size,
82-
const StorageType bias_storage_type,
83-
const StorageType weight_storage_type);
72+
const utils::uvec3 tile_size);
8473
};
8574

8675
bool operator==(const ShaderInfo& _1, const ShaderInfo& _2);

0 commit comments

Comments
 (0)