-
Notifications
You must be signed in to change notification settings - Fork 137
Conversation
@dan-zheng I have change precondition of |
@@ -2229,7 +2248,9 @@ public extension Tensor where Scalar: Numeric { | |||
exclusive: Bool = false, | |||
reverse: Bool = false | |||
) -> Tensor { | |||
_Raw.cumsum(self, axis: axis, exclusive: exclusive, reverse: reverse) | |||
precondition(axis.rank == 0, "Axis must have rank 0.") | |||
precondition(areAxesInRange(axis), "All axis must be in the range `[-rank, rank)`.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In cumprod
and cumsum
, if the axis is plural then I think the documentation must read, All axes must be in range ...,
Like the previous conditions.
If it's a singular value, is calling the function necessary instead of just doing the check locally ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously it was given axis instead of axes and also the same comment was written above that axis
must be in the range -rank..<rank
. So I followed the same pattern in writing the precondition. Changing the input variable from axis to axes may conflict with some written test cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't have to write all axis
, it's just referring to a singular axis. You can start the statement with Axis must be in range...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think rank 0
tensor means that there is only one value in axis tensor. So previously it was written as singular.
var input: Tensor<Int32> = Tensor(5)
input.rank
is giving 0
as output
@dan-zheng can you please review it? |
@@ -2229,7 +2248,9 @@ public extension Tensor where Scalar: Numeric { | |||
exclusive: Bool = false, | |||
reverse: Bool = false | |||
) -> Tensor { | |||
_Raw.cumsum(self, axis: axis, exclusive: exclusive, reverse: reverse) | |||
precondition(axis.rank == 0, "Axis must have rank 0.") | |||
precondition(areAxesInRange(axis), "Axis must be in the range `[-rank, rank)`.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dan-zheng I have change precondition of
areAxesInRange()
as it is not necessary that the input axes tensor must haverank 1
. Eg:cumlativeSum()
andmoments
the former requirerank 0
tensor while later takesrank 1
tensor as input
Thanks for the clarification!
The correct fix is actually to add another precondition helper method called isAxisInRange
taking a Tensor<Int32>
, not to call areAxesInRange(_: Tensor<Int32>)
on a scalar Tensor<Int32>
representing a single axis.
Tensor.isAxisInRange(_: Tensor<Int32>)
should have a precondition:
precondition(axis.rank == 0, "Axis must have rank 0.")
And Tensor.areAxesInRange(_: Tensor<Int32>)
should have the old precondition:
precondition(axes.rank == 1, "Axis must have rank 1.")
precondition(areAxesInRange(axis), "Axis must be in the range `[-rank, rank)`.") | |
precondition(isAxisInRange(axis), "Axis must be in the range `[-rank, rank)`.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
Weirdly, there's a test failure:
Is this a flaky test? Rerunning tests now. |
#517
Precondiiton for function: