Skip to content

Commit b17589e

Browse files
Merge pull request #1047 from IntelPython/chore-factor-out-broadcast_impl
CHORE: Factored out _broadcast_shape implementation
2 parents 949711e + f9413c5 commit b17589e

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,18 +48,14 @@ def _broadcast_strides(X_shape, X_strides, res_ndim):
4848
return tuple(out_strides)
4949

5050

51-
def _broadcast_shapes(*args):
52-
"""
53-
Broadcast the input shapes into a single shape;
54-
returns tuple broadcasted shape.
55-
"""
56-
shapes = [array.shape for array in args]
51+
def _broadcast_shape_impl(shapes):
5752
if len(set(shapes)) == 1:
5853
return shapes[0]
5954
mutable_shapes = False
6055
nds = [len(s) for s in shapes]
6156
biggest = max(nds)
62-
for i in range(len(args)):
57+
sh_len = len(shapes)
58+
for i in range(sh_len):
6359
diff = biggest - nds[i]
6460
if diff > 0:
6561
ty = type(shapes[i])
@@ -77,7 +73,7 @@ def _broadcast_shapes(*args):
7773
unique.remove(1)
7874
new_length = unique.pop()
7975
common_shape.append(new_length)
80-
for i in range(len(args)):
76+
for i in range(sh_len):
8177
if shapes[i][axis] == 1:
8278
if not mutable_shapes:
8379
shapes = [list(s) for s in shapes]
@@ -89,6 +85,15 @@ def _broadcast_shapes(*args):
8985
return tuple(common_shape)
9086

9187

88+
def _broadcast_shapes(*args):
89+
"""
90+
Broadcast the input shapes into a single shape;
91+
returns tuple broadcasted shape.
92+
"""
93+
array_shapes = [array.shape for array in args]
94+
return _broadcast_shape_impl(array_shapes)
95+
96+
9297
def permute_dims(X, axes):
9398
"""
9499
permute_dims(X: usm_ndarray, axes: tuple or list) -> usm_ndarray

0 commit comments

Comments
 (0)