Skip to content

Commit 2ff2e00

Browse files
committed
TST: inner, vdot
1 parent 54e09b8 commit 2ff2e00

File tree

2 files changed

+43
-7
lines changed

2 files changed

+43
-7
lines changed

torch_np/_detail/implementations.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -692,11 +692,12 @@ def where(condition, x, y):
692692

693693
# ### dot and other linalg ###
694694

695+
695696
def inner(t_a, t_b):
696-
is_half = t_a.dtype == torch.float16 or t_b.dtype == torch.float16
697-
is_bool = t_a.dtype == torch.bool or t_b.dtype == torch.bool
697+
dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype))
698+
is_half = dtype == torch.float16
699+
is_bool = dtype == torch.bool
698700

699-
dtype = None
700701
if is_half:
701702
# work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
702703
dtype = torch.float32
@@ -718,11 +719,32 @@ def inner(t_a, t_b):
718719

719720

720721
def vdot(t_a, t_b, /):
721-
# torch only accepts 1D arrays, numpy ravels
722+
# 1. torch only accepts 1D arrays, numpy ravels
723+
# 2. torch requires matching dtype, while numpy casts (?)
722724
t_a, t_b = torch.atleast_1d(t_a, t_b)
723725
if t_a.ndim > 1:
724726
t_a = t_a.ravel()
725727
if t_b.ndim > 1:
726728
t_b = t_b.ravel()
729+
730+
dtype = _dtypes_impl.result_type_impl((t_a.dtype, t_b.dtype))
731+
is_half = dtype == torch.float16
732+
is_bool = dtype == torch.bool
733+
734+
# work around torch's "dot" not implemented for 'Half', 'Bool'
735+
if is_half:
736+
dtype = torch.float32
737+
if is_bool:
738+
dtype = torch.uint8
739+
740+
t_a = _util.cast_if_needed(t_a, dtype)
741+
t_b = _util.cast_if_needed(t_b, dtype)
742+
727743
result = torch.vdot(t_a, t_b)
744+
745+
if is_half:
746+
result = result.to(torch.float16)
747+
if is_bool:
748+
result = result.to(torch.bool)
749+
728750
return result

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5421,23 +5421,22 @@ def __array_finalize__(self, obj):
54215421
assert_(res.info == dat.info)
54225422

54235423

5424-
@pytest.mark.xfail(reason='TODO')
54255424
class TestVdot:
54265425
def test_basic(self):
54275426
dt_numeric = np.typecodes['AllFloat'] + np.typecodes['AllInteger']
54285427
dt_complex = np.typecodes['Complex']
54295428

54305429
# test real
54315430
a = np.eye(3)
5432-
for dt in dt_numeric + 'O':
5431+
for dt in dt_numeric:
54335432
b = a.astype(dt)
54345433
res = np.vdot(b, b)
54355434
assert_(np.isscalar(res))
54365435
assert_equal(np.vdot(b, b), 3)
54375436

54385437
# test complex
54395438
a = np.eye(3) * 1j
5440-
for dt in dt_complex + 'O':
5439+
for dt in dt_complex:
54415440
b = a.astype(dt)
54425441
res = np.vdot(b, b)
54435442
assert_(np.isscalar(res))
@@ -5449,6 +5448,7 @@ def test_basic(self):
54495448
assert_(np.isscalar(res))
54505449
assert_equal(np.vdot(b, b), True)
54515450

5451+
@pytest.mark.xfail(reason="implement order='F'")
54525452
def test_vdot_array_order(self):
54535453
a = np.array([[1, 2], [3, 4]], order='C')
54545454
b = np.array([[1, 2], [3, 4]], order='F')
@@ -5476,6 +5476,20 @@ def test_vdot_uncontiguous(self):
54765476
np.vdot(a.flatten(), b.flatten()))
54775477
assert_equal(np.vdot(a.copy(), b),
54785478
np.vdot(a.flatten(), b.flatten()))
5479+
5480+
@pytest.mark.xfail(reason="implement order='F'")
5481+
def test_vdot_uncontiguous_2(self):
5482+
# test order='F' separately
5483+
for size in [2, 1000]:
5484+
# Different sizes match different branches in vdot.
5485+
a = np.zeros((size, 2, 2))
5486+
b = np.zeros((size, 2, 2))
5487+
a[:, 0, 0] = np.arange(size)
5488+
b[:, 0, 0] = np.arange(size) + 1
5489+
# Make a and b uncontiguous:
5490+
a = a[..., 0]
5491+
b = b[..., 0]
5492+
54795493
assert_equal(np.vdot(a.copy('F'), b),
54805494
np.vdot(a.flatten(), b.flatten()))
54815495
assert_equal(np.vdot(a, b.copy('F')),

0 commit comments

Comments
 (0)