Skip to content

Commit b317063

Browse files
dulinrileyfacebook-github-bot
authored andcommitted
Fix list assumption about permute arguments
Summary: Permute nodes can have either a list or tuple as the second argument, and the debug handle may or may not exist on dq nodes. Make these passes a bit more robust in what they accept. Reviewed By: Vysarat Differential Revision: D67765363
1 parent 3ef78ee commit b317063

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

backends/cadence/aot/compiler_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def get_transposed_dims(node: torch.fx.Node, dims: List[int]) -> List[int]:
129129

130130

131131
# Capture the effect of permute op on incoming dimension order
132-
def get_permuted_dims(node: torch.fx.Node, dims: Optional[List[int]]) -> List[int]:
132+
def get_permuted_dims(node: torch.fx.Node, dims: Optional[Sequence[int]]) -> List[int]:
133133
"""
134134
Given a permute node, and the incoming dimension ordering of the input
135135
tensor to the permute node, return the net effect of permute op on the
@@ -138,7 +138,7 @@ def get_permuted_dims(node: torch.fx.Node, dims: Optional[List[int]]) -> List[in
138138
assert node.target == exir_ops.edge.aten.permute_copy.default
139139
# Permute each index of the dimension ordering (dims)
140140
permute_dims = node.args[1]
141-
assert isinstance(permute_dims, List)
141+
assert isinstance(permute_dims, (tuple, list))
142142
assert all(isinstance(x, int) for x in permute_dims)
143143
# If the dims is empty, we can simply return the permute order
144144
if not dims:

backends/cadence/aot/reorder_ops.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,9 +438,9 @@ def postpone_dequantize_op(self, graph_module: torch.fx.GraphModule) -> bool:
438438
args=(user, *node.args[1:]),
439439
)
440440
dequant_node.meta = user.meta.copy()
441-
# Remove meta["debug_handle"] on new node. Reassign it at the
442-
# caller level by calling generate_missing_debug_handles
443-
dequant_node.meta.pop("debug_handle")
441+
# Remove meta["debug_handle"] on new node if it exists.
442+
# Reassign it at the caller level by calling generate_missing_debug_handles
443+
dequant_node.meta.pop("debug_handle", None)
444444
user.replace_all_uses_with(dequant_node)
445445
dequant_node.args = (user, *node.args[1:])
446446

0 commit comments

Comments
 (0)