Skip to content

Commit ed0f10b

Browse files
committed
ENH: add tensordot
1 parent c7d3b53 commit ed0f10b

File tree

3 files changed

+18
-23
lines changed

3 files changed

+18
-23
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,6 @@ def binary_repr(num, width=None):
134134
raise NotImplementedError
135135

136136

137-
def blackman(M):
138-
raise NotImplementedError
139-
140-
141137
def block(arrays):
142138
raise NotImplementedError
143139

@@ -337,14 +333,6 @@ def gradient(f, *varargs, axis=None, edge_order=1):
337333
raise NotImplementedError
338334

339335

340-
def hamming(M):
341-
raise NotImplementedError
342-
343-
344-
def hanning(M):
345-
raise NotImplementedError
346-
347-
348336
def histogram(a, bins=10, range=None, normed=None, weights=None, density=None):
349337
raise NotImplementedError
350338

@@ -409,10 +397,6 @@ def ix_(*args):
409397
raise NotImplementedError
410398

411399

412-
def kaiser(M, beta):
413-
raise NotImplementedError
414-
415-
416400
def lexsort(keys, axis=-1):
417401
raise NotImplementedError
418402

@@ -759,10 +743,6 @@ def take(a, indices, axis=None, out=None, mode="raise"):
759743
raise NotImplementedError
760744

761745

762-
def tensordot(a, b, axes=2):
763-
raise NotImplementedError
764-
765-
766746
def trapz(y, x=None, dx=1.0, axis=-1):
767747
raise NotImplementedError
768748

torch_np/_funcs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1223,6 +1223,15 @@ def vdot(a: ArrayLike, b: ArrayLike, /):
12231223
return result.item()
12241224

12251225

1226+
@normalizer
1227+
def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
1228+
if isinstance(axes, (list, tuple)):
1229+
axes = [[ax] if isinstance(ax, int) else ax for ax in axes]
1230+
result = torch.tensordot(a, b, dims=axes)
1231+
1232+
return result
1233+
1234+
12261235
@normalizer
12271236
def dot(a: ArrayLike, b: ArrayLike, out: Optional[NDArray] = None) -> OutArray:
12281237
dtype = _dtypes_impl.result_type_impl((a.dtype, b.dtype))

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2988,15 +2988,21 @@ def test_shape_mismatch_error_message(self):
29882988
np.broadcast([[1, 2, 3]], [[4], [5]], [6, 7])
29892989

29902990

2991-
@pytest.mark.xfail(reason="TODO")
29922991
class TestTensordot:
29932992

29942993
def test_zero_dimension(self):
29952994
# Test resolution to issue #5663
2996-
a = np.ndarray((3,0))
2997-
b = np.ndarray((0,4))
2995+
a = np.zeros((3,0))
2996+
b = np.zeros((0,4))
29982997
td = np.tensordot(a, b, (1, 0))
29992998
assert_array_equal(td, np.dot(a, b))
2999+
3000+
@pytest.mark.xfail(reason="no einsum")
3001+
def test_zero_dimension_einsum(self):
3002+
# Test resolution to issue #5663
3003+
a = np.zeros((3,0))
3004+
b = np.zeros((0,4))
3005+
td = np.tensordot(a, b, (1, 0))
30003006
assert_array_equal(td, np.einsum('ij,jk', a, b))
30013007

30023008
def test_zero_dimensional(self):

0 commit comments

Comments
 (0)