Skip to content

Commit edc498d

Browse files
committed
allow using scalar when possible
1 parent 6f2364e commit edc498d

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def _append_to_diff_array(a, axis, combined, values):
143143
144144
"""
145145

146-
dpnp.check_supported_arrays_type(values, scalar_type=True)
146+
dpnp.check_supported_arrays_type(values, scalar_type=True, all_scalars=True)
147147
if dpnp.isscalar(values):
148148
values = dpnp.asarray(
149149
values, sycl_queue=a.sycl_queue, usm_type=a.usm_type

dpnp/linalg/dpnp_iface_linalg.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,9 @@ def matrix_rank(A, tol=None, hermitian=False):
581581

582582
dpnp.check_supported_arrays_type(A)
583583
if tol is not None:
584-
dpnp.check_supported_arrays_type(tol, scalar_type=True)
584+
dpnp.check_supported_arrays_type(
585+
tol, scalar_type=True, all_scalars=True
586+
)
585587

586588
return dpnp_matrix_rank(A, tol=tol, hermitian=hermitian)
587589

@@ -695,7 +697,7 @@ def pinv(a, rcond=1e-15, hermitian=False):
695697
"""
696698

697699
dpnp.check_supported_arrays_type(a)
698-
dpnp.check_supported_arrays_type(rcond, scalar_type=True)
700+
dpnp.check_supported_arrays_type(rcond, scalar_type=True, all_scalars=True)
699701
check_stacked_2d(a)
700702

701703
return dpnp_pinv(a, rcond=rcond, hermitian=hermitian)

tests/test_mathematical.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2585,14 +2585,15 @@ def setup_method(self):
25852585
def test_matmul(self, order_pair, shape_pair):
25862586
order1, order2 = order_pair
25872587
shape1, shape2 = shape_pair
2588-
a1 = numpy.arange(numpy.prod(shape1)).reshape(shape1)
2589-
a2 = numpy.arange(numpy.prod(shape2)).reshape(shape2)
2588+
dtype = dpnp.default_float_type()
2589+
a1 = numpy.arange(numpy.prod(shape1), dtype=dtype).reshape(shape1)
2590+
a2 = numpy.arange(numpy.prod(shape2), dtype=dtype).reshape(shape2)
25902591
a1 = numpy.array(a1, order=order1)
25912592
a2 = numpy.array(a2, order=order2)
25922593

25932594
b1 = dpnp.asarray(a1)
25942595
b2 = dpnp.asarray(a2)
2595-
2596+
print(a1.dtype)
25962597
result = dpnp.matmul(b1, b2)
25972598
expected = numpy.matmul(a1, a2)
25982599
assert_dtype_allclose(result, expected)

0 commit comments

Comments
 (0)