Skip to content

Commit 420dc9d

Browse files
committed
allow using scalar when possible
1 parent bcd9e21 commit 420dc9d

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _append_to_diff_array(a, axis, combined, values):
145145
146146
"""
147147

148-
dpnp.check_supported_arrays_type(values, scalar_type=True)
148+
dpnp.check_supported_arrays_type(values, scalar_type=True, all_scalars=True)
149149
if dpnp.isscalar(values):
150150
values = dpnp.asarray(
151151
values, sycl_queue=a.sycl_queue, usm_type=a.usm_type

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -705,7 +705,9 @@ def matrix_rank(A, tol=None, hermitian=False):
705705

706706
dpnp.check_supported_arrays_type(A)
707707
if tol is not None:
708-
dpnp.check_supported_arrays_type(tol, scalar_type=True)
708+
dpnp.check_supported_arrays_type(
709+
tol, scalar_type=True, all_scalars=True
710+
)
709711

710712
return dpnp_matrix_rank(A, tol=tol, hermitian=hermitian)
711713

@@ -819,7 +821,7 @@ def pinv(a, rcond=1e-15, hermitian=False):
819821
"""
820822

821823
dpnp.check_supported_arrays_type(a)
822-
dpnp.check_supported_arrays_type(rcond, scalar_type=True)
824+
dpnp.check_supported_arrays_type(rcond, scalar_type=True, all_scalars=True)
823825
check_stacked_2d(a)
824826

825827
return dpnp_pinv(a, rcond=rcond, hermitian=hermitian)

tests/test_mathematical.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2554,14 +2554,15 @@ def setup_method(self):
25542554
def test_matmul(self, order_pair, shape_pair):
25552555
order1, order2 = order_pair
25562556
shape1, shape2 = shape_pair
2557-
a1 = numpy.arange(numpy.prod(shape1)).reshape(shape1)
2558-
a2 = numpy.arange(numpy.prod(shape2)).reshape(shape2)
2557+
dtype = dpnp.default_float_type()
2558+
a1 = numpy.arange(numpy.prod(shape1), dtype=dtype).reshape(shape1)
2559+
a2 = numpy.arange(numpy.prod(shape2), dtype=dtype).reshape(shape2)
25592560
a1 = numpy.array(a1, order=order1)
25602561
a2 = numpy.array(a2, order=order2)
25612562

25622563
b1 = dpnp.asarray(a1)
25632564
b2 = dpnp.asarray(a2)
2564-
2565+
print(a1.dtype)
25652566
result = dpnp.matmul(b1, b2)
25662567
expected = numpy.matmul(a1, a2)
25672568
assert_dtype_allclose(result, expected)

0 commit comments

Comments
 (0)