Skip to content

Commit e5dc18a

Browse files
authored
Fix unsqueeze optimize pass
Differential Revision: D69812661 Pull Request resolved: #8564
1 parent b6ffe1a commit e5dc18a

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

backends/transforms/view_copy_to_squeeze_unsqueeze.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ def find_unsqueeze_dim(
7575
j = 0
7676
idx = -1
7777
while j < len(view_shape):
78-
if input_shape[i] != view_shape[j]:
78+
# account for added dim being last dim in view_shape
79+
if i == j and j == len(input_shape):
80+
if view_shape[j] != 1:
81+
return None
82+
elif input_shape[i] != view_shape[j]:
7983
if view_shape[j] == 1:
8084
idx = j
8185
i -= 1

0 commit comments

Comments
 (0)