Skip to content

Commit d5a0743

Browse files
authored
[ET-VK] Update partitioner to account for custom packed arguments
Differential Revision: D65759236 Pull Request resolved: #6753
1 parent 793f17e commit d5a0743

File tree

4 files changed

+73
-23
lines changed

4 files changed

+73
-23
lines changed

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
3535
if not is_param_node(program, node):
3636
return True
3737

38+
# Annotate that this node is going to represented as a tensorref in the Vulkan
39+
# compute graph. This will be useful for later graph passes.
40+
node.meta["vkdg_tensorref"] = True
41+
3842
for user in node.users:
3943
if user.op == "call_function" and handles_own_prepacking(
4044
# pyre-ignore

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,30 @@ def set_memory_metadata(
3939
utils.set_node_spec_attr(node, "vk_memory_layout", layout)
4040

4141

42+
def insert_transition_node(
43+
graph_module: torch.fx.GraphModule,
44+
node: torch.fx.Node,
45+
arg: torch.fx.Node,
46+
storage: VkStorageType,
47+
layout: VkMemoryLayout,
48+
) -> None:
49+
"""
50+
Insert a clone node to copy the original tensor to a tensor with the desired storage
51+
type and memory layout.
52+
"""
53+
with graph_module.graph.inserting_before(node):
54+
clone_node = graph_module.graph.create_node(
55+
"call_function",
56+
exir_ops.edge.aten.clone.default,
57+
(arg,),
58+
)
59+
clone_node.meta["val"] = arg.meta["val"]
60+
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
61+
clone_node.meta["spec"].const = False
62+
set_memory_metadata(clone_node, storage, layout)
63+
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)
64+
65+
4266
class TagMemoryMetaPass(ExportPass):
4367
"""
4468
There are a variety of ways that tensors can be represented in Vulkan. The two main
@@ -174,14 +198,33 @@ def propose_node_layout(
174198
else:
175199
return next(iter(valid_layouts))
176200

201+
def should_annotate(self, node) -> bool:
202+
if not isinstance(node, torch.fx.Node):
203+
return False
204+
205+
if not isinstance(node.meta["val"], FakeTensor):
206+
return False
207+
208+
# Storage type and memory layout for tensorref will be determined at runtime
209+
# so there's no use in setting those attributes ahead of time.
210+
if node.meta.get("vkdg_tensorref", False):
211+
return False
212+
213+
return True
214+
215+
def should_delay_annotation(self, node: torch.fx.Node) -> bool:
216+
# For prepack nodes, delay setting the storage type and memory layout as long as
217+
# possible. This is to minimize the number of transitions, since it can be
218+
# difficult to predict what storage type and memory layout should be used at the
219+
# time the prepack node is observed.
220+
return node.target == exir_ops.edge.et_vk.prepack.default
221+
222+
# noqa
177223
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
178224
sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))
179225

180226
for node in sorted_nodes:
181-
if not isinstance(node.meta["val"], FakeTensor):
182-
continue
183-
184-
if node.target == exir_ops.edge.et_vk.prepack.default:
227+
if not self.should_annotate(node) or self.should_delay_annotation(node):
185228
continue
186229

187230
storage = self.propose_node_storage(node)
@@ -191,11 +234,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
191234

192235
inserting_transitions_for_node = False
193236
for i, arg in enumerate(node.args):
194-
if not isinstance(arg, torch.fx.Node):
195-
continue
196-
if not isinstance(arg.meta["val"], FakeTensor):
237+
if not self.should_annotate(arg):
197238
continue
198239

240+
assert isinstance(arg, torch.fx.Node)
241+
199242
arg_storage = utils.get_node_storage_type(arg)
200243
arg_layout = utils.get_node_memory_layout(arg)
201244

@@ -215,22 +258,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
215258
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
216259
)
217260

261+
insert_transition_node(graph_module, node, arg, storage, layout)
262+
218263
logger.info(
219264
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
220265
)
221266

222-
# Insert a clone node to copy the original tensor to a tensor with the
223-
# desired storage type and memory layout.
224-
with graph_module.graph.inserting_before(node):
225-
clone_node = graph_module.graph.create_node(
226-
"call_function",
227-
exir_ops.edge.aten.clone.default,
228-
(arg,),
229-
)
230-
clone_node.meta["val"] = arg.meta["val"]
231-
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
232-
clone_node.meta["spec"].const = False
233-
set_memory_metadata(clone_node, storage, layout)
234-
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)
235-
236267
return PassResult(graph_module, True)

backends/vulkan/op_registry.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ class OpFeatures:
9090
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
9191
# of the op.
9292
"handles_own_prepacking",
93+
# Optional dictionary to specify a custom function to calculate the required
94+
# image extents for a particular argument index.
95+
"skip_limits_check",
9396
# Optional check function used during partitioning to determine if a node's
9497
# inputs are supported by the operator implementation.
9598
"check_node_fn",
@@ -103,6 +106,7 @@ def __init__(
103106
optimal_storage: Optional[VkStorageType] = None,
104107
optimal_layout: Optional[VkMemoryLayout] = None,
105108
handles_own_prepacking: bool = False,
109+
skip_limits_check: Optional[Set[int]] = None,
106110
check_node_fn: Optional[Callable] = None,
107111
):
108112
self.texture_impl: Optional[TextureImplFeatures] = texture_impl
@@ -111,6 +115,11 @@ def __init__(
111115
self.optimal_storage: Optional[VkStorageType] = optimal_storage
112116
self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout
113117
self.handles_own_prepacking: bool = handles_own_prepacking
118+
119+
self.skip_limits_check: Set[int] = set()
120+
if skip_limits_check is not None:
121+
self.skip_limits_check = skip_limits_check
122+
114123
self.check_node_fn: Callable = allow_node
115124
if check_node_fn is not None:
116125
self.check_node_fn = check_node_fn
@@ -433,6 +442,7 @@ def register_convolution_op(features: OpFeatures):
433442
features.optimal_storage = VkStorageType.TEXTURE_3D
434443
features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED
435444
features.handles_own_prepacking = True
445+
features.skip_limits_check = {1, 2}
436446
return features
437447

438448

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,13 @@ def op_node_is_compatible(
8282
valid_texture_layouts = utils.possible_node_memory_layouts(
8383
node, self.texture_limits
8484
)
85-
for arg in node.args:
86-
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):
85+
86+
for i, arg in enumerate(node.args):
87+
if (
88+
isinstance(arg, torch.fx.Node)
89+
and utils.is_tensor_node(arg)
90+
and i not in features.skip_limits_check
91+
):
8792
arg_texture_layouts = utils.possible_node_memory_layouts(
8893
arg, self.texture_limits
8994
)

0 commit comments

Comments
 (0)