Skip to content

Commit 85231f5

Browse files
authored
fix index error bug in ViewCopyToSqueezeUnsqueezePass
Differential Revision: D69477266 Pull Request resolved: #8386
1 parent f438da8 commit 85231f5

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

backends/transforms/view_copy_to_squeeze_unsqueeze.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,16 +46,19 @@ def find_squeeze_dims(
4646
i = 0
4747
j = 0
4848
idx = []
49-
while i < len(input_shape):
50-
if input_shape[i] != view_shape[j]:
51-
if input_shape[i] == 1:
52-
idx.append(i)
53-
j -= 1
54-
# continue to check remaining dims are equal
55-
else:
56-
return None
57-
i += 1
58-
j += 1
49+
while i < len(input_shape) and j < len(view_shape):
50+
if input_shape[i] == view_shape[j]:
51+
i += 1
52+
j += 1
53+
elif input_shape[i] == 1:
54+
# squeeze axis on i and check next dim
55+
idx.append(i)
56+
i += 1
57+
else:
58+
return None
59+
# If there are remaining dimensions, shapes do not match
60+
if i < len(input_shape) or j < len(view_shape):
61+
return None
5962
return idx
6063

6164
def find_unsqueeze_dim(

0 commit comments

Comments
 (0)