Skip to content

Commit 5670a80

Browse files
nathanaelseefacebook-github-bot
authored andcommitted
check to_copy args in vulkan_partitioner
Summary: in exir dialect, to_copy doesn't have dtype arg and it is inferred from the dtype of the output tensor. The args will be of length 1 with the sole arg being the input tensor. Thus the previous check always returns False as args is never > 1. Differential Revision: D64267104
1 parent bff26f3 commit 5670a80

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

backends/vulkan/partitioner/vulkan_partitioner.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,24 @@ def is_in_local_scalar_dense_chain(self, node: torch.fx.Node) -> bool:
144144

145145
return False
146146

147-
def is_valid_to_copy(self, node: torch.fx.node) -> bool: # pyre-ignore[11]
148-
# lower only if floating point dtype conversion
149-
return len(node.args) > 1 and node.args[1] in (torch.float32, torch.float16)
147+
def is_valid_to_copy(self, node: torch.fx.Node) -> bool:
148+
float_dtypes = [torch.float16, torch.float32]
149+
150+
if len(node.args) != 1:
151+
return False
152+
153+
in_arg = node.args[0]
154+
if not isinstance(in_arg, torch.fx.Node):
155+
return False
156+
157+
in_tensor = in_arg.meta.get("val", None)
158+
out_tensor = node.meta.get("val", None)
159+
160+
if isinstance(in_tensor, FakeTensor) and isinstance(out_tensor, FakeTensor):
161+
if out_tensor.dtype in float_dtypes and in_tensor.dtype in float_dtypes:
162+
return True
163+
164+
return False
150165

151166
def is_node_supported(
152167
self, submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node
@@ -174,13 +189,13 @@ def _is_node_supported(
174189
if target not in VulkanSupportedOperators._ops:
175190
return False
176191

177-
features = VulkanSupportedOperators._ops[target]
178-
179192
if target == exir_ops.edge.aten._to_copy.default and not self.is_valid_to_copy(
180193
node
181194
):
182195
return False
183196

197+
features = VulkanSupportedOperators._ops[target]
198+
184199
if self.require_dynamic_shapes and not features.supports_dynamic_shape:
185200
return False
186201

0 commit comments

Comments
 (0)