Skip to content

Commit 24483d2

Browse files
author
Nathanael See
committed
[ET-VK] fix index error bug in ViewCopyToSqueezeUnsqueezePass
See T214560872 #8226 added the pass to the partition preprocess pass list, so now it runs on all exports. This uncovered a bug in the squeeze dims finding function in the mobilenet test case. Differential Revision: [D69254910](https://our.internmc.facebook.com/intern/diff/D69254910/) [ghstack-poisoned]
1 parent dad2ba0 commit 24483d2

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

backends/transforms/view_copy_to_squeeze_unsqueeze.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,18 @@ def find_squeeze_dims(
4747
j = 0
4848
idx = []
4949
while i < len(input_shape):
50-
if input_shape[i] != view_shape[j]:
51-
if input_shape[i] == 1:
52-
idx.append(i)
53-
j -= 1
54-
# continue to check remaining dims are equal
55-
else:
56-
return None
57-
i += 1
58-
j += 1
50+
if j < len(view_shape) and input_shape[i] == view_shape[j]:
51+
i += 1
52+
j += 1
53+
elif input_shape[i] == 1:
54+
# squeeze axis on i and check next dim
55+
idx.append(i)
56+
i += 1
57+
else:
58+
return None
59+
# If there are remaining dimensions in view_shape, shapes do not match
60+
if j < len(view_shape):
61+
return None
5962
return idx
6063

6164
def find_unsqueeze_dim(

0 commit comments

Comments
 (0)