Skip to content

Commit 49c953c

Browse files
committed
[ET-VK][14/n] Add operators to Partitioner
1. Register aten operators in the vulkan partitioner. 2. Fix some minor operators name issue due to mismatch between the torch api and actual aten name Note: Permute is not yet registered due to tensor movement issues with the "Partial" model where the `Linear` operator is decomposed into `permute` and `addmm`. Will fix in later diffs. Differential Revision: [D56695929](https://our.internmc.facebook.com/intern/diff/D56695929/) ghstack-source-id: 224361378 Pull Request resolved: #3407
1 parent 23e04e2 commit 49c953c

File tree

4 files changed

+138
-8
lines changed

4 files changed

+138
-8
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5656
exir_ops.edge.aten.select_copy.int,
5757
exir_ops.edge.aten.unsqueeze_copy.default,
5858
exir_ops.edge.aten.view_copy.default,
59+
# Copy-releated operators
60+
exir_ops.edge.aten.clone.default,
61+
exir_ops.edge.aten.cat.default,
62+
exir_ops.edge.aten.split_with_sizes_copy.default,
63+
exir_ops.edge.aten.split.Tensor,
64+
exir_ops.edge.aten.slice_copy.Tensor,
5965
# Other
6066
operator.getitem,
6167
exir_ops.edge.aten.full.default,

backends/vulkan/runtime/graph/ops/impl/Split.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ void add_split_with_sizes_default_node(
106106
add_split_with_sizes_default_node(graph, in, split_sizes, dim, out);
107107
}
108108

109-
void split_with_sizes_default(
109+
void split_with_sizes_copy_default(
110110
ComputeGraph& graph,
111111
const std::vector<ValueRef>& args) {
112112
add_split_with_sizes_default_node(graph, args[0], args[1], args[2], args[3]);
@@ -134,7 +134,8 @@ void split_tensor(ComputeGraph& graph, const std::vector<ValueRef>& args) {
134134
}
135135

136136
REGISTER_OPERATORS {
137-
VK_REGISTER_OP(aten.split_with_sizes.default, split_with_sizes_default);
137+
VK_REGISTER_OP(
138+
aten.split_with_sizes_copy.default, split_with_sizes_copy_default);
138139
VK_REGISTER_OP(aten.split.Tensor, split_tensor);
139140
}
140141

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import operator
88
from types import NoneType
9-
from typing import cast, List, Optional, Union
9+
from typing import cast, List, Optional, Tuple, Union
1010

1111
import executorch.backends.vulkan.serialization.vulkan_graph_schema as vk_graph_schema
1212

@@ -133,9 +133,9 @@ def create_node_value(self, node: Node) -> int:
133133
new_id = self.create_tensor_value(spec, constant_id)
134134
self.node_to_value_ids[node] = new_id
135135
return new_id
136-
elif isinstance(spec, tuple):
137-
# Create a Value for each element in the tuple, wrap Values in a
138-
# ValueList, and map the Node to the ValueList id.
136+
elif isinstance(spec, list) or isinstance(spec, tuple):
137+
# pyre-ignore[6]: pyre having hard time to infer Node type inside
138+
# the container.
139139
new_id = self.create_value_list_value(spec)
140140
self.node_to_value_ids[node] = new_id
141141
return new_id
@@ -202,7 +202,7 @@ def create_scalar_list_value(self, arg: List[_ScalarType]) -> int:
202202
)
203203
return new_id
204204

205-
def create_value_list_value(self, arg: List[Node] | tuple) -> int:
205+
def create_value_list_value(self, arg: tuple | list) -> int:
206206
self.values.append(
207207
vk_graph_schema.VkValue(
208208
vk_graph_schema.ValueList(
@@ -242,7 +242,6 @@ def get_or_create_value_for(self, arg: _Argument):
242242
# pyre-ignore[6]
243243
return self.create_scalar_list_value(arg)
244244
elif isinstance(arg, list) and isinstance(arg[0], Node):
245-
# pyre-ignore[6]
246245
return self.create_value_list_value(arg)
247246
elif isinstance(arg, str):
248247
return self.create_string_value(arg)

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ def run_test(memory_layout):
8080
compile_options = {
8181
"memory_layout_override": memory_layout,
8282
}
83+
84+
# At least model should run in eager mode.
85+
model(*sample_inputs)
86+
8387
program: ExportedProgram = export(
8488
model, sample_inputs, dynamic_shapes=dynamic_shapes
8589
)
@@ -798,3 +802,123 @@ def forward(self, x):
798802
sample_inputs,
799803
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
800804
)
805+
806+
def DISABLED_test_vulkan_backend_permute_copy(self):
807+
# aten.permute_copy.default is not enabled yet in partitioner
808+
class PermuteModule(torch.nn.Module):
809+
def __init__(self):
810+
super().__init__()
811+
812+
def forward(self, x):
813+
return torch.permute(x, [3, 0, 2, 1])
814+
815+
sample_inputs = (torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),)
816+
817+
self.lower_module_and_test_output(
818+
PermuteModule(),
819+
sample_inputs,
820+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
821+
)
822+
823+
def test_vulkan_backend_cat(self):
824+
class TestModule(torch.nn.Module):
825+
def __init__(self):
826+
super().__init__()
827+
828+
def forward(self, x, y, z, w):
829+
return torch.cat([x, y, z, w], dim=1)
830+
831+
sample_inputs = (
832+
torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),
833+
torch.randn(size=(3, 1, 2, 7), dtype=torch.float32),
834+
torch.randn(size=(3, 9, 2, 7), dtype=torch.float32),
835+
torch.randn(size=(3, 3, 2, 7), dtype=torch.float32),
836+
)
837+
838+
self.lower_module_and_test_output(
839+
TestModule(),
840+
sample_inputs,
841+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
842+
)
843+
844+
def test_vulkan_backend_slice(self):
845+
class TestModule(torch.nn.Module):
846+
def __init__(self):
847+
super().__init__()
848+
849+
def forward(self, x):
850+
return x[:, 2:9:2, :]
851+
852+
sample_inputs = (torch.randn(size=(3, 13, 7, 3), dtype=torch.float32),)
853+
854+
self.lower_module_and_test_output(
855+
TestModule(),
856+
sample_inputs,
857+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
858+
)
859+
860+
def test_vulkan_backend_split_with_sizes(self):
861+
class TestModule(torch.nn.Module):
862+
def __init__(self):
863+
super().__init__()
864+
865+
def forward(self, x):
866+
return torch.split(x, (3, 6, 1, 3), dim=1)
867+
868+
sample_inputs = (torch.randn(size=(3, 13, 7, 3), dtype=torch.float32),)
869+
870+
self.lower_module_and_test_output(
871+
TestModule(),
872+
sample_inputs,
873+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
874+
)
875+
876+
def test_vulkan_backend_split_tensor(self):
877+
class TestModule(torch.nn.Module):
878+
def __init__(self):
879+
super().__init__()
880+
881+
def forward(self, x):
882+
return torch.tensor_split(x, 2, dim=1)
883+
884+
sample_inputs = (torch.randn(size=(3, 14, 7, 3), dtype=torch.float32),)
885+
886+
self.lower_module_and_test_output(
887+
TestModule(),
888+
sample_inputs,
889+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
890+
)
891+
892+
def test_vulkan_backend_clone(self):
893+
class TestModule(torch.nn.Module):
894+
def __init__(self):
895+
super().__init__()
896+
897+
def forward(self, x):
898+
return torch.clone(x)
899+
900+
sample_inputs = (torch.randn(size=(3, 14, 7, 3), dtype=torch.float32),)
901+
902+
self.lower_module_and_test_output(
903+
TestModule(),
904+
sample_inputs,
905+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
906+
)
907+
908+
def DISABLED_test_vulkan_backend_t_default(self):
909+
# aten.permute_copy.default is not enabled yet in partitioner
910+
class TestModule(torch.nn.Module):
911+
def __init__(self):
912+
super().__init__()
913+
914+
def forward(self, x):
915+
# torch.t is actually exported as aten::permute.
916+
return torch.t(x)
917+
918+
sample_inputs = (torch.randn(size=(3, 14), dtype=torch.float32),)
919+
920+
self.lower_module_and_test_output(
921+
TestModule(),
922+
sample_inputs,
923+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
924+
)

0 commit comments

Comments
 (0)