Skip to content

Commit 189b04f

Browse files
committed
[ET-VK][14/n] Add operators to Partitioner
Pull Request resolved: #3407 1. Register aten operators in the vulkan partitioner. 2. Fix some minor operators name issue due to mismatch between the torch api and actual aten name Note: Permute is not yet registered due to tensor movement issues with the "Partial" model where the `Linear` operator is decomposed into `permute` and `addmm`. Will fix in later diffs. Differential Revision: [D56695929](https://our.internmc.facebook.com/intern/diff/D56695929/) ghstack-source-id: 224365109
1 parent 23e04e2 commit 189b04f

File tree

4 files changed

+137
-7
lines changed

4 files changed

+137
-7
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,12 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5656
exir_ops.edge.aten.select_copy.int,
5757
exir_ops.edge.aten.unsqueeze_copy.default,
5858
exir_ops.edge.aten.view_copy.default,
59+
# Copy-releated operators
60+
exir_ops.edge.aten.clone.default,
61+
exir_ops.edge.aten.cat.default,
62+
exir_ops.edge.aten.split_with_sizes_copy.default,
63+
exir_ops.edge.aten.split.Tensor,
64+
exir_ops.edge.aten.slice_copy.Tensor,
5965
# Other
6066
operator.getitem,
6167
exir_ops.edge.aten.full.default,

backends/vulkan/runtime/graph/ops/impl/Split.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ void add_split_with_sizes_default_node(
106106
add_split_with_sizes_default_node(graph, in, split_sizes, dim, out);
107107
}
108108

109-
void split_with_sizes_default(
109+
void split_with_sizes_copy_default(
110110
ComputeGraph& graph,
111111
const std::vector<ValueRef>& args) {
112112
add_split_with_sizes_default_node(graph, args[0], args[1], args[2], args[3]);
@@ -134,7 +134,8 @@ void split_tensor(ComputeGraph& graph, const std::vector<ValueRef>& args) {
134134
}
135135

136136
REGISTER_OPERATORS {
137-
VK_REGISTER_OP(aten.split_with_sizes.default, split_with_sizes_default);
137+
VK_REGISTER_OP(
138+
aten.split_with_sizes_copy.default, split_with_sizes_copy_default);
138139
VK_REGISTER_OP(aten.split.Tensor, split_tensor);
139140
}
140141

backends/vulkan/serialization/vulkan_graph_builder.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ def create_node_value(self, node: Node) -> int:
133133
new_id = self.create_tensor_value(spec, constant_id)
134134
self.node_to_value_ids[node] = new_id
135135
return new_id
136-
elif isinstance(spec, tuple):
137-
# Create a Value for each element in the tuple, wrap Values in a
138-
# ValueList, and map the Node to the ValueList id.
136+
elif isinstance(spec, list) or isinstance(spec, tuple):
137+
# pyre-ignore[6]: pyre having hard time to infer Node type inside
138+
# the container.
139139
new_id = self.create_value_list_value(spec)
140140
self.node_to_value_ids[node] = new_id
141141
return new_id
@@ -202,7 +202,7 @@ def create_scalar_list_value(self, arg: List[_ScalarType]) -> int:
202202
)
203203
return new_id
204204

205-
def create_value_list_value(self, arg: List[Node] | tuple) -> int:
205+
def create_value_list_value(self, arg: tuple | list) -> int:
206206
self.values.append(
207207
vk_graph_schema.VkValue(
208208
vk_graph_schema.ValueList(
@@ -242,7 +242,6 @@ def get_or_create_value_for(self, arg: _Argument):
242242
# pyre-ignore[6]
243243
return self.create_scalar_list_value(arg)
244244
elif isinstance(arg, list) and isinstance(arg[0], Node):
245-
# pyre-ignore[6]
246245
return self.create_value_list_value(arg)
247246
elif isinstance(arg, str):
248247
return self.create_string_value(arg)

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ def run_test(memory_layout):
8080
compile_options = {
8181
"memory_layout_override": memory_layout,
8282
}
83+
84+
# At least model should run in eager mode.
85+
model(*sample_inputs)
86+
8387
program: ExportedProgram = export(
8488
model, sample_inputs, dynamic_shapes=dynamic_shapes
8589
)
@@ -798,3 +802,123 @@ def forward(self, x):
798802
sample_inputs,
799803
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
800804
)
805+
806+
def DISABLED_test_vulkan_backend_permute_copy(self):
807+
# aten.permute_copy.default is not enabled yet in partitioner
808+
class PermuteModule(torch.nn.Module):
809+
def __init__(self):
810+
super().__init__()
811+
812+
def forward(self, x):
813+
return torch.permute(x, [3, 0, 2, 1])
814+
815+
sample_inputs = (torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),)
816+
817+
self.lower_module_and_test_output(
818+
PermuteModule(),
819+
sample_inputs,
820+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
821+
)
822+
823+
def test_vulkan_backend_cat(self):
824+
class TestModule(torch.nn.Module):
825+
def __init__(self):
826+
super().__init__()
827+
828+
def forward(self, x, y, z, w):
829+
return torch.cat([x, y, z, w], dim=1)
830+
831+
sample_inputs = (
832+
torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),
833+
torch.randn(size=(3, 1, 2, 7), dtype=torch.float32),
834+
torch.randn(size=(3, 9, 2, 7), dtype=torch.float32),
835+
torch.randn(size=(3, 3, 2, 7), dtype=torch.float32),
836+
)
837+
838+
self.lower_module_and_test_output(
839+
TestModule(),
840+
sample_inputs,
841+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
842+
)
843+
844+
def test_vulkan_backend_slice(self):
845+
class TestModule(torch.nn.Module):
846+
def __init__(self):
847+
super().__init__()
848+
849+
def forward(self, x):
850+
return x[:, 2:9:2, :]
851+
852+
sample_inputs = (torch.randn(size=(3, 13, 7, 3), dtype=torch.float32),)
853+
854+
self.lower_module_and_test_output(
855+
TestModule(),
856+
sample_inputs,
857+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
858+
)
859+
860+
def test_vulkan_backend_split_with_sizes(self):
861+
class TestModule(torch.nn.Module):
862+
def __init__(self):
863+
super().__init__()
864+
865+
def forward(self, x):
866+
return torch.split(x, (3, 6, 1, 3), dim=1)
867+
868+
sample_inputs = (torch.randn(size=(3, 13, 7, 3), dtype=torch.float32),)
869+
870+
self.lower_module_and_test_output(
871+
TestModule(),
872+
sample_inputs,
873+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
874+
)
875+
876+
def test_vulkan_backend_split_tensor(self):
877+
class TestModule(torch.nn.Module):
878+
def __init__(self):
879+
super().__init__()
880+
881+
def forward(self, x):
882+
return torch.tensor_split(x, 2, dim=1)
883+
884+
sample_inputs = (torch.randn(size=(3, 14, 7, 3), dtype=torch.float32),)
885+
886+
self.lower_module_and_test_output(
887+
TestModule(),
888+
sample_inputs,
889+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
890+
)
891+
892+
def test_vulkan_backend_clone(self):
893+
class TestModule(torch.nn.Module):
894+
def __init__(self):
895+
super().__init__()
896+
897+
def forward(self, x):
898+
return torch.clone(x)
899+
900+
sample_inputs = (torch.randn(size=(3, 14, 7, 3), dtype=torch.float32),)
901+
902+
self.lower_module_and_test_output(
903+
TestModule(),
904+
sample_inputs,
905+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
906+
)
907+
908+
def DISABLED_test_vulkan_backend_t_default(self):
909+
# aten.permute_copy.default is not enabled yet in partitioner
910+
class TestModule(torch.nn.Module):
911+
def __init__(self):
912+
super().__init__()
913+
914+
def forward(self, x):
915+
# torch.t is actually exported as aten::permute.
916+
return torch.t(x)
917+
918+
sample_inputs = (torch.randn(size=(3, 14), dtype=torch.float32),)
919+
920+
self.lower_module_and_test_output(
921+
TestModule(),
922+
sample_inputs,
923+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
924+
)

0 commit comments

Comments
 (0)