Skip to content

Commit 9fdd5a3

Browse files
Yujie Huifacebook-github-bot
authored andcommitted
register nn.Module test for upsample_nearest2d (#4055)
Summary: Pull Request resolved: #4055 Register nn.Module test for ate.upsample_nearest2d.vec Reviewed By: jorgep31415 Differential Revision: D58907319
1 parent f538eae commit 9fdd5a3

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __contains__(self, op):
125125
exir_ops.edge.aten.ones_like.default,
126126
exir_ops.edge.aten.zeros.default,
127127
exir_ops.edge.aten.zeros_like.default,
128+
exir_ops.edge.aten.upsample_nearest2d.vec,
128129
]
129130

130131

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,6 +1057,23 @@ def forward(self, x):
10571057
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
10581058
)
10591059

1060+
def test_vulkan_backend_upsample_nearest2d(self):
1061+
class UpsampleNearest2d(torch.nn.Module):
1062+
def __init__(self):
1063+
super().__init__()
1064+
self.upsample = torch.nn.Upsample(scale_factor=2, mode="nearest")
1065+
1066+
def forward(self, x):
1067+
return self.upsample(x)
1068+
1069+
sample_inputs = (torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2),)
1070+
1071+
self.lower_module_and_test_output(
1072+
UpsampleNearest2d(),
1073+
sample_inputs,
1074+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1075+
)
1076+
10601077
def test_vulkan_backend_reshape(self):
10611078
class ReshapeModule(torch.nn.Module):
10621079
def __init__(self):

0 commit comments

Comments
 (0)