Skip to content

[ET-VK][Test] aten.max_pool2d_with_indices #2548

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 62 additions & 6 deletions backends/vulkan/test/test_vulkan_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@


class TestBackends(unittest.TestCase):
def assert_outputs_equal(self, model_output, ref_output, atol=1e-03, rtol=1e-03):
def assert_outputs_equal(
self, model_output, ref_output, atol=1e-03, rtol=1e-03, first_output_only=False
):
"""
Helper testing function that asserts that the model output and the reference output
are equal with some tolerance. Due to numerical differences between eager mode and
Expand All @@ -40,10 +42,17 @@ def assert_outputs_equal(self, model_output, ref_output, atol=1e-03, rtol=1e-03)
if isinstance(ref_output, tuple) or isinstance(ref_output, list):
# Multiple outputs executor always returns tuple, even if there is one output
self.assertTrue(len(ref_output) == len(model_output))
for i in range(len(ref_output)):
if first_output_only:
self.assertTrue(
torch.allclose(model_output[i], ref_output[i], atol=atol, rtol=rtol)
torch.allclose(model_output[0], ref_output[0], atol=atol, rtol=rtol)
)
else:
for i in range(len(ref_output)):
self.assertTrue(
torch.allclose(
model_output[i], ref_output[i], atol=atol, rtol=rtol
)
)
else:
# If one output, eager returns tensor while executor tuple of size 1
self.assertTrue(
Expand All @@ -59,6 +68,7 @@ def lower_module_and_test_output(
dynamic_shapes=None,
test_inputs=None,
memory_layouts=None,
first_output_only=False,
):
"""
Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with
Expand Down Expand Up @@ -94,7 +104,13 @@ def run_test(memory_layout):
)
ref_output = model(*sample_inputs)

self.assert_outputs_equal(model_output, ref_output, atol=atol, rtol=rtol)
self.assert_outputs_equal(
model_output,
ref_output,
atol=atol,
rtol=rtol,
first_output_only=first_output_only,
)

if test_inputs is not None:
for test_input in test_inputs:
Expand All @@ -105,7 +121,11 @@ def run_test(memory_layout):
ref_output = model(*test_input)

self.assert_outputs_equal(
model_output, ref_output, atol=atol, rtol=rtol
model_output,
ref_output,
atol=atol,
rtol=rtol,
first_output_only=first_output_only,
)

memory_layouts_to_test = [
Expand All @@ -120,7 +140,8 @@ def run_test(memory_layout):
run_test(memory_layout)

def test_vulkan_backend_add(self):
# This test is the simplest test by manually lowering some submodules, we can use paritioner for auto detecting lowerable parts
# This test is the simplest test by manually lowering some submodules, we can use paritioner
# for auto detecting lowerable parts.
class AddModule(torch.nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -323,6 +344,41 @@ def forward(self, x):

self.lower_clamp_module_and_test_output(ReLUModule())

def test_vulkan_backend_max_pool2d(self):
class MaxPool2dModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.max_pool = torch.nn.MaxPool2d(
kernel_size=(2, 3),
stride=(1, 1),
padding=0,
dilation=1,
ceil_mode=False,
return_indices=True,
)

def forward(self, x):
return self.max_pool(x)

max_pool2d_module = MaxPool2dModule()
sample_inputs = (torch.randn(5, 13, 55, 68),)

batch = Dim("batch", max=8)
dynamic_shapes = {"x": {0: batch}}
test_inputs = [
(torch.randn(3, 14, 15, 9),),
(torch.randn(1, 1, 4, 6),),
(torch.randn(5, 10, 50, 40),),
]
self.lower_module_and_test_output(
max_pool2d_module,
sample_inputs,
dynamic_shapes=dynamic_shapes,
test_inputs=test_inputs,
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
first_output_only=True,
)

def test_vulkan_backend_partial(self):
class SimpleModel(torch.nn.Module):
def __init__(self):
Expand Down