-
Notifications
You must be signed in to change notification settings - Fork 4
ENH: add a naive divmod, un-xfail relevant tests #84
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Two more ufuncs with two optional out= inputs and two OutArray outputs : ldexp and frexp. Need a dedicated result annotation, apparently. |
35e647b
to
bc9bd0c
Compare
quot = floor_divide(x1, x2, out=out1, **kwds) | ||
rem = remainder(x1, x2, out=out2, **kwds) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that here (and in general) if we want these functions to be differentiable in PyTorch, we should not use their out=
variant. We should implement the out=
behaviour manually.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here these are wrapper floor_divide
and remainder
, so torch.remainder
and torch.floor_divide
never see the out=
.
Also NB: this will need a rework anyway, apparently there are more (out1, out2) ufuncs, ldexp and frexp. So Tuple[NDArray]
return annotation will need to appear after we settle on the generic machinery.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right! And sure, let's just push this discussion then. I still think that the out=
machinery can be implemented in a generic way.
# postprocess return values | ||
|
||
|
||
def postprocess_ndarray(result, **kwds): | ||
return _helpers.array_from(result) | ||
|
||
|
||
def postprocess_out(result, **kwds): | ||
result, out = result | ||
return _helpers.result_or_out(result, out, **kwds) | ||
|
||
|
||
def postprocess_tuple(result, **kwds): | ||
return _helpers.tuple_arrays_from(result) | ||
|
||
|
||
def postprocess_list(result, **kwds): | ||
return list(_helpers.tuple_arrays_from(result)) | ||
|
||
|
||
def postprocess_variadic(result, **kwds): | ||
# a variadic return: a single NDArray or tuple/list of NDArrays, e.g. atleast_1d | ||
if isinstance(result, (tuple, list)): | ||
seq = type(result) | ||
return seq(_helpers.tuple_arrays_from(result)) | ||
else: | ||
return _helpers.array_from(result) | ||
|
||
|
||
postprocessors = { | ||
NDArray: postprocess_ndarray, | ||
OutArray: postprocess_out, | ||
NDArrayOrSequence: postprocess_variadic, | ||
tuple[NDArray]: postprocess_tuple, | ||
list[NDArray]: postprocess_list, | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is duplicated from the other PR.
postprocess = postprocessors.get(r, None) | ||
if postprocess: | ||
kwds = {"promote_scalar": promote_scalar_out} | ||
result = postprocess(result, **kwds) | ||
return result |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also dupped?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's figure out what's going on with that duplicated return code, otherwise the implementations LGTM. Note that divmod
can also be implemented via a divison and a substraction, but I'm happy with this implementation as well as it's simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant
That's just an artefact from giving optional return arguments to |
Merged in #91 |
Implement
np.divmod(x1, x2)
andndarray.__divmod__
. These are unusual in that they allow two output arrays, so it's either complicate and obfuscate the ufunc machinery just for these two, or special case. This PR opts for the latter :-).Since there is no pytorch-native divmod (cf pytorch/pytorch#90820 ), this PR just does
(x1 // x2, x1 % x2)
in the wrapper level.This PR is on top of gh-83, for clarity.