Skip to content

Commit bced356

Browse files
Yujie Huifacebook-github-bot
authored andcommitted
register nn.Module test for upsample_nearest2d
Summary: Register nn.Module test for ate.upsample_nearest2d.vec. Need to enable `not_decompose` to test this op. Differential Revision: D58907319
1 parent 50f907d commit bced356

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

backends/vulkan/partitioner/supported_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def __contains__(self, op):
125125
exir_ops.edge.aten.ones_like.default,
126126
exir_ops.edge.aten.zeros.default,
127127
exir_ops.edge.aten.zeros_like.default,
128+
exir_ops.edge.aten.upsample_nearest2d.vec,
128129
]
129130

130131

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,6 +1068,24 @@ def forward(self, x):
10681068
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
10691069
)
10701070

1071+
def test_vulkan_backend_upsample_nearest2d(self):
1072+
class UpsampleNearest2d(torch.nn.Module):
1073+
def __init__(self):
1074+
super().__init__()
1075+
self.upsample = torch.nn.UpsamplingNearest2d(scale_factor=2)
1076+
1077+
def forward(self, x):
1078+
return self.upsample(x)
1079+
1080+
sample_inputs = (torch.arange(1, 5, dtype=torch.float32).view(1, 1, 2, 2),)
1081+
1082+
self.lower_module_and_test_output(
1083+
UpsampleNearest2d(),
1084+
sample_inputs,
1085+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
1086+
not_decompose=True,
1087+
)
1088+
10711089
def test_vulkan_backend_reshape(self):
10721090
class ReshapeModule(torch.nn.Module):
10731091
def __init__(self):

0 commit comments

Comments
 (0)