We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f438da8 commit 85231f5Copy full SHA for 85231f5
backends/transforms/view_copy_to_squeeze_unsqueeze.py
@@ -46,16 +46,19 @@ def find_squeeze_dims(
46
i = 0
47
j = 0
48
idx = []
49
- while i < len(input_shape):
50
- if input_shape[i] != view_shape[j]:
51
- if input_shape[i] == 1:
52
- idx.append(i)
53
- j -= 1
54
- # continue to check remaining dims are equal
55
- else:
56
- return None
57
- i += 1
58
- j += 1
+ while i < len(input_shape) and j < len(view_shape):
+ if input_shape[i] == view_shape[j]:
+ i += 1
+ j += 1
+ elif input_shape[i] == 1:
+ # squeeze axis on i and check next dim
+ idx.append(i)
+ else:
+ return None
59
+ # If there are remaining dimensions, shapes do not match
60
+ if i < len(input_shape) or j < len(view_shape):
61
62
return idx
63
64
def find_unsqueeze_dim(
0 commit comments