Skip to content

Commit 90677e8

Browse files
authored
Merge branch 'master' into impl_insert
2 parents 23bf028 + 9086f45 commit 90677e8

File tree

3 files changed

+62
-2
lines changed

3 files changed

+62
-2
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ class InsertDeleteParams(NamedTuple):
8080
"atleast_2d",
8181
"atleast_3d",
8282
"broadcast_arrays",
83+
"broadcast_shapes",
8384
"broadcast_to",
8485
"can_cast",
8586
"column_stack",
@@ -1098,6 +1099,41 @@ def broadcast_arrays(*args, subok=False):
10981099
return [dpnp_array._create_from_usm_ndarray(a) for a in usm_arrays]
10991100

11001101

1102+
def broadcast_shapes(*args):
1103+
"""
1104+
Broadcast the input shapes into a single shape.
1105+
1106+
For full documentation refer to :obj:`numpy.broadcast_shapes`.
1107+
1108+
Parameters
1109+
----------
1110+
*args : tuples of ints, or ints
1111+
The shapes to be broadcast against each other.
1112+
1113+
Returns
1114+
-------
1115+
tuple
1116+
Broadcasted shape.
1117+
1118+
See Also
1119+
--------
1120+
:obj:`dpnp.broadcast_arrays` : Broadcast any number of arrays against
1121+
each other.
1122+
:obj:`dpnp.broadcast_to` : Broadcast an array to a new shape.
1123+
1124+
Examples
1125+
--------
1126+
>>> import dpnp as np
1127+
>>> np.broadcast_shapes((1, 2), (3, 1), (3, 2))
1128+
(3, 2)
1129+
>>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7))
1130+
(5, 6, 7)
1131+
1132+
"""
1133+
1134+
return numpy.broadcast_shapes(*args)
1135+
1136+
11011137
# pylint: disable=redefined-outer-name
11021138
def broadcast_to(array, /, shape, subok=False):
11031139
"""

dpnp/dpnp_iface_mathematical.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -994,8 +994,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
994994
a_shape = a.shape
995995
b_shape = b.shape
996996

997-
# TODO: replace with dpnp.broadcast_shapes once implemented
998-
res_shape = numpy.broadcast_shapes(a_shape[:-1], b_shape[:-1])
997+
res_shape = dpnp.broadcast_shapes(a_shape[:-1], b_shape[:-1])
999998
if a_shape[:-1] != res_shape:
1000999
a = dpnp.broadcast_to(a, res_shape + (a_shape[-1],))
10011000
a_shape = a.shape

tests/test_manipulation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,31 @@ def test_no_copy(self):
332332
assert_array_equal(b, a)
333333

334334

335+
class TestBroadcast:
336+
@pytest.mark.parametrize(
337+
"shape",
338+
[
339+
[(1,), (3,)],
340+
[(1, 3), (3, 3)],
341+
[(3, 1), (3, 3)],
342+
[(1, 3), (3, 1)],
343+
[(1, 1), (3, 3)],
344+
[(1, 1), (1, 3)],
345+
[(1, 1), (3, 1)],
346+
[(1, 0), (0, 0)],
347+
[(0, 1), (0, 0)],
348+
[(1, 0), (0, 1)],
349+
[(1, 1), (0, 0)],
350+
[(1, 1), (1, 0)],
351+
[(1, 1), (0, 1)],
352+
],
353+
)
354+
def test_broadcast_shapes(self, shape):
355+
expected = numpy.broadcast_shapes(*shape)
356+
result = dpnp.broadcast_shapes(*shape)
357+
assert_equal(result, expected)
358+
359+
335360
class TestDelete:
336361
@pytest.mark.parametrize(
337362
"obj", [slice(0, 4, 2), 3, [2, 3]], ids=["slice", "int", "list"]

0 commit comments

Comments
 (0)