Skip to content

Commit e3afcfc

Browse files
committed
address comments
1 parent 4ee2ede commit e3afcfc

File tree

8 files changed

+19
-58
lines changed

8 files changed

+19
-58
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ This release achieves 100% compliance with Python Array API specification (revis
2727
* Updated `dpnp.einsum` to add support for `order=None` [#2411](https://github.com/IntelPython/dpnp/pull/2411)
2828
* Updated Python Array API specification version supported to `2024.12` [#2416](https://github.com/IntelPython/dpnp/pull/2416)
2929
* Removed `einsum_call` keyword from `dpnp.einsum_path` signature [#2421](https://github.com/IntelPython/dpnp/pull/2421)
30+
* Updated `dpnp.vdot` to return a 0-D array when one of the inputs is a scalar [#2295](https://github.com/IntelPython/dpnp/pull/2295)
31+
* Updated `dpnp.outer` to return the same dtype as NumPy when multiplying an array with a scalar [#2295](https://github.com/IntelPython/dpnp/pull/2295)
3032

3133
### Fixed
3234

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,17 @@
7171
# TODO: implement a specific scalar-array kernel
7272
def _call_multiply(a, b, out=None, outer_calc=False):
7373
"""
74-
Call multiply function for special cases of scalar-array dots.
75-
76-
if `sc` is an scalar and `a` is an array of type float32, we have
77-
dpnp.multiply(a, sc).dtype == dpnp.float32 and
78-
numpy.multiply(a, sc).dtype == dpnp.float32.
79-
80-
However, for scalar-array dots such as dot function we have
81-
dpnp.dot(a, sc).dtype == dpnp.float32 while
82-
numpy.dot(a, sc).dtype == dpnp.float64.
83-
84-
We need to adjust the behavior of the multiply function when it is
85-
used for special cases of scalar-array dots.
74+
Adjusted multiply function for handling special cases of scalar-array dot
75+
products in linear algebra.
76+
77+
`dpnp.multiply` cannot directly be used for calculating scalar-array dots,
78+
because the output dtype of multiply is not the same as the expected dtype
79+
for scalar-array dots. For example, if `sc` is an scalar and `a` is an
80+
array of type `float32`, then `dpnp.multiply(a, sc).dtype == dpnp.float32`
81+
(similar to NumPy). However, for scalar-array dots, such as the dot
82+
function, we need `dpnp.dot(a, sc).dtype == dpnp.float64` to align with
83+
NumPy. This functions adjusts the behavior of `dpnp.multiply` function to
84+
meet this requirement.
8685
8786
"""
8887

dpnp/tests/helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def _assert_dtype(a_dt, b_dt, check_only_type_kind=False):
1818

1919

2020
def _assert_shape(a, b):
21+
# it is assumed `a` is a `dpnp.ndarray` and so it has shape attribute
2122
if hasattr(b, "shape"):
2223
assert a.shape == b.shape, f"{a.shape} != {b.shape}"
2324
else:

dpnp/tests/test_dlpack.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ def test_dtype_passthrough(self, xp, dt):
4040
x = xp.arange(5).astype(dt)
4141
y = xp.from_dlpack(x)
4242

43-
assert y.dtype == x.dtype
4443
assert_array_equal(x, y)
4544

4645
@pytest.mark.parametrize("xp", [dpnp, numpy])

dpnp/tests/test_linalg.py

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -422,13 +422,9 @@ def test_det_empty(self):
422422
a = numpy.empty((0, 0, 2, 2), dtype=numpy.float32)
423423
ia = dpnp.array(a)
424424

425-
np_det = numpy.linalg.det(a)
426-
dpnp_det = dpnp.linalg.det(ia)
427-
428-
assert dpnp_det.dtype == np_det.dtype
429-
assert dpnp_det.shape == np_det.shape
430-
431-
assert_allclose(dpnp_det, np_det)
425+
expected = numpy.linalg.det(a)
426+
result = dpnp.linalg.det(ia)
427+
assert_allclose(result, expected)
432428

433429
@pytest.mark.parametrize(
434430
"matrix",
@@ -2851,25 +2847,6 @@ def get_tol(self, dtype):
28512847
tol = 1e-03
28522848
self._tol = tol
28532849

2854-
def check_types_shapes(
2855-
self, dp_u, dp_s, dp_vt, np_u, np_s, np_vt, compute_vt=True
2856-
):
2857-
if has_support_aspect64():
2858-
if compute_vt:
2859-
assert dp_u.dtype == np_u.dtype
2860-
assert dp_vt.dtype == np_vt.dtype
2861-
assert dp_s.dtype == np_s.dtype
2862-
else:
2863-
if compute_vt:
2864-
assert dp_u.dtype.kind == np_u.dtype.kind
2865-
assert dp_vt.dtype.kind == np_vt.dtype.kind
2866-
assert dp_s.dtype.kind == np_s.dtype.kind
2867-
2868-
if compute_vt:
2869-
assert dp_u.shape == np_u.shape
2870-
assert dp_vt.shape == np_vt.shape
2871-
assert dp_s.shape == np_s.shape
2872-
28732850
# Checks the accuracy of singular value decomposition (SVD).
28742851
# Compares the reconstructed matrix from the decomposed components
28752852
# with the original matrix.
@@ -2922,7 +2899,6 @@ def test_svd(self, dtype, shape):
29222899
result = dpnp.linalg.svd(dp_a)
29232900
dp_u, dp_s, dp_vh = result.U, result.S, result.Vh
29242901

2925-
self.check_types_shapes(dp_u, dp_s, dp_vh, np_u, np_s, np_vh)
29262902
self.get_tol(dtype)
29272903
self.check_decomposition(
29282904
dp_a, dp_u, dp_s, dp_vh, np_u, np_s, np_vh, True
@@ -2950,10 +2926,6 @@ def test_svd_hermitian(self, dtype, compute_vt, shape):
29502926
dp_s = dpnp.linalg.svd(dp_a, compute_uv=compute_vt, hermitian=True)
29512927
np_u = np_vh = dp_u = dp_vh = None
29522928

2953-
self.check_types_shapes(
2954-
dp_u, dp_s, dp_vh, np_u, np_s, np_vh, compute_vt
2955-
)
2956-
29572929
self.get_tol(dtype)
29582930

29592931
self.check_decomposition(
@@ -3029,14 +3001,6 @@ def get_tol(self, dtype):
30293001
tol = 1e-03
30303002
self._tol = tol
30313003

3032-
def check_types_shapes(self, dp_B, np_B):
3033-
if has_support_aspect64():
3034-
assert dp_B.dtype == np_B.dtype
3035-
else:
3036-
assert dp_B.dtype.kind == np_B.dtype.kind
3037-
3038-
assert dp_B.shape == np_B.shape
3039-
30403004
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
30413005
@pytest.mark.parametrize(
30423006
"shape",
@@ -3058,7 +3022,6 @@ def test_pinv(self, dtype, shape):
30583022
B = numpy.linalg.pinv(a)
30593023
B_dp = dpnp.linalg.pinv(a_dp)
30603024

3061-
self.check_types_shapes(B_dp, B)
30623025
self.get_tol(dtype)
30633026
tol = self._tol
30643027
assert_allclose(B_dp, B, rtol=tol, atol=tol)
@@ -3083,7 +3046,6 @@ def test_pinv_hermitian(self, dtype, shape):
30833046
B = numpy.linalg.pinv(a, hermitian=True)
30843047
B_dp = dpnp.linalg.pinv(a_dp, hermitian=True)
30853048

3086-
self.check_types_shapes(B_dp, B)
30873049
self.get_tol(dtype)
30883050
tol = self._tol
30893051

dpnp/tests/test_manipulation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1089,7 +1089,6 @@ def test_non_array_input(self):
10891089
assert expected.flags["C"] == result.flags["C"]
10901090
assert expected.flags["F"] == result.flags["F"]
10911091
assert expected.flags["W"] == result.flags["W"]
1092-
assert expected.dtype == result.dtype
10931092
assert_array_equal(result, expected)
10941093

10951094
def test_C_and_F_simul(self):

dpnp/tests/test_sort.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,6 @@ def test_complex(self, dtype):
407407
result = dpnp.sort_complex(ia)
408408
expected = numpy.sort_complex(a)
409409
assert_equal(result, expected)
410-
assert result.dtype == expected.dtype
411410

412411

413412
@pytest.mark.parametrize("kth", [0, 1])

dpnp/tests/testing/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def _assert(assert_func, result, expected, *args, **kwargs):
5050
# For numpy < 2.0, some tests will fail for dtype mismatch
5151
dev = dpctl.select_default_device()
5252
if numpy.__version__ >= "2.0.0" and dev.has_aspect_fp64:
53-
kwargs.setdefault("strict", True)
53+
strict = kwargs.setdefault("strict", True)
5454
if flag:
55-
if kwargs.get("strict"):
55+
if strict:
5656
if hasattr(expected, "dtype"):
5757
assert (
5858
result.dtype == expected.dtype

0 commit comments

Comments
 (0)