-
Notifications
You must be signed in to change notification settings - Fork 4
ENH: add a naive divmod, un-xfail relevant tests #84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,14 @@ | ||
""" "Normalize" arguments: convert array_likes to tensors, dtypes to torch dtypes and so on. | ||
""" | ||
import functools | ||
import inspect | ||
import operator | ||
import typing | ||
from typing import Optional, Sequence, Union | ||
|
||
import torch | ||
|
||
from . import _helpers | ||
from . import _dtypes, _helpers | ||
|
||
ArrayLike = typing.TypeVar("ArrayLike") | ||
DTypeLike = typing.TypeVar("DTypeLike") | ||
|
@@ -22,10 +24,6 @@ | |
NDArrayOrSequence = Union[NDArray, Sequence[NDArray]] | ||
OutArray = typing.TypeVar("OutArray") | ||
|
||
import inspect | ||
|
||
from . import _dtypes | ||
|
||
|
||
def normalize_array_like(x, name=None): | ||
(tensor,) = _helpers.to_tensors(x) | ||
|
@@ -52,7 +50,7 @@ def normalize_dtype(dtype, name=None): | |
return torch_dtype | ||
|
||
|
||
def normalize_subok_like(arg, name): | ||
def normalize_subok_like(arg, name="subok"): | ||
if arg: | ||
raise ValueError(f"'{name}' parameter is not supported.") | ||
|
||
|
@@ -87,7 +85,6 @@ def normalize_ndarray(arg, name=None): | |
AxisLike: normalize_axis_like, | ||
} | ||
|
||
import functools | ||
|
||
_sentinel = object() | ||
|
||
|
@@ -97,7 +94,7 @@ def normalize_this(arg, parm, return_on_failure=_sentinel): | |
normalizer = normalizers.get(parm.annotation, None) | ||
if normalizer: | ||
try: | ||
return normalizer(arg) | ||
return normalizer(arg, parm.name) | ||
except Exception as exc: | ||
if return_on_failure is not _sentinel: | ||
return return_on_failure | ||
|
@@ -108,6 +105,44 @@ def normalize_this(arg, parm, return_on_failure=_sentinel): | |
return arg | ||
|
||
|
||
# postprocess return values | ||
|
||
|
||
def postprocess_ndarray(result, **kwds): | ||
return _helpers.array_from(result) | ||
|
||
|
||
def postprocess_out(result, **kwds): | ||
result, out = result | ||
return _helpers.result_or_out(result, out, **kwds) | ||
|
||
|
||
def postprocess_tuple(result, **kwds): | ||
return _helpers.tuple_arrays_from(result) | ||
|
||
|
||
def postprocess_list(result, **kwds): | ||
return list(_helpers.tuple_arrays_from(result)) | ||
|
||
|
||
def postprocess_variadic(result, **kwds): | ||
# a variadic return: a single NDArray or tuple/list of NDArrays, e.g. atleast_1d | ||
if isinstance(result, (tuple, list)): | ||
seq = type(result) | ||
return seq(_helpers.tuple_arrays_from(result)) | ||
else: | ||
return _helpers.array_from(result) | ||
|
||
|
||
postprocessors = { | ||
NDArray: postprocess_ndarray, | ||
OutArray: postprocess_out, | ||
NDArrayOrSequence: postprocess_variadic, | ||
tuple[NDArray]: postprocess_tuple, | ||
list[NDArray]: postprocess_list, | ||
} | ||
|
||
|
||
Comment on lines
+108
to
+145
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is duplicated from the other PR. |
||
def normalizer(_func=None, *, return_on_failure=_sentinel, promote_scalar_out=False): | ||
def normalizer_inner(func): | ||
@functools.wraps(func) | ||
|
@@ -154,33 +189,17 @@ def wrapped(*args, **kwds): | |
raise TypeError( | ||
f"{func.__name__}() takes {len(ba.args)} positional argument but {len(args)} were given." | ||
) | ||
|
||
# finally, pass normalized arguments through | ||
result = func(*ba.args, **ba.kwargs) | ||
|
||
# handle returns | ||
r = sig.return_annotation | ||
if r == NDArray: | ||
return _helpers.array_from(result) | ||
elif r == inspect._empty: | ||
return result | ||
elif hasattr(r, "__origin__") and r.__origin__ in (list, tuple): | ||
# this is tuple[NDArray] or list[NDArray] | ||
# XXX: change to separate tuple and list normalizers? | ||
return r.__origin__(_helpers.tuple_arrays_from(result)) | ||
elif r == NDArrayOrSequence: | ||
# a variadic return: a single NDArray or tuple/list of NDArrays, e.g. atleast_1d | ||
if isinstance(result, (tuple, list)): | ||
seq = type(result) | ||
return seq(_helpers.tuple_arrays_from(result)) | ||
else: | ||
return _helpers.array_from(result) | ||
elif r == OutArray: | ||
result, out = result | ||
return _helpers.result_or_out( | ||
result, out, promote_scalar=promote_scalar_out | ||
) | ||
else: | ||
raise ValueError(f"Unknown return annotation {return_annotation}") | ||
postprocess = postprocessors.get(r, None) | ||
if postprocess: | ||
kwds = {"promote_scalar": promote_scalar_out} | ||
result = postprocess(result, **kwds) | ||
return result | ||
Comment on lines
+198
to
+202
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also dupped? |
||
|
||
return wrapped | ||
|
||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that here (and in general) if we want these functions to be differentiable in PyTorch, we should not use their
out=
variant. We should implement theout=
behaviour manually.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here these are wrapper
floor_divide
andremainder
, sotorch.remainder
andtorch.floor_divide
never see theout=
.Also NB: this will need a rework anyway, apparently there are more (out1, out2) ufuncs, ldexp and frexp. So
Tuple[NDArray]
return annotation will need to appear after we settle on the generic machinery.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right! And sure, let's just push this discussion then. I still think that the
out=
machinery can be implemented in a generic way.