Skip to content

Commit 5a7885a

Browse files
committed
fix scalar issues
1 parent fe0a7ec commit 5a7885a

File tree

2 files changed

+13
-25
lines changed

2 files changed

+13
-25
lines changed

pytensor/scalar/basic.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,17 +1140,23 @@ def output_types(self, types):
11401140
else:
11411141
raise NotImplementedError(f"Cannot calculate the output types for {self}")
11421142

1143+
@staticmethod
1144+
def _cast_scalar(x, dtype):
1145+
if hasattr(x, "astype"):
1146+
return x.astype(dtype)
1147+
else:
1148+
return x
1149+
11431150
def perform(self, node, inputs, output_storage):
11441151
if self.nout == 1:
1145-
output_storage[0][0] = np.asarray(
1146-
self.impl(*inputs),
1147-
dtype=node.outputs[0].dtype,
1148-
)
1152+
dtype = node.outputs[0].dtype
1153+
output_storage[0][0] = self._cast_scalar(self.impl(*inputs), dtype)
11491154
else:
11501155
variables = from_return_values(self.impl(*inputs))
11511156
assert len(variables) == len(output_storage)
11521157
for out, storage, variable in zip(node.outputs, output_storage, variables):
1153-
storage[0] = np.asarray(variable, dtype=out.dtype)
1158+
dtype = out.dtype
1159+
storage[0] = self._cast_scalar(variable, dtype)
11541160

11551161
def impl(self, *inputs):
11561162
raise MethodNotDefined("impl", type(self), self.__class__.__name__)

pytensor/tensor/elemwise.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -767,34 +767,16 @@ def perform(self, node, inputs, output_storage):
767767
for i, (variable, storage, nout) in enumerate(
768768
zip(variables, output_storage, node.outputs)
769769
):
770-
if getattr(variable, "dtype", "") == "object":
771-
# Since numpy 1.6, function created with numpy.frompyfunc
772-
# always return an ndarray with dtype object
773-
variable = np.asarray(variable, dtype=nout.dtype)
770+
storage[0] = variable = np.asarray(variable, dtype=nout.dtype)
774771

775772
if i in self.inplace_pattern:
776773
odat = inputs[self.inplace_pattern[i]]
777774
odat[...] = variable
778775
storage[0] = odat
779776

780-
# Sometimes NumPy return a Python type.
781-
# Some PyTensor op return a different dtype like floor, ceil,
782-
# trunc, eq, ...
783-
elif not isinstance(variable, np.ndarray) or variable.dtype != nout.dtype:
784-
variable = np.asarray(variable, nout.dtype)
785-
# The next line is needed for numpy 1.9. Otherwise
786-
# there are tests that fail in DebugMode.
787-
# Normally we would call pytensor.misc._asarray, but it
788-
# is faster to inline the code. We know that the dtype
789-
# are the same string, just different typenum.
790-
if np.dtype(nout.dtype).num != variable.dtype.num:
791-
variable = variable.view(dtype=nout.dtype)
792-
storage[0] = variable
793777
# numpy.real return a view!
794-
elif not variable.flags.owndata:
778+
if not variable.flags.owndata:
795779
storage[0] = variable.copy()
796-
else:
797-
storage[0] = variable
798780

799781
@staticmethod
800782
def _check_runtime_broadcast(node, inputs):

0 commit comments

Comments
 (0)