Skip to content

Commit 9bbe3d3

Browse files
committed
Make tests compatible with latest release of PyTensor
1 parent 9bb3cf0 commit 9bbe3d3

File tree

4 files changed

+9
-4
lines changed

4 files changed

+9
-4
lines changed

pymc/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1232,7 +1232,7 @@ def set_data(
12321232
if isinstance(length_tensor_origin, TensorConstant):
12331233
raise ShapeError(
12341234
f"Resizing dimension '{dname}' with values of length {new_length} would lead to incompatibilities, "
1235-
f"because the dimension length is tied to a {length_tensor_origin}. "
1235+
f"because the dimension length is tied to a TensorConstant. "
12361236
f"Check if the dimension was defined implicitly before the shared variable '{name}' was created, "
12371237
f"for example by another model variable.",
12381238
actual=new_length,

tests/distributions/test_multivariate.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,12 @@ def test_mvnormal_indef(self):
316316
f_logp(cov_val, np.ones(2))
317317
dlogp = pt.grad(mvn_logp, cov)
318318
f_dlogp = pytensor.function([cov, x], dlogp)
319-
assert not np.all(np.isfinite(f_dlogp(cov_val, np.ones(2))))
319+
try:
320+
res = f_dlogp(cov_val, np.ones(2))
321+
except ValueError:
322+
pass # Op raises internally
323+
else:
324+
assert not np.all(np.isfinite(res)) # Otherwise, should return nan
320325

321326
def test_mvnormal_init_fail(self):
322327
with pm.Model():

tests/logprob/test_basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def test_probability_inference(func, scipy_func, test_value):
426426
def test_probability_inference_fails(func, func_name):
427427
with pytest.raises(
428428
NotImplementedError,
429-
match=f"{func_name} method not implemented for Elemwise{{cos,no_inplace}}",
429+
match=f"{func_name} method not implemented for (Elemwise{{cos,no_inplace}}|Cos)",
430430
):
431431
func(pt.cos(pm.Normal.dist()), 1)
432432

tests/test_pytensorf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -575,7 +575,7 @@ def step_wo_update(x, rng):
575575

576576
with pytest.raises(
577577
ValueError,
578-
match=r"No update found for at least one RNG used in Scan Op for\{cpu,test_scan\}",
578+
match="No update found for at least one RNG used in Scan Op",
579579
):
580580
collect_default_updates([xs])
581581

0 commit comments

Comments
 (0)