Skip to content

Commit 8f37665

Browse files
committed
BUG: random: size=() returns a 0D array, not scalar
1 parent 12e0770 commit 8f37665

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

torch_np/random.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
]
3030

3131

32-
def array_or_scalar(values, py_type=float):
33-
if values.numel() == 1:
32+
def array_or_scalar(values, py_type=float, size=None):
33+
if size is None:
3434
return py_type(values.item())
3535
else:
3636
return asarray(values)
@@ -45,7 +45,7 @@ def random_sample(size=None):
4545
if size is None:
4646
size = ()
4747
values = torch.empty(size, dtype=_default_dtype).uniform_()
48-
return array_or_scalar(values)
48+
return array_or_scalar(values, size=size)
4949

5050

5151
def rand(*size):
@@ -60,19 +60,19 @@ def uniform(low=0.0, high=1.0, size=None):
6060
if size is None:
6161
size = ()
6262
values = torch.empty(size, dtype=_default_dtype).uniform_(low, high)
63-
return array_or_scalar(values)
63+
return array_or_scalar(values, size=size)
6464

6565

6666
def randn(*size):
6767
values = torch.randn(size, dtype=_default_dtype)
68-
return array_or_scalar(values)
68+
return array_or_scalar(values, size=size)
6969

7070

7171
def normal(loc=0.0, scale=1.0, size=None):
7272
if size is None:
7373
size = ()
7474
values = torch.empty(size, dtype=_default_dtype).normal_(loc, scale)
75-
return array_or_scalar(values)
75+
return array_or_scalar(values, size=size)
7676

7777

7878
def shuffle(x):
@@ -90,7 +90,7 @@ def randint(low, high=None, size=None):
9090
if high is None:
9191
low, high = 0, low
9292
values = torch.randint(low, high, size=size)
93-
return array_or_scalar(values, int)
93+
return array_or_scalar(values, int, size=size)
9494

9595

9696
def choice(a, size=None, replace=True, p=None):

torch_np/tests/numpy_tests/core/test_multiarray.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3627,7 +3627,6 @@ def test_assign_mask2(self):
36273627
assert_array_equal(x, np.array([[1, 10, 3, 4], [5, 6, 7, 8]]))
36283628

36293629

3630-
@pytest.mark.xfail(reason='TODO')
36313630
class TestArgmaxArgminCommon:
36323631

36333632
sizes = [(), (3,), (3, 2), (2, 3),

0 commit comments

Comments
 (0)