Skip to content

Commit f3d5b32

Browse files
committed
[ET-VK] Don't specify memory layouts when testing
As title. Now that we have a memory metadata tagging pass that automatically determines the optimal memory layout to use for operators, there is no need to specify what memory layout to test in the Python export tests. There were some issues with the memory metadata tagging pass when dealing with nodes that contain tensor lists, which have been fixed as part of this diff as well. Differential Revision: [D67180897](https://our.internmc.facebook.com/intern/diff/D67180897/) ghstack-source-id: 258028337 Pull Request resolved: #7322
1 parent 0b1c1e5 commit f3d5b32

File tree

3 files changed

+43
-86
lines changed

3 files changed

+43
-86
lines changed

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,6 @@
2323

2424
from executorch.exir.pass_base import ExportPass, PassResult
2525

26-
from torch._subclasses.fake_tensor import FakeTensor
27-
2826
from torch.fx.passes.tools_common import NodeList
2927
from torch.fx.passes.utils.fuser_utils import topo_sort
3028

@@ -138,9 +136,7 @@ def propose_node_storage(
138136
return storage
139137

140138
for arg in node.args:
141-
if isinstance(arg, torch.fx.Node) and isinstance(
142-
arg.meta["val"], FakeTensor
143-
):
139+
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):
144140
storage = utils.get_node_storage_type(arg)
145141
if storage is not None and storage in valid_storage_types:
146142
return storage
@@ -178,9 +174,7 @@ def propose_node_layout(
178174
return layout
179175

180176
for arg in node.args:
181-
if isinstance(arg, torch.fx.Node) and isinstance(
182-
arg.meta["val"], FakeTensor
183-
):
177+
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):
184178
layout = utils.get_node_memory_layout(arg)
185179
if layout is not None and layout in valid_layouts:
186180
return layout
@@ -202,14 +196,19 @@ def should_annotate(self, node) -> bool:
202196
if not isinstance(node, torch.fx.Node):
203197
return False
204198

205-
if not isinstance(node.meta["val"], FakeTensor):
199+
if not utils.is_tensor_node(node):
206200
return False
207201

208202
# Storage type and memory layout for tensorref will be determined at runtime
209203
# so there's no use in setting those attributes ahead of time.
210204
if node.meta.get("vkdg_tensorref", False):
211205
return False
212206

207+
# Skip annotating output node. The output tensors should be annotated by the
208+
# time the output node is observed.
209+
if node.op == "output":
210+
return False
211+
213212
return True
214213

215214
def should_delay_annotation(self, node: torch.fx.Node) -> bool:

0 commit comments

Comments
 (0)