Skip to content

Commit b4f7994

Browse files
authored
[ET-VK][ez] Add back tensor dim check
Differential Revision: D68630966 Pull Request resolved: #7938
1 parent 57ef834 commit b4f7994

File tree

3 files changed

+54
-4
lines changed

3 files changed

+54
-4
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def op_node_is_compatible(
8383
return False, "no operator implementation"
8484
features = get_op_features(target)
8585

86+
# Check for high dimensional tensors
87+
if utils.is_tensor_node(node) and utils.tensor_node_is_high_dim(node):
88+
return False, "contains high dim tensor"
89+
8690
valid_texture_layouts = utils.possible_node_memory_layouts(
8791
node, self.texture_limits
8892
)
@@ -94,6 +98,10 @@ def op_node_is_compatible(
9498
and utils.is_tensor_node(arg)
9599
and i not in features.skip_limits_check
96100
):
101+
# Check for high dimensional tensors
102+
if utils.is_tensor_node(arg) and utils.tensor_node_is_high_dim(arg):
103+
return False, "contains high dim tensor"
104+
97105
arg_texture_layouts = utils.possible_node_memory_layouts(
98106
arg, self.texture_limits
99107
)

backends/vulkan/test/test_vulkan_delegate.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def lower_module_and_test_output(
9797
dynamic_shapes=None,
9898
test_inputs=None,
9999
first_output_only=False,
100+
expect_no_delegates=False,
100101
):
101102
"""
102103
Helper testing function that takes a torch.nn.Module and lowers it to Vulkan with
@@ -125,10 +126,23 @@ def run_test():
125126
)
126127
executorch_program = edge_program.to_executorch()
127128

128-
self.assertEqual(
129-
executorch_program.executorch_program.execution_plan[0].delegates[0].id,
130-
VulkanBackend.__name__,
131-
)
129+
if expect_no_delegates:
130+
self.assertEqual(
131+
len(
132+
executorch_program.executorch_program.execution_plan[
133+
0
134+
].delegates
135+
),
136+
0,
137+
)
138+
return
139+
else:
140+
self.assertEqual(
141+
executorch_program.executorch_program.execution_plan[0]
142+
.delegates[0]
143+
.id,
144+
VulkanBackend.__name__,
145+
)
132146

133147
executorch_module = _load_for_executorch_from_buffer(
134148
executorch_program.buffer
@@ -1683,3 +1697,17 @@ def forward(self, x):
16831697
GridPriorsModule(),
16841698
(torch.rand(size=[1, 5, 2, 3]),),
16851699
)
1700+
1701+
def test_vulkan_backend_high_dim_tensors_fail(self):
1702+
class UnsqueezeHigherDim(torch.nn.Module):
1703+
def __init__(self):
1704+
super().__init__()
1705+
1706+
def forward(self, x):
1707+
return torch.unsqueeze(x, 2)
1708+
1709+
self.lower_module_and_test_output(
1710+
UnsqueezeHigherDim(),
1711+
(torch.ones(size=[5, 4, 1, 2, 6]),),
1712+
expect_no_delegates=True,
1713+
)

backends/vulkan/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,20 @@ def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int:
130130
raise RuntimeError(f"Cannot get numel for val of type {type(node.meta['val'])}")
131131

132132

133+
def tensor_node_is_high_dim(node: torch.fx.Node) -> bool:
134+
"""
135+
Returns true if a given node contains a tensor with more than 4 dimensions
136+
"""
137+
if isinstance(node.meta["val"], FakeTensor):
138+
return len(node.meta["val"].shape) > 4
139+
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
140+
for fake_tensor in node.meta["val"]:
141+
if isinstance(fake_tensor, FakeTensor):
142+
if len(fake_tensor.shape) > 4:
143+
return True
144+
return False
145+
146+
133147
def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents:
134148
"""
135149
Calculate the image extents that will be used to represent a tensor with the given sizes

0 commit comments

Comments
 (0)