Skip to content

Commit e737658

Browse files
Add test_tensorsolve_axes
1 parent 9471d79 commit e737658

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

tests/test_linalg.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1412,6 +1412,29 @@ def test_pinv_errors(self):
14121412

14131413

14141414
class TestTensorsolve:
1415+
@pytest.mark.parametrize("dtype", get_all_dtypes())
1416+
@pytest.mark.parametrize(
1417+
"axes",
1418+
[None, (1,), (2,)],
1419+
ids=[
1420+
"None",
1421+
"(1,)",
1422+
"(2,)",
1423+
],
1424+
)
1425+
def test_tensorsolve_axes(self, dtype, axes):
1426+
a = numpy.eye(12).reshape(12, 3, 4).astype(dtype)
1427+
b = numpy.ones(a.shape[0], dtype=dtype)
1428+
1429+
a_dp = inp.array(a)
1430+
b_dp = inp.array(b)
1431+
1432+
res_np = numpy.linalg.tensorsolve(a, b, axes=axes)
1433+
res_dp = inp.linalg.tensorsolve(a_dp, b_dp, axes=axes)
1434+
1435+
assert res_np.shape == res_dp.shape
1436+
assert_dtype_allclose(res_dp, res_np)
1437+
14151438
def test_tensorsolve_errors(self):
14161439
a_dp = inp.eye(24, dtype="float32").reshape(4, 6, 8, 3)
14171440
b_dp = inp.ones(a_dp.shape[:2], dtype="float32")

0 commit comments

Comments
 (0)