Skip to content

Commit 8dd7294

Browse files
committed
[ET-VK][Test] aten.max_pool2d_with_indices
Pull Request resolved: #2548 Due to the below issues, we only check equality of the output tensor and not the index tensor. 1. We can't verify index tensors since VK-float16 vs CPU-float32 deltas can change which index in a pool is the maximum. That can yield completely different integers in the index tensor. Hence, we only verify the output tensor not the index tensor. 2. To actually visualize the index tensor, we need to re-construct the int32 values from the int64 values. Since the `torch.int64` index tensor is serialized as `int32` in Vulkan, Python expects int64 but C++ writes to the buffer as though it is for int32. Hence, we must apply some computation to re-construct the tensor. See below for details. A helper function was included in an earlier version of this change, but was removed for conciseness since we aren't checking that index tensor anyway. For example, if the first and second elements return 16 and 17, we get this value as the first element: ``` 73014444048 = 1000100000000000000000000000000010000 ``` We must split this int64 into two int32 values, and construct a new tensor accordingly. ``` 10001 | 00000000000000000000000000010000 10001 | 10000 17 | 16 ``` ghstack-source-id: 219506266 @exported-using-ghexport Differential Revision: [D54962492](https://our.internmc.facebook.com/intern/diff/D54962492/)
1 parent 4ec0852 commit 8dd7294

File tree

1 file changed

+62
-6
lines changed

1 file changed

+62
-6
lines changed

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@
2828

2929

3030
class TestBackends(unittest.TestCase):
31-
def assert_outputs_equal(self, model_output, ref_output, atol=1e-03, rtol=1e-03):
31+
def assert_outputs_equal(
32+
self, model_output, ref_output, atol=1e-03, rtol=1e-03, first_output_only=False
33+
):
3234
"""
3335
Helper testing function that asserts that the model output and the reference output
3436
are equal with some tolerance. Due to numerical differences between eager mode and
@@ -40,10 +42,17 @@ def assert_outputs_equal(self, model_output, ref_output, atol=1e-03, rtol=1e-03)
4042
if isinstance(ref_output, tuple) or isinstance(ref_output, list):
4143
# Multiple outputs executor always returns tuple, even if there is one output
4244
self.assertTrue(len(ref_output) == len(model_output))
43-
for i in range(len(ref_output)):
45+
if first_output_only:
4446
self.assertTrue(
45-
torch.allclose(model_output[i], ref_output[i], atol=atol, rtol=rtol)
47+
torch.allclose(model_output[0], ref_output[0], atol=atol, rtol=rtol)
4648
)
49+
else:
50+
for i in range(len(ref_output)):
51+
self.assertTrue(
52+
torch.allclose(
53+
model_output[i], ref_output[i], atol=atol, rtol=rtol
54+
)
55+
)
4756
else:
4857
# If one output, eager returns tensor while executor tuple of size 1
4958
self.assertTrue(
@@ -59,6 +68,7 @@ def lower_module_and_test_output(
5968
dynamic_shapes=None,
6069
test_inputs=None,
6170
memory_layouts=None,
71+
first_output_only=False,
6272
):
6373
"""
6474
Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with
@@ -94,7 +104,13 @@ def run_test(memory_layout):
94104
)
95105
ref_output = model(*sample_inputs)
96106

97-
self.assert_outputs_equal(model_output, ref_output, atol=atol, rtol=rtol)
107+
self.assert_outputs_equal(
108+
model_output,
109+
ref_output,
110+
atol=atol,
111+
rtol=rtol,
112+
first_output_only=first_output_only,
113+
)
98114

99115
if test_inputs is not None:
100116
for test_input in test_inputs:
@@ -105,7 +121,11 @@ def run_test(memory_layout):
105121
ref_output = model(*test_input)
106122

107123
self.assert_outputs_equal(
108-
model_output, ref_output, atol=atol, rtol=rtol
124+
model_output,
125+
ref_output,
126+
atol=atol,
127+
rtol=rtol,
128+
first_output_only=first_output_only,
109129
)
110130

111131
memory_layouts_to_test = [
@@ -120,7 +140,8 @@ def run_test(memory_layout):
120140
run_test(memory_layout)
121141

122142
def test_vulkan_backend_add(self):
123-
# This test is the simplest test by manually lowering some submodules, we can use paritioner for auto detecting lowerable parts
143+
# This test is the simplest test by manually lowering some submodules, we can use paritioner
144+
# for auto detecting lowerable parts.
124145
class AddModule(torch.nn.Module):
125146
def __init__(self):
126147
super().__init__()
@@ -323,6 +344,41 @@ def forward(self, x):
323344

324345
self.lower_clamp_module_and_test_output(ReLUModule())
325346

347+
def test_vulkan_backend_max_pool2d(self):
348+
class MaxPool2dModule(torch.nn.Module):
349+
def __init__(self):
350+
super().__init__()
351+
self.max_pool = torch.nn.MaxPool2d(
352+
kernel_size=(2, 3),
353+
stride=(1, 1),
354+
padding=0,
355+
dilation=1,
356+
ceil_mode=False,
357+
return_indices=True,
358+
)
359+
360+
def forward(self, x):
361+
return self.max_pool(x)
362+
363+
max_pool2d_module = MaxPool2dModule()
364+
sample_inputs = (torch.randn(5, 13, 55, 68),)
365+
366+
batch = Dim("batch", max=8)
367+
dynamic_shapes = {"x": {0: batch}}
368+
test_inputs = [
369+
(torch.randn(3, 14, 15, 9),),
370+
(torch.randn(1, 1, 4, 6),),
371+
(torch.randn(5, 10, 50, 40),),
372+
]
373+
self.lower_module_and_test_output(
374+
max_pool2d_module,
375+
sample_inputs,
376+
dynamic_shapes=dynamic_shapes,
377+
test_inputs=test_inputs,
378+
memory_layouts=[vk_graph_schema.VkMemoryLayout.TENSOR_CHANNELS_PACKED],
379+
first_output_only=True,
380+
)
381+
326382
def test_vulkan_backend_partial(self):
327383
class SimpleModel(torch.nn.Module):
328384
def __init__(self):

0 commit comments

Comments
 (0)