Skip to content

Commit 6887ae9

Browse files
pytorchbotSS-JIA
andauthored
[ET-VK] Update partitioner to account for custom packed arguments (#6763)
## Problem Convolution operators, especially for pointwise convolution, may have sizes like ``` W=1, H=1, C=320, N=1280 ``` When represented as a texture, this tensor would normally require a texture with extents ``` (1, 1, 320 / 4 * 1280 = 102400) ``` which would normally exceed texture limits. The new partitioner system detects this and prevents nodes with similar weights from being lowered to Vulkan. However, the partitioner system does not account for the fact that the operator implementation uses a specialized prepacking algorithm which results in valid texture limits for the packed weights. ## Changes * Add field to `OpFeatures` class to annotate that some arguments in an op should be skipped when checking against texture limits * Update metadata tagging pass to ignore annotating constant tensor nodes so that they don't influence memory layout and storage type proposals. Without this change, the tagging pass will try to use buffer storage for the pointwise convolution since the weight can only be represented as a buffer under normal circumstances. Differential Revision: [D65759236](https://our.internmc.facebook.com/intern/diff/D65759236/) ghstack-source-id: 252885980 Pull Request resolved: #6753 Co-authored-by: Stephen Jia <[email protected]>
1 parent b8b5146 commit 6887ae9

File tree

4 files changed

+73
-23
lines changed

4 files changed

+73
-23
lines changed

backends/vulkan/_passes/insert_prepack_nodes.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def prepack_not_required(node: torch.fx.Node) -> bool:
3535
if not is_param_node(program, node):
3636
return True
3737

38+
# Annotate that this node is going to represented as a tensorref in the Vulkan
39+
# compute graph. This will be useful for later graph passes.
40+
node.meta["vkdg_tensorref"] = True
41+
3842
for user in node.users:
3943
if user.op == "call_function" and handles_own_prepacking(
4044
# pyre-ignore

backends/vulkan/_passes/tag_memory_meta_pass.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,30 @@ def set_memory_metadata(
3939
utils.set_node_spec_attr(node, "vk_memory_layout", layout)
4040

4141

42+
def insert_transition_node(
43+
graph_module: torch.fx.GraphModule,
44+
node: torch.fx.Node,
45+
arg: torch.fx.Node,
46+
storage: VkStorageType,
47+
layout: VkMemoryLayout,
48+
) -> None:
49+
"""
50+
Insert a clone node to copy the original tensor to a tensor with the desired storage
51+
type and memory layout.
52+
"""
53+
with graph_module.graph.inserting_before(node):
54+
clone_node = graph_module.graph.create_node(
55+
"call_function",
56+
exir_ops.edge.aten.clone.default,
57+
(arg,),
58+
)
59+
clone_node.meta["val"] = arg.meta["val"]
60+
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
61+
clone_node.meta["spec"].const = False
62+
set_memory_metadata(clone_node, storage, layout)
63+
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)
64+
65+
4266
class TagMemoryMetaPass(ExportPass):
4367
"""
4468
There are a variety of ways that tensors can be represented in Vulkan. The two main
@@ -174,14 +198,33 @@ def propose_node_layout(
174198
else:
175199
return next(iter(valid_layouts))
176200

201+
def should_annotate(self, node) -> bool:
202+
if not isinstance(node, torch.fx.Node):
203+
return False
204+
205+
if not isinstance(node.meta["val"], FakeTensor):
206+
return False
207+
208+
# Storage type and memory layout for tensorref will be determined at runtime
209+
# so there's no use in setting those attributes ahead of time.
210+
if node.meta.get("vkdg_tensorref", False):
211+
return False
212+
213+
return True
214+
215+
def should_delay_annotation(self, node: torch.fx.Node) -> bool:
216+
# For prepack nodes, delay setting the storage type and memory layout as long as
217+
# possible. This is to minimize the number of transitions, since it can be
218+
# difficult to predict what storage type and memory layout should be used at the
219+
# time the prepack node is observed.
220+
return node.target == exir_ops.edge.et_vk.prepack.default
221+
222+
# noqa
177223
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
178224
sorted_nodes: NodeList = topo_sort(list(graph_module.graph.nodes))
179225

180226
for node in sorted_nodes:
181-
if not isinstance(node.meta["val"], FakeTensor):
182-
continue
183-
184-
if node.target == exir_ops.edge.et_vk.prepack.default:
227+
if not self.should_annotate(node) or self.should_delay_annotation(node):
185228
continue
186229

187230
storage = self.propose_node_storage(node)
@@ -191,11 +234,11 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
191234

192235
inserting_transitions_for_node = False
193236
for i, arg in enumerate(node.args):
194-
if not isinstance(arg, torch.fx.Node):
195-
continue
196-
if not isinstance(arg.meta["val"], FakeTensor):
237+
if not self.should_annotate(arg):
197238
continue
198239

240+
assert isinstance(arg, torch.fx.Node)
241+
199242
arg_storage = utils.get_node_storage_type(arg)
200243
arg_layout = utils.get_node_memory_layout(arg)
201244

@@ -215,22 +258,10 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
215258
f"[Vulkan Delegate] Inserting transition(s) for {node.format_node()}:"
216259
)
217260

261+
insert_transition_node(graph_module, node, arg, storage, layout)
262+
218263
logger.info(
219264
f" args {i} ({arg}): ({arg_storage}, {arg_layout}) -> ({storage}, {layout})"
220265
)
221266

222-
# Insert a clone node to copy the original tensor to a tensor with the
223-
# desired storage type and memory layout.
224-
with graph_module.graph.inserting_before(node):
225-
clone_node = graph_module.graph.create_node(
226-
"call_function",
227-
exir_ops.edge.aten.clone.default,
228-
(arg,),
229-
)
230-
clone_node.meta["val"] = arg.meta["val"]
231-
clone_node.meta["spec"] = deepcopy(arg.meta["spec"])
232-
clone_node.meta["spec"].const = False
233-
set_memory_metadata(clone_node, storage, layout)
234-
arg.replace_all_uses_with(clone_node, lambda x, y=node: x == y)
235-
236267
return PassResult(graph_module, True)

backends/vulkan/op_registry.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ class OpFeatures:
9090
# then the insert_prepack_nodes pass will not insert prepack nodes for the args
9191
# of the op.
9292
"handles_own_prepacking",
93+
# Optional dictionary to specify a custom function to calculate the required
94+
# image extents for a particular argument index.
95+
"skip_limits_check",
9396
# Optional check function used during partitioning to determine if a node's
9497
# inputs are supported by the operator implementation.
9598
"check_node_fn",
@@ -103,6 +106,7 @@ def __init__(
103106
optimal_storage: Optional[VkStorageType] = None,
104107
optimal_layout: Optional[VkMemoryLayout] = None,
105108
handles_own_prepacking: bool = False,
109+
skip_limits_check: Optional[Set[int]] = None,
106110
check_node_fn: Optional[Callable] = None,
107111
):
108112
self.texture_impl: Optional[TextureImplFeatures] = texture_impl
@@ -111,6 +115,11 @@ def __init__(
111115
self.optimal_storage: Optional[VkStorageType] = optimal_storage
112116
self.optimal_layout: Optional[VkMemoryLayout] = optimal_layout
113117
self.handles_own_prepacking: bool = handles_own_prepacking
118+
119+
self.skip_limits_check: Set[int] = set()
120+
if skip_limits_check is not None:
121+
self.skip_limits_check = skip_limits_check
122+
114123
self.check_node_fn: Callable = allow_node
115124
if check_node_fn is not None:
116125
self.check_node_fn = check_node_fn
@@ -433,6 +442,7 @@ def register_convolution_op(features: OpFeatures):
433442
features.optimal_storage = VkStorageType.TEXTURE_3D
434443
features.optimal_layout = VkMemoryLayout.TENSOR_CHANNELS_PACKED
435444
features.handles_own_prepacking = True
445+
features.skip_limits_check = {1, 2}
436446
return features
437447

438448

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,13 @@ def op_node_is_compatible(
8282
valid_texture_layouts = utils.possible_node_memory_layouts(
8383
node, self.texture_limits
8484
)
85-
for arg in node.args:
86-
if isinstance(arg, torch.fx.Node) and utils.is_tensor_node(arg):
85+
86+
for i, arg in enumerate(node.args):
87+
if (
88+
isinstance(arg, torch.fx.Node)
89+
and utils.is_tensor_node(arg)
90+
and i not in features.skip_limits_check
91+
):
8792
arg_texture_layouts = utils.possible_node_memory_layouts(
8893
arg, self.texture_limits
8994
)

0 commit comments

Comments
 (0)