Skip to content

Commit c41ef8d

Browse files
authored
Merge f1bf82b into 687b8ea
2 parents 687b8ea + f1bf82b commit c41ef8d

File tree

3 files changed

+49
-2
lines changed

3 files changed

+49
-2
lines changed

dpnp/dpnp_iface_manipulation.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
import dpctl.tensor as dpt
4646
import numpy
47+
from dpctl.tensor._manipulation_functions import _broadcast_shapes
4748
from dpctl.tensor._numpy_helper import AxisError, normalize_axis_index
4849

4950
import dpnp
@@ -63,6 +64,7 @@
6364
"atleast_2d",
6465
"atleast_3d",
6566
"broadcast_arrays",
67+
"broadcast_shapes",
6668
"broadcast_to",
6769
"can_cast",
6870
"column_stack",
@@ -967,6 +969,44 @@ def broadcast_arrays(*args, subok=False):
967969
return [dpnp_array._create_from_usm_ndarray(a) for a in usm_arrays]
968970

969971

972+
def broadcast_shapes(*args):
973+
"""
974+
Broadcast the input shapes into a single shape.
975+
976+
For full documentation refer to :obj:`numpy.broadcast_shapes`.
977+
978+
Parameters
979+
----------
980+
*args : tuples of ints, or ints
981+
The shapes to be broadcast against each other.
982+
983+
Returns
984+
-------
985+
tuple
986+
Broadcasted shape.
987+
988+
See Also
989+
--------
990+
:obj:`dpnp.broadcast_arrays` : Broadcast any number of arrays against
991+
each other.
992+
:obj:`dpnp.broadcast_to` : Broadcast an array to a new shape.
993+
994+
Examples
995+
--------
996+
>>> import dpnp as np
997+
>>> np.broadcast_shapes((1, 2), (3, 1), (3, 2))
998+
(3, 2)
999+
>>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7))
1000+
(5, 6, 7)
1001+
1002+
"""
1003+
1004+
if hasattr(numpy, "broadcast_shapes"):
1005+
return numpy.broadcast_shapes(*args)
1006+
1007+
return _broadcast_shapes(*args)
1008+
1009+
9701010
# pylint: disable=redefined-outer-name
9711011
def broadcast_to(array, /, shape, subok=False):
9721012
"""

dpnp/dpnp_iface_mathematical.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -994,8 +994,7 @@ def cross(a, b, axisa=-1, axisb=-1, axisc=-1, axis=None):
994994
a_shape = a.shape
995995
b_shape = b.shape
996996

997-
# TODO: replace with dpnp.broadcast_shapes once implemented
998-
res_shape = numpy.broadcast_shapes(a_shape[:-1], b_shape[:-1])
997+
res_shape = dpnp.broadcast_shapes(a_shape[:-1], b_shape[:-1])
999998
if a_shape[:-1] != res_shape:
1000999
a = dpnp.broadcast_to(a, res_shape + (a_shape[-1],))
10011000
a_shape = a.shape

tests/test_manipulation.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,14 @@ def test_no_copy(self):
332332
assert_array_equal(b, a)
333333

334334

335+
class TestBroadcast:
336+
@pytest.mark.parametrize("shape", [(1, 1), (0, 1)])
337+
def test_broadcast_shapes(self, shape):
338+
expected = numpy.broadcast_shapes(*shape)
339+
result = dpnp.broadcast_shapes(*shape)
340+
assert_equal(result, expected)
341+
342+
335343
class TestDelete:
336344
@pytest.mark.parametrize(
337345
"obj", [slice(0, 4, 2), 3, [2, 3]], ids=["slice", "int", "list"]

0 commit comments

Comments
 (0)