Skip to content

Commit 2039c9b

Browse files
committed
address comments
1 parent f04bb51 commit 2039c9b

File tree

5 files changed

+20
-16
lines changed

5 files changed

+20
-16
lines changed

doc/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
"sphinx.ext.autodoc",
5656
"sphinx.ext.autosummary",
5757
"sphinxcontrib.googleanalytics",
58+
"sphinx.ext.mathjax",
5859
]
5960

6061
googleanalytics_id = "G-554F8VNE28"

dpnp/dpnp_iface_linearalgebra.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,9 @@ def dot(a, b, out=None):
6767
6868
Parameters
6969
----------
70-
a : {dpnp_array, usm_ndarray, scalar}
70+
a : {dpnp.ndarray, usm_ndarray, scalar}
7171
First input array. Both inputs `a` and `b` can not be scalars at the same time.
72-
b : {dpnp_array, usm_ndarray, scalar}
72+
b : {dpnp.ndarray, usm_ndarray, scalar}
7373
Second input array. Both inputs `a` and `b` can not be scalars at the same time.
7474
out : {dpnp.ndarray, usm_ndarray}, optional
7575
Alternative output array in which to place the result. It must have
@@ -413,9 +413,9 @@ def tensordot(a, b, axes=2):
413413
414414
Parameters
415415
----------
416-
a : {dpnp_array, usm_ndarray, scalar}
416+
a : {dpnp.ndarray, usm_ndarray, scalar}
417417
First input array. Both inputs `a` and `b` can not be scalars at the same time.
418-
b : {dpnp_array, usm_ndarray, scalar}
418+
b : {dpnp.ndarray, usm_ndarray, scalar}
419419
Second input array. Both inputs `a` and `b` can not be scalars at the same time.
420420
axes : int or (2,) array_like
421421
* integer_like
@@ -502,7 +502,7 @@ def tensordot(a, b, axes=2):
502502
iter(axes)
503503
except Exception:
504504
if not isinstance(axes, int):
505-
raise ValueError("Axes must be an integer.")
505+
raise TypeError("Axes must be an integer.")
506506
axes_a = tuple(range(-axes, 0))
507507
axes_b = tuple(range(0, axes))
508508
else:
@@ -561,11 +561,11 @@ def vdot(a, b):
561561
562562
Parameters
563563
----------
564-
a : {dpnp_array, usm_ndarray, scalar}
564+
a : {dpnp.ndarray, usm_ndarray, scalar}
565565
First input array. Both inputs `a` and `b` can not be
566566
scalars at the same time. If `a` is complex, the complex
567567
conjugate is taken before the calculation of the dot product.
568-
b : {dpnp_array, usm_ndarray, scalar}
568+
b : {dpnp.ndarray, usm_ndarray, scalar}
569569
Second input array. Both inputs `a` and `b` can not be
570570
scalars at the same time.
571571

tests/helper.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88

99

1010
def assert_dtype_allclose(
11-
dpnp_arr, numpy_arr, check_type=True, check_only_type_kind=False
11+
dpnp_arr,
12+
numpy_arr,
13+
check_type=True,
14+
check_only_type_kind=False,
15+
factor=8,
1216
):
1317
"""
1418
Assert DPNP and NumPy array based on maximum dtype resolution of input arrays
@@ -28,6 +32,7 @@ def assert_dtype_allclose(
2832
The 'check_only_type_kind' parameter (False by default) asserts only equal type kinds
2933
for all data types supported by DPNP when set to True.
3034
It is effective only when 'check_type' is also set to True.
35+
The parameter `factor` scales the resolution used for comparing the arrays.
3136
3237
"""
3338

@@ -44,7 +49,7 @@ def assert_dtype_allclose(
4449
if is_inexact(numpy_arr)
4550
else -dpnp.inf
4651
)
47-
tol = 8 * max(tol_dpnp, tol_numpy)
52+
tol = factor * max(tol_dpnp, tol_numpy)
4853
assert_allclose(dpnp_arr.asnumpy(), numpy_arr, atol=tol, rtol=tol)
4954
if check_type:
5055
numpy_arr_dtype = numpy_arr.dtype

tests/test_dot.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ def test_tensordot(self, dtype, axes):
383383

384384
result = dpnp.tensordot(ia, ib, axes=axes)
385385
expected = numpy.tensordot(a, b, axes=axes)
386-
assert_dtype_allclose(result, expected)
386+
assert_dtype_allclose(result, expected, factor=16)
387387

388388
@pytest.mark.parametrize("dtype", get_complex_dtypes())
389389
@pytest.mark.parametrize("axes", [-3, -2, -1, 0, 1, 2])
@@ -399,7 +399,7 @@ def test_tensordot_complex(self, dtype, axes):
399399

400400
result = dpnp.tensordot(ia, ib, axes=axes)
401401
expected = numpy.tensordot(a, b, axes=axes)
402-
assert_dtype_allclose(result, expected)
402+
assert_dtype_allclose(result, expected, factor=16)
403403

404404
@pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True))
405405
@pytest.mark.parametrize(
@@ -424,7 +424,7 @@ def test_tensordot_axes(self, dtype, axes):
424424

425425
result = dpnp.tensordot(ia, ib, axes=axes)
426426
expected = numpy.tensordot(a, b, axes=axes)
427-
assert_dtype_allclose(result, expected)
427+
assert_dtype_allclose(result, expected, factor=16)
428428

429429
@pytest.mark.parametrize("dtype1", get_all_dtypes())
430430
@pytest.mark.parametrize("dtype2", get_all_dtypes())
@@ -440,7 +440,7 @@ def test_tensordot_input_dtype_matrix(self, dtype1, dtype2):
440440

441441
result = dpnp.tensordot(ia, ib)
442442
expected = numpy.tensordot(a, b)
443-
assert_dtype_allclose(result, expected)
443+
assert_dtype_allclose(result, expected, factor=16)
444444

445445
def test_tensordot_strided(self):
446446
for dim in [1, 2, 3, 4]:
@@ -475,7 +475,7 @@ def test_tensordot_error(self):
475475
a = dpnp.arange(24).reshape(2, 3, 4)
476476
b = dpnp.arange(24).reshape(3, 4, 2)
477477
# axes should be an integer
478-
with pytest.raises(ValueError):
478+
with pytest.raises(TypeError):
479479
dpnp.tensordot(a, b, axes=2.0)
480480

481481
# Axes must consist of two sequences

tests/third_party/cupy/linalg_tests/test_product.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,6 @@ def test_zerodim_kron(self, xp, dtype):
415415
}
416416
)
417417
)
418-
@pytest.mark.usefixtures("allow_fall_back_on_numpy")
419-
@testing.gpu
420418
class TestProductZeroLength(unittest.TestCase):
421419
@testing.for_all_dtypes()
422420
@testing.numpy_cupy_allclose()

0 commit comments

Comments
 (0)