Skip to content

Commit 48e47e7

Browse files
committed
Add dpnp.broadcast_shapes implementation
1 parent 687b8ea commit 48e47e7

File tree

3 files changed

+65
-2
lines changed

3 files changed

+65
-2
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
"atleast_2d",
6464
"atleast_3d",
6565
"broadcast_arrays",
66+
"broadcast_shapes",
6667
"broadcast_to",
6768
"can_cast",
6869
"column_stack",
@@ -967,6 +968,44 @@ def broadcast_arrays(*args, subok=False):
967968
return [dpnp_array._create_from_usm_ndarray(a) for a in usm_arrays]
968969

969970

971+
def broadcast_shapes(*args):
972+
"""
973+
Broadcast the input shapes into a single shape.
974+
975+
For full documentation refer to :obj:`numpy.broadcast_shapes`.
976+
977+
Parameters
978+
----------
979+
*args : tuples of ints, or ints
980+
The shapes to be broadcast against each other.
981+
982+
Returns
983+
-------
984+
tuple
985+
Broadcasted shape.
986+
987+
See Also
988+
--------
989+
:obj:`dpnp.broadcast_arrays` : Broadcast any number of arrays against
990+
each other.
991+
:obj:`dpnp.broadcast_to` : Broadcast an array to a new shape.
992+
993+
Examples
994+
--------
995+
>>> import dpnp as np
996+
>>> np.broadcast_shapes((1, 2), (3, 1), (3, 2))
997+
(3, 2)
998+
>>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7))
999+
(5, 6, 7)
1000+
1001+
"""
1002+
1003+
if hasattr(numpy, "broadcast_shapes"):
1004+
return numpy.broadcast_shapes(*args)
1005+
1006+
return dpt._broadcast_shapes(*args)
1007+
1008+
9701009
# pylint: disable=redefined-outer-name
9711010
def broadcast_to(array, /, shape, subok=False):
9721011
"""

dpnp/dpnp_iface_mathematical.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -994,8 +994,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
994994
a_shape = a.shape
995995
b_shape = b.shape
996996

997-
# TODO: replace with dpnp.broadcast_shapes once implemented
998-
res_shape = numpy.broadcast_shapes(a_shape[:-1], b_shape[:-1])
997+
res_shape = dpnp.broadcast_shapes(a_shape[:-1], b_shape[:-1])
999998
if a_shape[:-1] != res_shape:
1000999
a = dpnp.broadcast_to(a, res_shape + (a_shape[-1],))
10011000
a_shape = a.shape

tests/test_manipulation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,31 @@ def test_no_copy(self):
332332
assert_array_equal(b, a)
333333

334334

335+
class TestBroadcast:
336+
@pytest.mark.parametrize(
337+
"shapes",
338+
[
339+
[[(1,), (3,)]],
340+
[[(1, 3), (3, 3)]],
341+
[[(3, 1), (3, 3)]],
342+
[[(1, 3), (3, 1)]],
343+
[[(1, 1), (3, 3)]],
344+
[[(1, 1), (1, 3)]],
345+
[[(1, 1), (3, 1)]],
346+
[[(1, 0), (0, 0)]],
347+
[[(0, 1), (0, 0)]],
348+
[[(1, 0), (0, 1)]],
349+
[[(1, 1), (0, 0)]],
350+
[[(1, 1), (1, 0)]],
351+
[[(1, 1), (0, 1)]],
352+
],
353+
)
354+
def test_broadcast_shapes(self, shapes):
355+
expected = numpy.broadcast_shapes(*shapes)
356+
result = dpnp.broadcast_shapes(*shapes)
357+
assert_equal(result, expected)
358+
359+
335360
class TestDelete:
336361
@pytest.mark.parametrize(
337362
"obj", [slice(0, 4, 2), 3, [2, 3]], ids=["slice", "int", "list"]

0 commit comments

Comments
 (0)