@@ -95,20 +95,20 @@ def test_context_passes_vars_to_parent_model(self):
95
95
usermodel2 .register_rv (pm .Normal .dist (), "v3" )
96
96
pm .Normal ("v4" )
97
97
# this variable is created in parent model too
98
- assert "another_v2 " in model .named_vars
99
- assert "another_v3 " in model .named_vars
100
- assert "another_v3 " in usermodel2 .named_vars
101
- assert "another_v4 " in model .named_vars
102
- assert "another_v4 " in usermodel2 .named_vars
98
+ assert "another/v2 " in model .named_vars
99
+ assert "another/v3 " in model .named_vars
100
+ assert "another/v3 " in usermodel2 .named_vars
101
+ assert "another/v4 " in model .named_vars
102
+ assert "another/v4 " in usermodel2 .named_vars
103
103
assert hasattr (usermodel2 , "v3" )
104
104
assert hasattr (usermodel2 , "v2" )
105
105
assert hasattr (usermodel2 , "v4" )
106
106
# When you create a class based model you should follow some rules
107
107
with model :
108
108
m = NewModel ("one_more" )
109
- assert m .d is model ["one_more_d " ]
110
- assert m ["d" ] is model ["one_more_d " ]
111
- assert m ["one_more_d " ] is model ["one_more_d " ]
109
+ assert m .d is model ["one_more/d " ]
110
+ assert m ["d" ] is model ["one_more/d " ]
111
+ assert m ["one_more/d " ] is model ["one_more/d " ]
112
112
113
113
114
114
class TestNested :
@@ -124,8 +124,8 @@ def test_nest_context_works(self):
124
124
def test_named_context (self ):
125
125
with pm .Model () as m :
126
126
NewModel (name = "new" )
127
- assert "new_v1 " in m .named_vars
128
- assert "new_v2 " in m .named_vars
127
+ assert "new/v1 " in m .named_vars
128
+ assert "new/v2 " in m .named_vars
129
129
130
130
def test_docstring_example1 (self ):
131
131
usage1 = DocstringModel ()
@@ -138,10 +138,10 @@ def test_docstring_example1(self):
138
138
def test_docstring_example2 (self ):
139
139
with pm .Model () as model :
140
140
DocstringModel (name = "prefix" )
141
- assert "prefix_v1 " in model .named_vars
142
- assert "prefix_v2 " in model .named_vars
143
- assert "prefix_v3 " in model .named_vars
144
- assert "prefix_v3_sq " in model .named_vars
141
+ assert "prefix/v1 " in model .named_vars
142
+ assert "prefix/v2 " in model .named_vars
143
+ assert "prefix/v3 " in model .named_vars
144
+ assert "prefix/v3_sq " in model .named_vars
145
145
assert len (model .potentials ), 1
146
146
147
147
def test_duplicates_detection (self ):
@@ -156,6 +156,20 @@ def test_model_root(self):
156
156
with pm .Model () as sub :
157
157
assert model is sub .root
158
158
159
+ def test_nested_named_model_repeated (self ):
160
+ with pm .Model ("sub" ) as model :
161
+ b = pm .Normal ("var" )
162
+ with pm .Model ("sub" ):
163
+ b = pm .Normal ("var" )
164
+ assert {"sub/var" , "sub/sub/var" } == set (model .named_vars .keys ())
165
+
166
+ def test_nested_named_model (self ):
167
+ with pm .Model ("sub1" ) as model :
168
+ b = pm .Normal ("var" )
169
+ with pm .Model ("sub2" ):
170
+ b = pm .Normal ("var" )
171
+ assert {"sub1/var" , "sub1/sub2/var" } == set (model .named_vars .keys ())
172
+
159
173
160
174
class TestObserved :
161
175
def test_observed_rv_fail (self ):
@@ -658,14 +672,14 @@ def test_datalogpt_multiple_shapes():
658
672
659
673
660
674
def test_nested_model_coords ():
661
- COORDS = { "dim" : range (10 )}
662
- with pm .Model ( name = "m1 " , coords = COORDS ) as m1 :
663
- a = pm .Normal ( "a" )
664
- with pm .Model ( name = "m2" ) as m2 :
665
- b = pm . Normal ( "b" )
666
- c = pm .HalfNormal ("c" )
667
- d = pm .Normal ("d" , b , c , dims = "dim " )
668
- e = pm .Normal ("e" , a + d , dims = "dim" )
675
+ with pm . Model ( name = "m1" , coords = dict ( dim1 = range (2 ))) as m1 :
676
+ a = pm .Normal ( "a " , dims = "dim1" )
677
+ with pm .Model ( name = "m2" , coords = dict ( dim2 = range ( 4 ))) as m2 :
678
+ b = pm .Normal ( "b" , dims = "dim1" )
679
+ m1 . add_coord ( "dim3" , range ( 4 ) )
680
+ c = pm .HalfNormal ("c" , dims = "dim3" )
681
+ d = pm .Normal ("d" , b , c , dims = "dim2 " )
682
+ e = pm .Normal ("e" , a [ None ] + d [:, None ], dims = ( "dim2" , "dim1" ) )
669
683
assert m1 .coords is m2 .coords
670
684
assert m1 .dim_lengths is m2 .dim_lengths
671
685
assert set (m2 .RV_dims ) < set (m1 .RV_dims )
0 commit comments