Skip to content

Commit 9471d79

Browse files
Add tests for tensorsolve
1 parent a009ebd commit 9471d79

File tree

4 files changed

+76
-0
lines changed

4 files changed

+76
-0
lines changed

tests/test_linalg.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,3 +1409,24 @@ def test_pinv_errors(self):
14091409
a_dp_q = inp.array(a_dp, sycl_queue=a_queue)
14101410
rcond_dp_q = inp.array([0.5], dtype="float32", sycl_queue=rcond_queue)
14111411
assert_raises(ValueError, inp.linalg.pinv, a_dp_q, rcond_dp_q)
1412+
1413+
1414+
class TestTensorsolve:
1415+
def test_tensorsolve_errors(self):
1416+
a_dp = inp.eye(24, dtype="float32").reshape(4, 6, 8, 3)
1417+
b_dp = inp.ones(a_dp.shape[:2], dtype="float32")
1418+
1419+
# unsupported type `a` and `b`
1420+
a_np = inp.asnumpy(a_dp)
1421+
b_np = inp.asnumpy(b_dp)
1422+
assert_raises(TypeError, inp.linalg.tensorsolve, a_np, b_dp)
1423+
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, b_np)
1424+
1425+
# unsupported type `axes`
1426+
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, 2.0)
1427+
assert_raises(TypeError, inp.linalg.tensorsolve, a_dp, -2)
1428+
1429+
# incorrect axes
1430+
assert_raises(
1431+
inp.linalg.LinAlgError, inp.linalg.tensorsolve, a_dp, b_dp, (1,)
1432+
)

tests/test_sycl_queue.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,3 +1849,24 @@ def test_pinv(shape, hermitian, rcond_as_array, device):
18491849
B_queue = B_result.sycl_queue
18501850

18511851
assert_sycl_queue_equal(B_queue, a_dp.sycl_queue)
1852+
1853+
1854+
@pytest.mark.parametrize(
1855+
"device",
1856+
valid_devices,
1857+
ids=[device.filter_string for device in valid_devices],
1858+
)
1859+
def test_tensorsolve(device):
1860+
a_np = numpy.random.randn(3, 2, 6).astype(dpnp.default_float_type())
1861+
b_np = numpy.ones(a_np.shape[:2], dtype=a_np.dtype)
1862+
1863+
a_dp = dpnp.array(a_np, device=device)
1864+
b_dp = dpnp.array(b_np, device=device)
1865+
1866+
result = dpnp.linalg.tensorsolve(a_dp, b_dp)
1867+
expected = numpy.linalg.tensorsolve(a_np, b_np)
1868+
assert_dtype_allclose(result, expected)
1869+
1870+
result_queue = result.sycl_queue
1871+
1872+
assert_sycl_queue_equal(result_queue, a_dp.sycl_queue)

tests/test_usm_type.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,3 +1014,17 @@ def test_qr(shape, mode, usm_type):
10141014

10151015
assert a.usm_type == dp_q.usm_type
10161016
assert a.usm_type == dp_r.usm_type
1017+
1018+
1019+
@pytest.mark.parametrize("usm_type_a", list_of_usm_types, ids=list_of_usm_types)
1020+
@pytest.mark.parametrize("usm_type_b", list_of_usm_types, ids=list_of_usm_types)
1021+
def test_tensorsolve(usm_type_a, usm_type_b):
1022+
data = numpy.random.randn(3, 2, 6)
1023+
a = dp.array(data, usm_type=usm_type_a)
1024+
b = dp.ones(a.shape[:2], dtype=a.dtype, usm_type=usm_type_b)
1025+
1026+
result = dp.linalg.tensorsolve(a, b)
1027+
1028+
assert a.usm_type == usm_type_a
1029+
assert b.usm_type == usm_type_b
1030+
assert result.usm_type == du.get_coerced_usm_type([usm_type_a, usm_type_b])

tests/third_party/cupy/linalg_tests/test_solve.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,26 @@ def test_invalid_shape(self):
9595
self.check_shape((0, 3, 4), (3,), linalg_errors)
9696

9797

98+
@testing.parameterize(
99+
*testing.product(
100+
{
101+
"a_shape": [(2, 3, 6), (3, 4, 4, 3)],
102+
"axes": [None, (0, 2)],
103+
}
104+
)
105+
)
106+
@testing.fix_random()
107+
class TestTensorSolve(unittest.TestCase):
108+
@testing.for_dtypes("ifdFD")
109+
@testing.numpy_cupy_allclose(atol=0.02, type_check=has_support_aspect64())
110+
def test_tensorsolve(self, xp, dtype):
111+
a_shape = self.a_shape
112+
b_shape = self.a_shape[:2]
113+
a = testing.shaped_random(a_shape, xp, dtype=dtype, seed=0)
114+
b = testing.shaped_random(b_shape, xp, dtype=dtype, seed=1)
115+
return xp.linalg.tensorsolve(a, b, axes=self.axes)
116+
117+
98118
@testing.parameterize(
99119
*testing.product(
100120
{

0 commit comments

Comments
 (0)