Skip to content

Commit 4309dfc

Browse files
committed
Update random test to work with symbolic inputs
1 parent 4645169 commit 4309dfc

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

pymc/tests/test_distributions_random.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,11 @@ def check_pymc_params_match_rv_op(self):
261261
for (expected_name, expected_value), actual_variable in zip(
262262
self.expected_rv_op_params.items(), aesara_dist_inputs
263263
):
264+
265+
# Add additional line to evaluate symbolic inputs to distributions
266+
if isinstance(expected_value, aesara.tensor.Variable):
267+
expected_value = expected_value.eval()
268+
264269
assert_almost_equal(expected_value, actual_variable.eval(), decimal=self.decimal)
265270

266271
def check_rv_size(self):

pymc/tests/test_distributions_timeseries.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
15-
import aesara.tensor as at
1614
import numpy as np
1715
import pytest
1816

@@ -131,7 +129,6 @@ class TestGaussianRandomWalk(BaseTestDistributionRandom):
131129
pymc_dist = pm.GaussianRandomWalk
132130
pymc_dist_params = {"mu": 1.0, "sigma": 2, "init": pm.Constant.dist(0), "steps": 4}
133131
expected_rv_op_params = {"mu": 1.0, "sigma": 2, "init": pm.Constant.dist(0), "steps": 4}
134-
# reference_dist_params = {"b": 1.0, "kappa": 1.0, "mu": 0.0}
135132

136133
checks_to_run = [
137134
"check_pymc_params_match_rv_op",

0 commit comments

Comments
 (0)