Skip to content

Commit 2fe8a9a

Browse files
committed
move array_from/tuple_arrays_from to normalizations.py
1 parent 6d40249 commit 2fe8a9a

File tree

3 files changed

+16
-17
lines changed

3 files changed

+16
-17
lines changed

torch_np/_helpers.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -81,19 +81,11 @@ def result_or_out(result_tensor, out_array=None, promote_scalar=False):
8181
out_tensor.copy_(result_tensor)
8282
return out_array
8383
else:
84-
return array_from(result_tensor)
84+
from ._ndarray import ndarray
8585

86+
return ndarray(result_tensor)
8687

87-
def array_from(tensor, base=None):
88-
from ._ndarray import ndarray
8988

90-
return ndarray(tensor)
91-
92-
93-
def tuple_arrays_from(result):
94-
from ._ndarray import asarray
95-
96-
return tuple(asarray(x) for x in result)
9789

9890

9991
# ### Various ways of converting array-likes to tensors ###

torch_np/_normalizations.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,15 @@ def normalize_this(arg, parm, return_on_failure=_sentinel):
107107

108108
# postprocess return values
109109

110+
def _array_from(tensor):
111+
# hide the circular import
112+
from ._ndarray import ndarray
113+
114+
return ndarray(tensor)
115+
110116

111117
def postprocess_ndarray(result, **kwds):
112-
return _helpers.array_from(result)
118+
return _array_from(result)
113119

114120

115121
def postprocess_out(result, **kwds):
@@ -118,20 +124,20 @@ def postprocess_out(result, **kwds):
118124

119125

120126
def postprocess_tuple(result, **kwds):
121-
return _helpers.tuple_arrays_from(result)
127+
return tuple(_array_from(x) for x in result)
122128

123129

124130
def postprocess_list(result, **kwds):
125-
return list(_helpers.tuple_arrays_from(result))
131+
return list(_array_from(x) for x in result)
126132

127133

128134
def postprocess_variadic(result, **kwds):
129135
# a variadic return: a single NDArray or tuple/list of NDArrays, e.g. atleast_1d
130136
if isinstance(result, (tuple, list)):
131137
seq = type(result)
132-
return seq(_helpers.tuple_arrays_from(result))
138+
return seq(_array_from(x) for x in result)
133139
else:
134-
return _helpers.array_from(result)
140+
return _array_from(result)
135141

136142

137143
postprocessors = {

torch_np/random.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
import torch
1212

13-
from . import _helpers
13+
from ._ndarray import ndarray
14+
1415
from ._detail import _dtypes_impl, _util
1516
from ._normalizations import ArrayLike, NDArray, normalizer
1617

@@ -35,7 +36,7 @@ def array_or_scalar(values, py_type=float, return_scalar=False):
3536
if return_scalar:
3637
return py_type(values.item())
3738
else:
38-
return _helpers.array_from(values)
39+
return ndarray(values)
3940

4041

4142
def seed(seed=None):

0 commit comments

Comments
 (0)