Skip to content

Commit 566c7bb

Browse files
committed
[ET-VK][ez] Add back tensor dim check
## Context Vulkan cannot represent higher dimensional tensors (tensors with dim > 4) at the moment, but due to some refactors implemented last year the partitioner check to avoid lowering ops that involve high dimensional tensors was accidentally removed. This diff adds back the check, as well as a test to verify that high dimensional tensors do not get lowered. Differential Revision: [D68630966](https://our.internmc.facebook.com/intern/diff/D68630966/) [ghstack-poisoned]
1 parent 57ef834 commit 566c7bb

File tree

3 files changed

+58
-4
lines changed

3 files changed

+58
-4
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def op_node_is_compatible(
8383
return False, "no operator implementation"
8484
features = get_op_features(target)
8585

86+
# Check for high dimensional tensors
87+
if utils.tensor_node_is_high_dim(node):
88+
return False, "contains high dim tensor"
89+
8690
valid_texture_layouts = utils.possible_node_memory_layouts(
8791
node, self.texture_limits
8892
)
@@ -94,6 +98,10 @@ def op_node_is_compatible(
9498
and utils.is_tensor_node(arg)
9599
and i not in features.skip_limits_check
96100
):
101+
# Check for high dimensional tensors
102+
if utils.tensor_node_is_high_dim(arg):
103+
return False, "contains high dim tensor"
104+
97105
arg_texture_layouts = utils.possible_node_memory_layouts(
98106
arg, self.texture_limits
99107
)

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def lower_module_and_test_output(
9797
dynamic_shapes=None,
9898
test_inputs=None,
9999
first_output_only=False,
100+
expect_no_delegates=False,
100101
):
101102
"""
102103
Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with
@@ -125,10 +126,23 @@ def run_test():
125126
)
126127
executorch_program = edge_program.to_executorch()
127128

128-
self.assertEqual(
129-
executorch_program.executorch_program.execution_plan[0].delegates[0].id,
130-
VulkanBackend.__name__,
131-
)
129+
if expect_no_delegates:
130+
self.assertEqual(
131+
len(
132+
executorch_program.executorch_program.execution_plan[
133+
0
134+
].delegates
135+
),
136+
0,
137+
)
138+
return
139+
else:
140+
self.assertEqual(
141+
executorch_program.executorch_program.execution_plan[0]
142+
.delegates[0]
143+
.id,
144+
VulkanBackend.__name__,
145+
)
132146

133147
executorch_module = _load_for_executorch_from_buffer(
134148
executorch_program.buffer
@@ -1683,3 +1697,17 @@ def forward(self, x):
16831697
GridPriorsModule(),
16841698
(torch.rand(size=[1, 5, 2, 3]),),
16851699
)
1700+
1701+
def test_vulkan_backend_high_dim_tensors_fail(self):
1702+
class UnsqueezeHigherDim(torch.nn.Module):
1703+
def __init__(self):
1704+
super().__init__()
1705+
1706+
def forward(self, x):
1707+
return torch.unsqueeze(x, 2)
1708+
1709+
self.lower_module_and_test_output(
1710+
UnsqueezeHigherDim(),
1711+
(torch.ones(size=[5, 4, 1, 2, 6]),),
1712+
expect_no_delegates=True,
1713+
)

backends/vulkan/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,24 @@ def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int:
130130
raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}")
131131

132132

133+
def tensor_node_is_high_dim(node: torch.fx.Node) -> bool:
134+
"""
135+
If the node does not contain a tensor or a collection of tensors, return False.
136+
Otherwise, return True if the tensor is high dimensional (i.e. rank > 4).
137+
"""
138+
if is_tensor_node(node):
139+
if isinstance(node.meta["val"], FakeTensor):
140+
return len(node.meta["val"].shape) > 4
141+
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
142+
for fake_tensor in node.meta["val"]:
143+
if isinstance(fake_tensor, FakeTensor):
144+
if len(fake_tensor.shape) > 4:
145+
return True
146+
return False
147+
else:
148+
return False
149+
150+
133151
def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents:
134152
"""
135153
Calculate the image extents that will be used to represent a tensor with the given sizes

0 commit comments

Comments
 (0)