Skip to content

Commit b2a7243

Browse files
copyrightlyfacebook-github-bot
authored andcommitted
register view, reshape and select
Summary: - We register `select`, `unsqueeze` and `view` in `vulkan_partitioner.py` in order to run vulkan_delegate test (Python e2e test). The latter two might be used to implement `bmm` and `addmm`, so I want to make sure they work. - We register `reshape` in `View.cpp` explicitly. `reshape` is implemented through `_reshape_alias` (see [this](https://www.internalfb.com/code/fbsource/[a3dd6401f00d73f09bbdea63887fef54ea2c6dd2]/fbcode/caffe2/aten/src/ATen/native/native_functions.yaml?lines=4872-4881)) which is [decomposed as `view`](https://www.internalfb.com/code/fbsource/[bbb783ae1cff98b3b549da3edd845dde946d3da8]/xplat/caffe2/torch/_decomp/decompositions.py?lines=3669-3672). For codegen test, we still need to register the op, otherwise there is error ``` C++ exception with description "Exception raised from get_op_fn at xplat/executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.cpp:20: (it != table_.end()) is false! Could not find operator with name aten.reshape.default" thrown in the test body. ``` Reviewed By: yipjustin, liuk22 Differential Revision: D56454941 fbshipit-source-id: c83e6fb97d9cf9019cc6e786508f353a22236931
1 parent 590cbce commit b2a7243

File tree

2 files changed

+70
-0
lines changed

2 files changed

+70
-0
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,10 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
5252
exir_ops.edge.aten.convolution.default,
5353
# Normalization
5454
exir_ops.edge.aten.native_layer_norm.default,
55+
# Shape-related operators
56+
exir_ops.edge.aten.select_copy.int,
57+
exir_ops.edge.aten.unsqueeze_copy.default,
58+
exir_ops.edge.aten.view_copy.default,
5559
# Other
5660
operator.getitem,
5761
exir_ops.edge.aten.full.default,

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -732,3 +732,69 @@ def forward(self, x):
732732
sample_inputs,
733733
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
734734
)
735+
736+
def test_vulkan_backend_reshape(self):
737+
class ReshapeModule(torch.nn.Module):
738+
def __init__(self):
739+
super().__init__()
740+
741+
def forward(self, x):
742+
return torch.reshape(x, [-1, x.size(-1)])
743+
744+
sample_inputs = (torch.randn(size=(5, 3, 4), dtype=torch.float32),)
745+
746+
self.lower_module_and_test_output(
747+
ReshapeModule(),
748+
sample_inputs,
749+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
750+
)
751+
752+
def test_vulkan_backend_view(self):
753+
class ViewModule(torch.nn.Module):
754+
def __init__(self):
755+
super().__init__()
756+
757+
def forward(self, x):
758+
return x.view([-1, x.size(-1)])
759+
760+
sample_inputs = (torch.randn(size=(3, 2, 3, 4), dtype=torch.float32),)
761+
762+
self.lower_module_and_test_output(
763+
ViewModule(),
764+
sample_inputs,
765+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
766+
)
767+
768+
def test_vulkan_backend_unsqueeze(self):
769+
class UnsqueezeModule(torch.nn.Module):
770+
def __init__(self):
771+
super().__init__()
772+
773+
def forward(self, x):
774+
x = torch.unsqueeze(x, 1)
775+
x = torch.unsqueeze(x, 0)
776+
return x
777+
778+
sample_inputs = (torch.randn(size=(3,), dtype=torch.float32),)
779+
780+
self.lower_module_and_test_output(
781+
UnsqueezeModule(),
782+
sample_inputs,
783+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
784+
)
785+
786+
def test_vulkan_backend_select(self):
787+
class SelectModule(torch.nn.Module):
788+
def __init__(self):
789+
super().__init__()
790+
791+
def forward(self, x):
792+
return x[0][3]
793+
794+
sample_inputs = (torch.randn(size=(3, 6, 2, 7), dtype=torch.float32),)
795+
796+
self.lower_module_and_test_output(
797+
SelectModule(),
798+
sample_inputs,
799+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
800+
)

0 commit comments

Comments
 (0)