Skip to content

Commit d10f245

Browse files
authored
Fix wrong dtype arguments (#1456)
1 parent 0ea61bc commit d10f245

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

pytensor/gradient.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2302,7 +2302,7 @@ def _is_zero(x):
23022302

23032303
class ZeroGrad(ViewOp):
23042304
def grad(self, args, g_outs):
2305-
return [g_out.zeros_like(g_out) for g_out in g_outs]
2305+
return [g_out.zeros_like() for g_out in g_outs]
23062306

23072307
def R_op(self, inputs, eval_points):
23082308
if eval_points[0] is None:

pytensor/scalar/basic.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3237,7 +3237,7 @@ def L_op(self, inputs, outputs, gout):
32373237
else:
32383238
return [x.zeros_like()]
32393239

3240-
return (gz * exp2(x) * log(np.array(2, dtype=x.type)),)
3240+
return (gz * exp2(x) * log(np.array(2, dtype=x.dtype)),)
32413241

32423242
def c_code(self, node, name, inputs, outputs, sub):
32433243
(x,) = inputs
@@ -3376,7 +3376,7 @@ def L_op(self, inputs, outputs, gout):
33763376
else:
33773377
return [x.zeros_like()]
33783378

3379-
return (gz * np.array(np.pi / 180, dtype=gz.type),)
3379+
return (gz * np.array(np.pi / 180, dtype=gz.dtype),)
33803380

33813381
def c_code(self, node, name, inputs, outputs, sub):
33823382
(x,) = inputs
@@ -3411,7 +3411,7 @@ def L_op(self, inputs, outputs, gout):
34113411
else:
34123412
return [x.zeros_like()]
34133413

3414-
return (gz * np.array(180.0 / np.pi, dtype=gz.type),)
3414+
return (gz * np.array(180.0 / np.pi, dtype=gz.dtype),)
34153415

34163416
def c_code(self, node, name, inputs, outputs, sub):
34173417
(x,) = inputs
@@ -3484,7 +3484,7 @@ def L_op(self, inputs, outputs, gout):
34843484
else:
34853485
return [x.zeros_like()]
34863486

3487-
return (-gz / sqrt(np.array(1, dtype=x.type) - sqr(x)),)
3487+
return (-gz / sqrt(np.array(1, dtype=x.dtype) - sqr(x)),)
34883488

34893489
def c_code(self, node, name, inputs, outputs, sub):
34903490
(x,) = inputs
@@ -3558,7 +3558,7 @@ def L_op(self, inputs, outputs, gout):
35583558
else:
35593559
return [x.zeros_like()]
35603560

3561-
return (gz / sqrt(np.array(1, dtype=x.type) - sqr(x)),)
3561+
return (gz / sqrt(np.array(1, dtype=x.dtype) - sqr(x)),)
35623562

35633563
def c_code(self, node, name, inputs, outputs, sub):
35643564
(x,) = inputs
@@ -3630,7 +3630,7 @@ def L_op(self, inputs, outputs, gout):
36303630
else:
36313631
return [x.zeros_like()]
36323632

3633-
return (gz / (np.array(1, dtype=x.type) + sqr(x)),)
3633+
return (gz / (np.array(1, dtype=x.dtype) + sqr(x)),)
36343634

36353635
def c_code(self, node, name, inputs, outputs, sub):
36363636
(x,) = inputs
@@ -3753,7 +3753,7 @@ def L_op(self, inputs, outputs, gout):
37533753
else:
37543754
return [x.zeros_like()]
37553755

3756-
return (gz / sqrt(sqr(x) - np.array(1, dtype=x.type)),)
3756+
return (gz / sqrt(sqr(x) - np.array(1, dtype=x.dtype)),)
37573757

37583758
def c_code(self, node, name, inputs, outputs, sub):
37593759
(x,) = inputs
@@ -3830,7 +3830,7 @@ def L_op(self, inputs, outputs, gout):
38303830
else:
38313831
return [x.zeros_like()]
38323832

3833-
return (gz / sqrt(sqr(x) + np.array(1, dtype=x.type)),)
3833+
return (gz / sqrt(sqr(x) + np.array(1, dtype=x.dtype)),)
38343834

38353835
def c_code(self, node, name, inputs, outputs, sub):
38363836
(x,) = inputs
@@ -3908,7 +3908,7 @@ def L_op(self, inputs, outputs, gout):
39083908
else:
39093909
return [x.zeros_like()]
39103910

3911-
return (gz / (np.array(1, dtype=x.type) - sqr(x)),)
3911+
return (gz / (np.array(1, dtype=x.dtype) - sqr(x)),)
39123912

39133913
def c_code(self, node, name, inputs, outputs, sub):
39143914
(x,) = inputs

pytensor/sparse/rewriting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def c_code_cache_version(self):
193193
def local_inplace_addsd_ccode(fgraph, node):
194194
"""Rewrite to insert inplace versions of `AddSD`."""
195195
if isinstance(node.op, sparse.AddSD) and config.cxx:
196-
out_dtype = ps.upcast(*node.inputs)
196+
out_dtype = ps.upcast(*[inp.type.dtype for inp in node.inputs])
197197
if out_dtype != node.inputs[1].dtype:
198198
return
199199
new_node = AddSD_ccode(format=node.inputs[0].type.format, inplace=True)(

0 commit comments

Comments
 (0)