Skip to content

Commit 8499a7d

Browse files
Make sure forward transformed input is a TensorVariable in TestMatchesScipy
1 parent 9f8d459 commit 8499a7d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pymc3/tests/test_distributions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1876,7 +1876,7 @@ def test_dirichlet_with_batch_shapes(self, dist_shape):
18761876
d_value = d.tag.value_var
18771877
d_point = d.eval()
18781878
if hasattr(d_value.tag, "transform"):
1879-
d_point_trans = d_value.tag.transform.forward(d, d_point).eval()
1879+
d_point_trans = d_value.tag.transform.forward(d, aet.as_tensor(d_point)).eval()
18801880
else:
18811881
d_point_trans = d_point
18821882

0 commit comments

Comments
 (0)