Skip to content

Commit 6ef1452

Browse files
danhphanricardoV94
authored andcommitted
Fix tensor.zeros and tensor.ones with symbolic scalar
1 parent 5a0fb0e commit 6ef1452

File tree

2 files changed

+27
-2
lines changed

2 files changed

+27
-2
lines changed

aesara/tensor/basic.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -978,7 +978,10 @@ def zeros_like(model, dtype=None, opt=False):
978978

979979
def zeros(shape, dtype=None):
980980
"""Create a `TensorVariable` filled with zeros, closer to NumPy's syntax than ``alloc``."""
981-
if not isinstance(shape, (np.ndarray, Sequence, TensorVariable)):
981+
if not (
982+
isinstance(shape, (np.ndarray, Sequence))
983+
or (isinstance(shape, TensorVariable) and shape.ndim > 0)
984+
):
982985
shape = [shape]
983986
if dtype is None:
984987
dtype = config.floatX
@@ -987,7 +990,10 @@ def zeros(shape, dtype=None):
987990

988991
def ones(shape, dtype=None):
989992
"""Create a `TensorVariable` filled with ones, closer to NumPy's syntax than ``alloc``."""
990-
if not isinstance(shape, (np.ndarray, Sequence, TensorVariable)):
993+
if not (
994+
isinstance(shape, (np.ndarray, Sequence))
995+
or (isinstance(shape, TensorVariable) and shape.ndim > 0)
996+
):
991997
shape = [shape]
992998
if dtype is None:
993999
dtype = config.floatX
@@ -4274,6 +4280,11 @@ def empty(shape, dtype=None):
42744280
Desired output data-type for the array, e.g, `numpy.int8`. Default is
42754281
`numpy.float64`.
42764282
"""
4283+
if not (
4284+
isinstance(shape, (np.ndarray, Sequence))
4285+
or (isinstance(shape, TensorVariable) and shape.ndim > 0)
4286+
):
4287+
shape = [shape]
42774288
if dtype is None:
42784289
dtype = config.floatX
42794290
return AllocEmpty(dtype)(*shape)

tests/tensor/test_basic.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,11 @@ def test_ones(self):
754754
for shp in [[], 1, [1], [1, 2], [1, 2, 3], np.r_[1, 2, 3]]:
755755
ones = aesara.function([], [at.ones(shp)], mode=self.mode)
756756
assert np.allclose(ones(), np.ones(shp))
757+
# When shape is a TensorConstant
758+
ones_const = aesara.function(
759+
[], [at.ones(at.constant(shp))], mode=self.mode
760+
)
761+
assert np.allclose(ones_const(), np.ones(shp))
757762

758763
# scalar doesn't have to be provided as input
759764
x = scalar()
@@ -771,6 +776,11 @@ def test_zeros(self):
771776
for shp in [[], 1, [1], [1, 2], [1, 2, 3], np.r_[1, 2, 3]]:
772777
zeros = aesara.function([], [at.zeros(shp)], mode=self.mode)
773778
assert np.allclose(zeros(), np.zeros(shp))
779+
# When shape is a TensorConstant
780+
zeros_const = aesara.function(
781+
[], [at.zeros(at.constant(shp))], mode=self.mode
782+
)
783+
assert np.allclose(zeros_const(), np.zeros(shp))
774784

775785
# scalar doesn't have to be provided as input
776786
x = scalar()
@@ -4381,6 +4391,10 @@ def test_empty():
43814391
assert out.shape == (2, 3)
43824392
assert out.dtype == "float32"
43834393

4394+
empty_at = at.empty(3)
4395+
res = aesara.function([], empty_at)()
4396+
assert res.shape == (3,)
4397+
43844398
empty_at = at.empty((2, 3), dtype=None)
43854399
res = aesara.function([], empty_at)()
43864400
assert res.shape == (2, 3)

0 commit comments

Comments
 (0)