Skip to content

Commit ab55b29

Browse files
committed
Update on "[ET-VK][ez] Add back tensor dim check"
## Context Vulkan cannot represent higher dimensional tensors (tensors with dim > 4) at the moment, but due to some refactors implemented last year the partitioner check to avoid lowering ops that involve high dimensional tensors was accidentally removed. This diff adds back the check, as well as a test to verify that high dimensional tensors do not get lowered. Differential Revision: [D68630966](https://our.internmc.facebook.com/intern/diff/D68630966/) [ghstack-poisoned]
1 parent 566c7bb commit ab55b29

File tree

2 files changed

+12
-16
lines changed

2 files changed

+12
-16
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def op_node_is_compatible(
8484
features = get_op_features(target)
8585

8686
# Check for high dimensional tensors
87-
if utils.tensor_node_is_high_dim(node):
87+
if utils.is_tensor_node(node) and utils.tensor_node_is_high_dim(node):
8888
return False, "contains high dim tensor"
8989

9090
valid_texture_layouts = utils.possible_node_memory_layouts(
@@ -99,7 +99,7 @@ def op_node_is_compatible(
9999
and i not in features.skip_limits_check
100100
):
101101
# Check for high dimensional tensors
102-
if utils.tensor_node_is_high_dim(arg):
102+
if utils.is_tensor_node(arg) and utils.tensor_node_is_high_dim(arg):
103103
return False, "contains high dim tensor"
104104

105105
arg_texture_layouts = utils.possible_node_memory_layouts(

backends/vulkan/utils.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -132,20 +132,16 @@ def within_buffer_limit(node: torch.fx.Node, buffer_limit: int) -> int:
132132

133133
def tensor_node_is_high_dim(node: torch.fx.Node) -> bool:
134134
"""
135-
If the node does not contain a tensor or a collection of tensors, return False.
136-
Otherwise, return True if the tensor is high dimensional (i.e. rank > 4).
137-
"""
138-
if is_tensor_node(node):
139-
if isinstance(node.meta["val"], FakeTensor):
140-
return len(node.meta["val"].shape) > 4
141-
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
142-
for fake_tensor in node.meta["val"]:
143-
if isinstance(fake_tensor, FakeTensor):
144-
if len(fake_tensor.shape) > 4:
145-
return True
146-
return False
147-
else:
148-
return False
135+
Returns true if a given node contains a tensor with more than 4 dimensions
136+
"""
137+
if isinstance(node.meta["val"], FakeTensor):
138+
return len(node.meta["val"].shape) > 4
139+
if isinstance(node.meta["val"], list) or isinstance(node.meta["val"], tuple):
140+
for fake_tensor in node.meta["val"]:
141+
if isinstance(fake_tensor, FakeTensor):
142+
if len(fake_tensor.shape) > 4:
143+
return True
144+
return False
149145

150146

151147
def required_image_extents(sizes: torch.Size, layout: VkMemoryLayout) -> ImageExtents:

0 commit comments

Comments
 (0)