@@ -48,18 +48,14 @@ def _broadcast_strides(X_shape, X_strides, res_ndim):
48
48
return tuple (out_strides )
49
49
50
50
51
- def _broadcast_shapes (* args ):
52
- """
53
- Broadcast the input shapes into a single shape;
54
- returns tuple broadcasted shape.
55
- """
56
- shapes = [array .shape for array in args ]
51
+ def _broadcast_shape_impl (shapes ):
57
52
if len (set (shapes )) == 1 :
58
53
return shapes [0 ]
59
54
mutable_shapes = False
60
55
nds = [len (s ) for s in shapes ]
61
56
biggest = max (nds )
62
- for i in range (len (args )):
57
+ sh_len = len (shapes )
58
+ for i in range (sh_len ):
63
59
diff = biggest - nds [i ]
64
60
if diff > 0 :
65
61
ty = type (shapes [i ])
@@ -77,7 +73,7 @@ def _broadcast_shapes(*args):
77
73
unique .remove (1 )
78
74
new_length = unique .pop ()
79
75
common_shape .append (new_length )
80
- for i in range (len ( args ) ):
76
+ for i in range (sh_len ):
81
77
if shapes [i ][axis ] == 1 :
82
78
if not mutable_shapes :
83
79
shapes = [list (s ) for s in shapes ]
@@ -89,6 +85,15 @@ def _broadcast_shapes(*args):
89
85
return tuple (common_shape )
90
86
91
87
88
+ def _broadcast_shapes (* args ):
89
+ """
90
+ Broadcast the input shapes into a single shape;
91
+ returns tuple broadcasted shape.
92
+ """
93
+ array_shapes = [array .shape for array in args ]
94
+ return _broadcast_shape_impl (array_shapes )
95
+
96
+
92
97
def permute_dims (X , axes ):
93
98
"""
94
99
permute_dims(X: usm_ndarray, axes: tuple or list) -> usm_ndarray
0 commit comments