Skip to content

Commit 12e0770

Browse files
committed
ENH: dot
1 parent f135442 commit 12e0770

File tree

5 files changed

+64
-69
lines changed

5 files changed

+64
-69
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -220,10 +220,6 @@ def disp(mesg, device=None, linefeed=True):
220220
raise NotImplementedError
221221

222222

223-
def dot(a, b, out=None):
224-
raise NotImplementedError
225-
226-
227223
def ediff1d(ary, to_end=None, to_begin=None):
228224
raise NotImplementedError
229225

torch_np/_detail/_reductions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def mean(tensor, axis=None, dtype=None, *, where=NoValue):
145145

146146
is_half = dtype == torch.float16
147147
if is_half:
148-
dtype=torch.float32
148+
dtype = torch.float32
149149

150150
if axis is None:
151151
result = tensor.mean(dtype=dtype)

torch_np/_detail/implementations.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,6 @@ def diagonal(tensor, offset=0, axis1=0, axis2=1):
185185
return result
186186

187187

188-
189188
# ### splits ###
190189

191190

@@ -721,7 +720,6 @@ def inner(t_a, t_b):
721720
result = result.to(torch.bool)
722721

723722
return result
724-
725723

726724

727725
def vdot(t_a, t_b, /):
@@ -754,3 +752,17 @@ def vdot(t_a, t_b, /):
754752
result = result.to(torch.bool)
755753

756754
return result
755+
756+
757+
def dot(t_a, t_b):
758+
if t_a.ndim == 0 or t_b.ndim == 0:
759+
result = t_a * t_b
760+
elif t_a.ndim == 1 and t_b.ndim == 1:
761+
result = torch.dot(t_a, t_b)
762+
elif t_a.ndim == 1:
763+
result = torch.mv(t_b.T, t_a).T
764+
elif t_b.ndim == 1:
765+
result = torch.mv(t_a, t_b)
766+
else:
767+
result = torch.matmul(t_a, t_b)
768+
return result

torch_np/_wrapper.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,6 @@ def fill_diagonal(a, val, wrap=False):
400400
return asarray(result)
401401

402402

403-
404403
def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
405404
arr = asarray(a)
406405
return arr.trace(offset, axis1, axis2, dtype=dtype, out=out)
@@ -500,6 +499,13 @@ def vdot(a, b, /):
500499
result = _impl.vdot(t_a, t_b)
501500
return result.item()
502501

502+
503+
def dot(a, b, out=None):
504+
t_a, t_b = _helpers.to_tensors(a, b)
505+
result = _impl.dot(t_a, t_b)
506+
return _helpers.result_or_out(result, out)
507+
508+
503509
###### module-level queries of object properties
504510

505511

@@ -1100,10 +1106,6 @@ def array_equiv(a1, a2):
11001106
return result
11011107

11021108

1103-
def dot():
1104-
raise NotImplementedError
1105-
1106-
11071109
def common_type():
11081110
raise NotImplementedError
11091111

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 42 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -5464,15 +5464,30 @@ def test_vdot_uncontiguous_2(self):
54645464
np.vdot(a.flatten(), b.flatten()))
54655465

54665466

5467-
@pytest.mark.xfail(reason='TODO')
54685467
class TestDot:
54695468
def setup_method(self):
54705469
np.random.seed(128)
5471-
self.A = np.random.rand(4, 2)
5472-
self.b1 = np.random.rand(2, 1)
5473-
self.b2 = np.random.rand(2)
5474-
self.b3 = np.random.rand(1, 2)
5475-
self.b4 = np.random.rand(4)
5470+
5471+
# Numpy guarantees the random stream, and we don't. So inline the
5472+
# values from numpy 1.24.1
5473+
# self.A = np.random.rand(4, 2)
5474+
self.A = np.array([[0.86663704, 0.26314485],
5475+
[0.13140848, 0.04159344],
5476+
[0.23892433, 0.6454746 ],
5477+
[0.79059935, 0.60144244]])
5478+
5479+
# self.b1 = np.random.rand(2, 1)
5480+
self.b1 = np.array([[0.33429937], [0.11942846]])
5481+
5482+
# self.b2 = np.random.rand(2)
5483+
self.b2 = np.array([0.30913305, 0.10972379])
5484+
5485+
# self.b3 = np.random.rand(1, 2)
5486+
self.b3 = np.array([[0.60211331, 0.25128496]])
5487+
5488+
# self.b4 = np.random.rand(4)
5489+
self.b4 = np.array([0.29968129, 0.517116, 0.71520252, 0.9314494])
5490+
54765491
self.N = 7
54775492

54785493
def test_dotmatmat(self):
@@ -5541,16 +5556,26 @@ def test_dotcolumnvect2(self):
55415556

55425557
def test_dotvecscalar(self):
55435558
np.random.seed(100)
5544-
b1 = np.random.rand(1, 1)
5545-
b2 = np.random.rand(1, 4)
5559+
# Numpy guarantees the random stream, and we don't. So inline the
5560+
# values from numpy 1.24.1
5561+
# b1 = np.random.rand(1, 1)
5562+
b1 = np.array([[0.54340494]])
5563+
5564+
# b2 = np.random.rand(1, 4)
5565+
b2 = np.array([[0.27836939, 0.42451759, 0.84477613, 0.00471886]])
5566+
55465567
res = np.dot(b1, b2)
55475568
tgt = np.array([[0.15126730, 0.23068496, 0.45905553, 0.00256425]])
55485569
assert_almost_equal(res, tgt, decimal=self.N)
55495570

55505571
def test_dotvecscalar2(self):
55515572
np.random.seed(100)
5552-
b1 = np.random.rand(4, 1)
5553-
b2 = np.random.rand(1, 1)
5573+
# b1 = np.random.rand(4, 1)
5574+
b1 = np.array([[0.54340494], [0.27836939], [0.42451759], [0.84477613]])
5575+
5576+
# b2 = np.random.rand(1, 1)
5577+
b2 = np.array([[0.00471886]])
5578+
55545579
res = np.dot(b1, b2)
55555580
tgt = np.array([[0.00256425],[0.00131359],[0.00200324],[ 0.00398638]])
55565581
assert_almost_equal(res, tgt, decimal=self.N)
@@ -5566,39 +5591,7 @@ def test_all(self):
55665591
assert_(res.shape == tgt.shape)
55675592
assert_almost_equal(res, tgt, decimal=self.N)
55685593

5569-
def test_vecobject(self):
5570-
class Vec:
5571-
def __init__(self, sequence=None):
5572-
if sequence is None:
5573-
sequence = []
5574-
self.array = np.array(sequence)
5575-
5576-
def __add__(self, other):
5577-
out = Vec()
5578-
out.array = self.array + other.array
5579-
return out
5580-
5581-
def __sub__(self, other):
5582-
out = Vec()
5583-
out.array = self.array - other.array
5584-
return out
5585-
5586-
def __mul__(self, other): # with scalar
5587-
out = Vec(self.array.copy())
5588-
out.array *= other
5589-
return out
5590-
5591-
def __rmul__(self, other):
5592-
return self*other
5593-
5594-
U_non_cont = np.transpose([[1., 1.], [1., 2.]])
5595-
U_cont = np.ascontiguousarray(U_non_cont)
5596-
x = np.array([Vec([1., 0.]), Vec([0., 1.])])
5597-
zeros = np.array([Vec([0., 0.]), Vec([0., 0.])])
5598-
zeros_test = np.dot(U_cont, x) - np.dot(U_non_cont, x)
5599-
assert_equal(zeros[0].array, zeros_test[0].array)
5600-
assert_equal(zeros[1].array, zeros_test[1].array)
5601-
5594+
@pytest.mark.skip(reason='numpy internals')
56025595
def test_dot_2args(self):
56035596
from numpy.core.multiarray import dot
56045597

@@ -5609,6 +5602,7 @@ def test_dot_2args(self):
56095602
d = dot(a, b)
56105603
assert_allclose(c, d)
56115604

5605+
@pytest.mark.skip(reason='numpy internals')
56125606
def test_dot_3args(self):
56135607
from numpy.core.multiarray import dot
56145608

@@ -5631,6 +5625,7 @@ def test_dot_3args(self):
56315625
assert_(r is dot(f, v, r))
56325626
assert_array_equal(r2, r)
56335627

5628+
@pytest.mark.skip(reason='numpy internals')
56345629
def test_dot_3args_errors(self):
56355630
from numpy.core.multiarray import dot
56365631

@@ -5661,6 +5656,7 @@ def test_dot_3args_errors(self):
56615656
r = np.empty((1024, 32), dtype=int)
56625657
assert_raises(ValueError, dot, f, v, r)
56635658

5659+
@pytest.mark.skip(reason="TODO order='F'")
56645660
def test_dot_array_order(self):
56655661
a = np.array([[1, 2], [3, 4]], order='C')
56665662
b = np.array([[1, 2], [3, 4]], order='F')
@@ -5671,6 +5667,7 @@ def test_dot_array_order(self):
56715667
assert_equal(np.dot(b, a), res)
56725668
assert_equal(np.dot(b, b), res)
56735669

5670+
@pytest.mark.skip(reason='TODO: nbytes, view')
56745671
def test_accelerate_framework_sgemv_fix(self):
56755672

56765673
def aligned_array(shape, align, dtype, order='C'):
@@ -5745,26 +5742,14 @@ def test_huge_vectordot(self, dtype):
57455742
res = np.dot(data, data)
57465743
assert res == 2**30+100
57475744

5748-
def test_dtype_discovery_fails(self):
5749-
# See gh-14247, error checking was missing for failed dtype discovery
5750-
class BadObject(object):
5751-
def __array__(self):
5752-
raise TypeError("just this tiny mint leaf")
5753-
5754-
with pytest.raises(TypeError):
5755-
np.dot(BadObject(), BadObject())
5756-
5757-
with pytest.raises(TypeError):
5758-
np.dot(3.0, BadObject())
5759-
57605745

57615746
class MatmulCommon:
57625747
"""Common tests for '@' operator and numpy.matmul.
57635748
57645749
"""
57655750
# Should work with these types. Will want to add
57665751
# "O" at some point
5767-
types = "?bhilqBHILQefdgFDGO"
5752+
types = "?bhilqBefdFD"
57685753

57695754
def test_exceptions(self):
57705755
dims = [
@@ -5975,7 +5960,7 @@ def test_matrix_matrix_values(self):
59755960
assert_equal(res, tgt12_21)
59765961

59775962

5978-
@pytest.mark.xfail(reason='TODO')
5963+
@pytest.mark.xfail(reason='TODO: matmul (ufunc wrapping goes south?)')
59795964
class TestMatmul(MatmulCommon):
59805965
matmul = np.matmul
59815966

0 commit comments

Comments
 (0)