Skip to content

Commit 00d78c4

Browse files
committed
ENH: fill_diagonal, trace, indices, vdot
1 parent 63e0787 commit 00d78c4

File tree

5 files changed

+104
-39
lines changed

5 files changed

+104
-39
lines changed

autogen/numpy_api_dump.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,6 @@ def choose(a, choices, out=None, mode="raise"):
172172
raise NotImplementedError
173173

174174

175-
def clip(a, a_min, a_max, out=None, **kwargs):
176-
raise NotImplementedError
177-
178-
179175
def common_type(*arrays):
180176
raise NotImplementedError
181177

@@ -244,10 +240,6 @@ def extract(condition, arr):
244240
raise NotImplementedError
245241

246242

247-
def fill_diagonal(a, val, wrap=False):
248-
raise NotImplementedError
249-
250-
251243
def find_common_type(array_types, scalar_types):
252244
raise NotImplementedError
253245

@@ -377,10 +369,6 @@ def in1d(ar1, ar2, assume_unique=False, invert=False):
377369
raise NotImplementedError
378370

379371

380-
def indices(dimensions, dtype=int, sparse=False):
381-
raise NotImplementedError
382-
383-
384372
def insert(arr, obj, values, axis=None):
385373
raise NotImplementedError
386374

@@ -766,6 +754,7 @@ def shares_memory(a, b, max_work=None):
766754
def show():
767755
raise NotImplementedError
768756

757+
769758
def sort_complex(a):
770759
raise NotImplementedError
771760

@@ -778,10 +767,6 @@ def tensordot(a, b, axes=2):
778767
raise NotImplementedError
779768

780769

781-
def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
782-
raise NotImplementedError
783-
784-
785770
def trapz(y, x=None, dx=1.0, axis=-1):
786771
raise NotImplementedError
787772

@@ -818,9 +803,5 @@ def unwrap(p, discont=None, axis=-1, *, period=6.283185307179586):
818803
raise NotImplementedError
819804

820805

821-
def vdot(a, b, /):
822-
raise NotImplementedError
823-
824-
825806
def who(vardict=None):
826807
raise NotImplementedError

torch_np/_detail/implementations.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def triu_indices(n, k=0, m=None):
129129

130130
def diag_indices(n, ndim=2):
131131
idx = torch.arange(n)
132-
return (idx,)*ndim
132+
return (idx,) * ndim
133133

134134

135135
def diag_indices_from(tensor):
@@ -143,6 +143,34 @@ def diag_indices_from(tensor):
143143
return diag_indices(s[0], tensor.ndim)
144144

145145

146+
def fill_diagonal(tensor, t_val, wrap):
147+
# torch.Tensor.fill_diagonal_ only accepts scalars. Thus vendor the numpy source,
148+
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/index_tricks.py#L786-L917
149+
150+
if tensor.ndim < 2:
151+
raise ValueError("array must be at least 2-d")
152+
end = None
153+
if tensor.ndim == 2:
154+
# Explicit, fast formula for the common case. For 2-d arrays, we
155+
# accept rectangular ones.
156+
step = tensor.shape[1] + 1
157+
# This is needed to don't have tall matrix have the diagonal wrap.
158+
if not wrap:
159+
end = tensor.shape[1] * tensor.shape[1]
160+
else:
161+
# For more than d=2, the strided formula is only valid for arrays with
162+
# all dimensions equal, so we check first.
163+
s = tensor.shape
164+
if s[1:] != s[:-1]:
165+
raise ValueError("All dimensions of input must be of equal length")
166+
sz = torch.as_tensor(tensor.shape[:-1])
167+
step = 1 + (torch.cumprod(sz, 0)).sum()
168+
169+
# Write the value out into the diagonal.
170+
tensor.ravel()[:end:step] = t_val
171+
return tensor
172+
173+
146174
# ### splits ###
147175

148176

@@ -396,6 +424,26 @@ def meshgrid(*xi_tensors, copy=True, sparse=False, indexing="xy"):
396424
return output
397425

398426

427+
def indices(dimensions, dtype=int, sparse=False):
428+
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1691-L1791
429+
dimensions = tuple(dimensions)
430+
N = len(dimensions)
431+
shape = (1,) * N
432+
if sparse:
433+
res = tuple()
434+
else:
435+
res = torch.empty((N,) + dimensions, dtype=dtype)
436+
for i, dim in enumerate(dimensions):
437+
idx = torch.arange(dim, dtype=dtype).reshape(
438+
shape[:i] + (dim,) + shape[i + 1 :]
439+
)
440+
if sparse:
441+
res = res + (idx,)
442+
else:
443+
res[i] = idx
444+
return res
445+
446+
399447
def bincount(x_tensor, /, weights_tensor=None, minlength=0):
400448
int_dtype = _dtypes_impl.default_int_dtype
401449
(x_tensor,) = _util.cast_dont_broadcast((x_tensor,), int_dtype, casting="safe")
@@ -620,3 +668,16 @@ def where(condition, x, y):
620668
else:
621669
result = torch.where(condition, x, y)
622670
return result
671+
672+
673+
# ### dot and other linalg ###
674+
675+
def vdot(t_a, t_b, /):
676+
# torch only accepts 1D arrays, numpy ravels
677+
t_a, t_b = torch.atleast_1d(t_a, t_b)
678+
if t_a.ndim > 1:
679+
t_a = t_a.ravel()
680+
if t_b.ndim > 1:
681+
t_b = t_b.ravel()
682+
result = torch.vdot(t_a, t_b)
683+
return result

torch_np/_wrapper.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,19 @@ def diag_indices_from(arr):
394394
return tuple(asarray(x) for x in result)
395395

396396

397+
def fill_diagonal(a, val, wrap=False):
398+
tensor, t_val = _helpers.to_tensors(a, val)
399+
result = _impl.fill_diagonal(tensor, t_val, wrap)
400+
return asarray(result)
401+
402+
403+
@_decorators.dtype_to_torch
404+
def trace(a, offset=0, axis1=0, axis2=1, dtype=None, out=None):
405+
tensor = asarray(a).get()
406+
result = torch.diagonal(tensor, offset, dim1=axis1, dim2=axis2).sum(-1, dtype=dtype)
407+
return asarray(result)
408+
409+
397410
###### misc/unordered
398411

399412

@@ -483,6 +496,11 @@ def searchsorted(a, v, side="left", sorter=None):
483496
return arr.searchsorted(v, side=side, sorter=sorter)
484497

485498

499+
def vdot(a, b, /):
500+
t_a, t_b = _helpers.to_tensors(a, b)
501+
result = _impl.vdot(t_a, t_b)
502+
return result.item()
503+
486504
###### module-level queries of object properties
487505

488506

@@ -615,6 +633,15 @@ def meshgrid(*xi, copy=True, sparse=False, indexing="xy"):
615633
return [asarray(t) for t in output]
616634

617635

636+
@_decorators.dtype_to_torch
637+
def indices(dimensions, dtype=int, sparse=False):
638+
result = _impl.indices(dimensions, dtype=dtype, sparse=sparse)
639+
if sparse:
640+
return tuple(asarray(x) for x in result)
641+
else:
642+
return asarray(result)
643+
644+
618645
def nonzero(a):
619646
arr = asarray(a)
620647
return arr.nonzero()

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,6 @@ def test_take(self):
291291
out = np.take(a, indices)
292292
assert_equal(out, tgt)
293293

294-
@pytest.mark.xfail(reason="TODO implement trace(...)")
295294
def test_trace(self):
296295
c = [[1, 2], [3, 4], [5, 6]]
297296
assert_equal(np.trace(c), 5)
@@ -2627,7 +2626,7 @@ def test_exceptions(self):
26272626
assert_raises(np.AxisError, np.rollaxis, a, 4, 0)
26282627
assert_raises(np.AxisError, np.rollaxis, a, 0, 5)
26292628

2630-
@pytest.mark.xfail(reason="needs np.indices")
2629+
@pytest.mark.xfail(reason="needs fancy indexing")
26312630
def test_results(self):
26322631
a = np.arange(1*2*3*4).reshape(1, 2, 3, 4).copy()
26332632
aind = np.indices(a.shape)
@@ -2833,7 +2832,6 @@ def test_outer_out_param():
28332832
assert_equal(np.outer(arr2, arr3, out2), out2)
28342833

28352834

2836-
@pytest.mark.xfail(reason="TODO")
28372835
class TestIndices:
28382836

28392837
def test_simple(self):

torch_np/tests/numpy_tests/lib/test_index_tricks.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,12 @@
88
from pytest import raises as assert_raises #, assert_raises_regex,
99

1010
from numpy.lib.index_tricks import (
11-
mgrid, ogrid, ndenumerate, fill_diagonal,
12-
ndindex, r_, s_, ix_
11+
mgrid, ogrid, ndenumerate,
12+
ndindex, r_, ix_
1313
)
1414

15-
from torch_np import diag_indices, diag_indices_from
16-
from torch_np._detail._index_tricks import index_exp
15+
from torch_np import diag_indices, diag_indices_from, fill_diagonal
16+
from torch_np._detail._index_tricks import index_exp, s_
1717

1818

1919
@pytest.mark.xfail(reason='unravel_index not implemented')
@@ -358,7 +358,6 @@ def test_basic(self):
358358
[((0, 0), 1), ((0, 1), 2), ((1, 0), 3), ((1, 1), 4)])
359359

360360

361-
@pytest.mark.xfail(reason='s_ not implemented')
362361
class TestIndexExpression:
363362
def test_regression_1(self):
364363
# ticket #1196
@@ -422,10 +421,9 @@ def test_c_():
422421
assert_equal(a, [[1, 2, 3, 0, 0, 4, 5, 6]])
423422

424423

425-
@pytest.mark.xfail(reason='fill_diagonal not implemented')
426424
class TestFillDiagonal:
427425
def test_basic(self):
428-
a = np.zeros((3, 3), int)
426+
a = np.zeros((3, 3), dtype=int)
429427
fill_diagonal(a, 5)
430428
assert_array_equal(
431429
a, np.array([[5, 0, 0],
@@ -434,7 +432,7 @@ def test_basic(self):
434432
)
435433

436434
def test_tall_matrix(self):
437-
a = np.zeros((10, 3), int)
435+
a = np.zeros((10, 3), dtype=int)
438436
fill_diagonal(a, 5)
439437
assert_array_equal(
440438
a, np.array([[5, 0, 0],
@@ -450,7 +448,7 @@ def test_tall_matrix(self):
450448
)
451449

452450
def test_tall_matrix_wrap(self):
453-
a = np.zeros((10, 3), int)
451+
a = np.zeros((10, 3), dtype=int)
454452
fill_diagonal(a, 5, True)
455453
assert_array_equal(
456454
a, np.array([[5, 0, 0],
@@ -466,7 +464,7 @@ def test_tall_matrix_wrap(self):
466464
)
467465

468466
def test_wide_matrix(self):
469-
a = np.zeros((3, 10), int)
467+
a = np.zeros((3, 10), dtype=int)
470468
fill_diagonal(a, 5)
471469
assert_array_equal(
472470
a, np.array([[5, 0, 0, 0, 0, 0, 0, 0, 0, 0],
@@ -475,22 +473,22 @@ def test_wide_matrix(self):
475473
)
476474

477475
def test_operate_4d_array(self):
478-
a = np.zeros((3, 3, 3, 3), int)
476+
a = np.zeros((3, 3, 3, 3), dtype=int)
479477
fill_diagonal(a, 4)
480478
i = np.array([0, 1, 2])
481479
assert_equal(np.where(a != 0), (i, i, i, i))
482480

483481
def test_low_dim_handling(self):
484482
# raise error with low dimensionality
485-
a = np.zeros(3, int)
486-
with assert_raises_regex(ValueError, "at least 2-d"):
483+
a = np.zeros(3, dtype=int)
484+
with assert_raises(ValueError):
487485
fill_diagonal(a, 5)
488486

489487
def test_hetero_shape_handling(self):
490488
# raise error with high dimensionality and
491489
# shape mismatch
492-
a = np.zeros((3,3,7,3), int)
493-
with assert_raises_regex(ValueError, "equal length"):
490+
a = np.zeros((3,3,7,3), dtype=int)
491+
with assert_raises(ValueError):
494492
fill_diagonal(a, 2)
495493

496494

0 commit comments

Comments
 (0)