Skip to content

Commit a781856

Browse files
committed
ENH: add where, searchsorted, inner and outer, diag_indices{_from}
1 parent c3095ab commit a781856

File tree

5 files changed

+86
-72
lines changed

5 files changed

+86
-72
lines changed

autogen/numpy_api_dump.py

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -216,22 +216,6 @@ def deprecate_with_doc(msg):
216216
raise NotImplementedError
217217

218218

219-
def diag_indices(n, ndim=2):
220-
raise NotImplementedError
221-
222-
223-
def diag_indices_from(arr):
224-
raise NotImplementedError
225-
226-
227-
def diagflat(v, k=0):
228-
raise NotImplementedError
229-
230-
231-
def diagonal(a, offset=0, axis1=0, axis2=1):
232-
raise NotImplementedError
233-
234-
235219
def digitize(x, bins, right=False):
236220
raise NotImplementedError
237221

@@ -397,10 +381,6 @@ def indices(dimensions, dtype=int, sparse=False):
397381
raise NotImplementedError
398382

399383

400-
def inner(a, b, /):
401-
raise NotImplementedError
402-
403-
404384
def insert(arr, obj, values, axis=None):
405385
raise NotImplementedError
406386

@@ -417,10 +397,6 @@ def is_busday(dates, weekmask="1111100", holidays=None, busdaycal=None, out=None
417397
raise NotImplementedError
418398

419399

420-
def isclose(a, b, rtol=1e-05, atol=1e-08, equal_nan=False):
421-
raise NotImplementedError
422-
423-
424400
def isfortran(a):
425401
raise NotImplementedError
426402

@@ -437,14 +413,6 @@ def issctype(rep):
437413
raise NotImplementedError
438414

439415

440-
def issubclass_(arg1, arg2):
441-
raise NotImplementedError
442-
443-
444-
def issubdtype(arg1, arg2):
445-
raise NotImplementedError
446-
447-
448416
def issubsctype(arg1, arg2):
449417
raise NotImplementedError
450418

@@ -624,10 +592,6 @@ def obj2sctype(rep, default=None):
624592
raise NotImplementedError
625593

626594

627-
def outer(a, b, out=None):
628-
raise NotImplementedError
629-
630-
631595
def packbits(a, /, axis=None, bitorder="big"):
632596
raise NotImplementedError
633597

@@ -696,10 +660,6 @@ def putmask(a, mask, values):
696660
raise NotImplementedError
697661

698662

699-
def ravel(a, order="C"):
700-
raise NotImplementedError
701-
702-
703663
def recfromcsv(fname, **kwargs):
704664
raise NotImplementedError
705665

@@ -754,10 +714,6 @@ def sctype2char(sctype):
754714
raise NotImplementedError
755715

756716

757-
def searchsorted(a, v, side="left", sorter=None):
758-
raise NotImplementedError
759-
760-
761717
def select(condlist, choicelist, default=0):
762718
raise NotImplementedError
763719

@@ -810,23 +766,10 @@ def shares_memory(a, b, max_work=None):
810766
def show():
811767
raise NotImplementedError
812768

813-
814-
def sinc(x):
815-
raise NotImplementedError
816-
817-
818-
def sometrue(*args, **kwargs):
819-
raise NotImplementedError
820-
821-
822769
def sort_complex(a):
823770
raise NotImplementedError
824771

825772

826-
def swapaxes(a, axis1, axis2):
827-
raise NotImplementedError
828-
829-
830773
def take(a, indices, axis=None, out=None, mode="raise"):
831774
raise NotImplementedError
832775

@@ -879,9 +822,5 @@ def vdot(a, b, /):
879822
raise NotImplementedError
880823

881824

882-
def where(condition, x, y, /):
883-
raise NotImplementedError
884-
885-
886825
def who(vardict=None):
887826
raise NotImplementedError

torch_np/_detail/implementations.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,22 @@ def triu_indices(n, k=0, m=None):
127127
return result
128128

129129

130+
def diag_indices(n, ndim=2):
131+
idx = torch.arange(n)
132+
return (idx,)*ndim
133+
134+
135+
def diag_indices_from(tensor):
136+
if not tensor.ndim >= 2:
137+
raise ValueError("input array must be at least 2-d")
138+
# For more than d=2, the strided formula is only valid for arrays with
139+
# all dimensions equal, so we check first.
140+
s = tensor.shape
141+
if any(s[1:] != s[:-1]):
142+
raise ValueError("All dimensions of input must be of equal length")
143+
return diag_indices(s[0], tensor.ndim)
144+
145+
130146
# ### splits ###
131147

132148

@@ -589,3 +605,18 @@ def argsort(tensor, axis=-1, kind=None, order=None):
589605
tensor, axis, stable = _sort_helper(tensor, axis, kind, order)
590606
result = torch.argsort(tensor, dim=axis, stable=stable)
591607
return result
608+
609+
610+
# ### logic and selection ###
611+
612+
613+
def where(condition, x, y):
614+
selector = (x is None) == (y is None)
615+
if not selector:
616+
raise ValueError("either both or neither of x and y should be given")
617+
618+
if x is None and y is None:
619+
result = torch.where(condition)
620+
else:
621+
result = torch.where(condition, x, y)
622+
return result

torch_np/_ndarray.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,10 @@ def repeat(self, repeats, axis=None):
395395
axis_none_ravel_wrapper(dtype_to_torch(_reductions.cumsum))
396396
)
397397

398+
def diagonal(self, offset=0, axis1=0, axis2=1):
399+
result = torch.diagonal(self._tensor, offset, axis1, axis2)
400+
return asarray(result)
401+
398402
### indexing ###
399403
@staticmethod
400404
def _upcast_int_indices(index):
@@ -427,6 +431,11 @@ def argsort(self, axis=-1, kind=None, order=None):
427431
result = _impl.argsort(self.tensor, axis, kind, order)
428432
return asarray(result)
429433

434+
def searchsorted(self, v, side="left", sorter=None):
435+
v_t, sorter_t = _helpers.to_tensors_or_none(v, sorter)
436+
result = torch.searchsorted(self._tensor, v_t, side=side, sorter=sorter_t)
437+
return asarray(result)
438+
430439

431440
# This is the ideally the only place which talks to ndarray directly.
432441
# The rest goes through asarray (preferred) or array.

torch_np/_wrapper.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,28 @@ def diag(v, k=0):
372372
return asarray(result)
373373

374374

375+
def diagonal(a, offset=0, axis1=0, axis2=1):
376+
arr = asarray(a)
377+
return arr.diagonal(offset, axis1, axis2)
378+
379+
380+
def diagflat(v, k=0):
381+
tensor = asarray(v).get()
382+
result = torch.diagflat(tensor, k)
383+
return result
384+
385+
386+
def diag_indices(n, ndim=2):
387+
result = _impl.diag_indices(n, ndim)
388+
return tuple(asarray(x) for x in result)
389+
390+
391+
def diag_indices_from(arr):
392+
tensor = asarray(arr).get()
393+
result = _impl.diag_indices_from(tensor)
394+
return tuple(asarray(x) for x in result)
395+
396+
375397
###### misc/unordered
376398

377399

@@ -446,17 +468,19 @@ def bincount(x, /, weights=None, minlength=0):
446468
return asarray(result)
447469

448470

449-
# YYY: pattern: sequence of arrays
450471
def where(condition, x=None, y=None, /):
451-
selector = (x is None) == (y is None)
452-
if not selector:
453-
raise ValueError("either both or neither of x and y should be given")
454-
condition = asarray(condition).get()
455-
if x is None and y is None:
456-
return tuple(asarray(_) for _ in torch.where(condition))
457-
x = asarray(condition).get()
458-
y = asarray(condition).get()
459-
return asarray(torch.where(condition, x, y))
472+
cond_t, x_t, y_t = _helpers.to_tensors_or_none(condition, x, y)
473+
result = _impl.where(cond_t, x_t, y_t)
474+
if isinstance(result, tuple):
475+
# single-argument where(condition)
476+
return tuple(asarray(x) for x in result)
477+
else:
478+
return asarray(result)
479+
480+
481+
def searchsorted(a, v, side="left", sorter=None):
482+
arr = asarray(a)
483+
return arr.searchsorted(v, side=side, sorter=sorter)
460484

461485

462486
###### module-level queries of object properties
@@ -847,6 +871,18 @@ def median(a, axis=None, out=None, overwrite_input=False, keepdims=False):
847871
)
848872

849873

874+
def inner(a, b, /):
875+
a_t, b_t = _helpers.to_tensors(a, b)
876+
result = torch.inner(a_t, b_t)
877+
return asarray(result)
878+
879+
880+
def outer(a, b, out=None):
881+
a_t, b_t = _helpers.to_tensors(a, b)
882+
result = torch.outer(a_t, b_t)
883+
return _helpers.result_or_out(result, out)
884+
885+
850886
@asarray_replacer()
851887
def nanmean(a, axis=None, dtype=None, out=None, keepdims=NoValue, *, where=NoValue):
852888
if where is not NoValue:

torch_np/tests/numpy_tests/core/test_numeric.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def test_cumproduct(self):
125125
A = [[1, 2, 3], [4, 5, 6]]
126126
assert_(np.all(np.cumproduct(A) == np.array([1, 2, 6, 24, 120, 720])))
127127

128-
@pytest.mark.xfail(reason="TODO implement diagonal(...)")
129128
def test_diagonal(self):
130129
a = [[0, 1, 2, 3],
131130
[4, 5, 6, 7],

0 commit comments

Comments
 (0)