Skip to content

Commit a6eb581

Browse files
committed
MAINT: add a comment on axis=() in reductions
1 parent 27cb10f commit a6eb581

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torch_np/_detail/_reductions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@ def wrapped(tensor, axis, *args, **kwds):
3535
axis = _util.normalize_axis_tuple(axis, tensor.ndim)
3636

3737
if axis == ():
38+
# NumPy does essentially an identity operation:
39+
# >>> np.sum(np.ones(2), axis=())
40+
# array([1., 1.])
41+
# So we insert a length-one axis and run the reduction along it.
3842
newshape = _util.expand_shape(tensor.shape, axis=0)
3943
tensor = tensor.reshape(newshape)
4044
axis = (0,)

0 commit comments

Comments
 (0)