29
29
]
30
30
31
31
32
- def array_or_scalar (values , py_type = float ):
33
- if values . numel () == 1 :
32
+ def array_or_scalar (values , py_type = float , size = None ):
33
+ if size is None :
34
34
return py_type (values .item ())
35
35
else :
36
36
return asarray (values )
@@ -45,7 +45,7 @@ def random_sample(size=None):
45
45
if size is None :
46
46
size = ()
47
47
values = torch .empty (size , dtype = _default_dtype ).uniform_ ()
48
- return array_or_scalar (values )
48
+ return array_or_scalar (values , size = size )
49
49
50
50
51
51
def rand (* size ):
@@ -60,19 +60,19 @@ def uniform(low=0.0, high=1.0, size=None):
60
60
if size is None :
61
61
size = ()
62
62
values = torch .empty (size , dtype = _default_dtype ).uniform_ (low , high )
63
- return array_or_scalar (values )
63
+ return array_or_scalar (values , size = size )
64
64
65
65
66
66
def randn (* size ):
67
67
values = torch .randn (size , dtype = _default_dtype )
68
- return array_or_scalar (values )
68
+ return array_or_scalar (values , size = size )
69
69
70
70
71
71
def normal (loc = 0.0 , scale = 1.0 , size = None ):
72
72
if size is None :
73
73
size = ()
74
74
values = torch .empty (size , dtype = _default_dtype ).normal_ (loc , scale )
75
- return array_or_scalar (values )
75
+ return array_or_scalar (values , size = size )
76
76
77
77
78
78
def shuffle (x ):
@@ -90,7 +90,7 @@ def randint(low, high=None, size=None):
90
90
if high is None :
91
91
low , high = 0 , low
92
92
values = torch .randint (low , high , size = size )
93
- return array_or_scalar (values , int )
93
+ return array_or_scalar (values , int , size = size )
94
94
95
95
96
96
def choice (a , size = None , replace = True , p = None ):
0 commit comments