Skip to content

Commit 144b0ba

Browse files
ferrinericardoV94
andauthored
WIP Improve scoped models (#5607)
* use slashes to separate context * add fix for nested prefix * fix test * refactor prefix * revert precommit change * fix windows issue * fis smc test * Update pymc/tests/test_model.py Co-authored-by: Ricardo Vieira <[email protected]> * Update pymc/tests/test_model.py Co-authored-by: Ricardo Vieira <[email protected]> * update tests Co-authored-by: Ricardo Vieira <[email protected]>
1 parent e8c07ef commit 144b0ba

File tree

4 files changed

+53
-33
lines changed

4 files changed

+53
-33
lines changed

pymc/model.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def __init__(self, mean=0, sigma=1, name=''):
473473
474474
# 3) you can create variables with Var method
475475
self.Var('v1', Normal.dist(mu=mean, sigma=sd))
476-
# this will create variable named like '{prefix_}v1'
476+
# this will create variable named like '{prefix/}v1'
477477
# and assign attribute 'v1' to instance created
478478
# variable can be accessed with self.v1 or self['v1']
479479
@@ -515,6 +515,8 @@ def __init__(self, mean=0, sigma=1, name=''):
515515
CustomModel(mean=1, name='first')
516516
CustomModel(mean=2, name='second')
517517
518+
# variables inside both scopes will be named like `first/*`, `second/*`
519+
518520
"""
519521

520522
if TYPE_CHECKING:
@@ -1455,14 +1457,18 @@ def add_random_variable(self, var, dims: Optional[Tuple[Union[str, None], ...]]
14551457
setattr(self, self.name_of(var.name), var)
14561458

14571459
@property
1458-
def prefix(self):
1459-
return f"{self.name}_" if self.name else ""
1460+
def prefix(self) -> str:
1461+
if self.isroot or not self.parent.prefix:
1462+
name = self.name
1463+
else:
1464+
name = f"{self.parent.prefix}/{self.name}"
1465+
return name.strip("/")
14601466

14611467
def name_for(self, name):
14621468
"""Checks if name has prefix and adds if needed"""
14631469
if self.prefix:
14641470
if not name.startswith(self.prefix):
1465-
return f"{self.prefix}{name}"
1471+
return f"{self.prefix}/{name}"
14661472
else:
14671473
return name
14681474
else:
@@ -1472,8 +1478,8 @@ def name_of(self, name):
14721478
"""Checks if name has prefix and deletes if needed"""
14731479
if not self.prefix or not name:
14741480
return name
1475-
elif name.startswith(self.prefix):
1476-
return name[len(self.prefix) :]
1481+
elif name.startswith(self.prefix + "/"):
1482+
return name[len(self.prefix) + 1 :]
14771483
else:
14781484
return name
14791485

pymc/tests/test_data_container.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,8 +403,8 @@ def test_data_naming():
403403
with pm.Model("named_model") as model:
404404
x = pm.ConstantData("x", [1.0, 2.0, 3.0])
405405
y = pm.Normal("y")
406-
assert y.name == "named_model_y"
407-
assert x.name == "named_model_x"
406+
assert y.name == "named_model/y"
407+
assert x.name == "named_model/x"
408408

409409

410410
def test_get_data():

pymc/tests/test_model.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -95,20 +95,20 @@ def test_context_passes_vars_to_parent_model(self):
9595
usermodel2.register_rv(pm.Normal.dist(), "v3")
9696
pm.Normal("v4")
9797
# this variable is created in parent model too
98-
assert "another_v2" in model.named_vars
99-
assert "another_v3" in model.named_vars
100-
assert "another_v3" in usermodel2.named_vars
101-
assert "another_v4" in model.named_vars
102-
assert "another_v4" in usermodel2.named_vars
98+
assert "another/v2" in model.named_vars
99+
assert "another/v3" in model.named_vars
100+
assert "another/v3" in usermodel2.named_vars
101+
assert "another/v4" in model.named_vars
102+
assert "another/v4" in usermodel2.named_vars
103103
assert hasattr(usermodel2, "v3")
104104
assert hasattr(usermodel2, "v2")
105105
assert hasattr(usermodel2, "v4")
106106
# When you create a class based model you should follow some rules
107107
with model:
108108
m = NewModel("one_more")
109-
assert m.d is model["one_more_d"]
110-
assert m["d"] is model["one_more_d"]
111-
assert m["one_more_d"] is model["one_more_d"]
109+
assert m.d is model["one_more/d"]
110+
assert m["d"] is model["one_more/d"]
111+
assert m["one_more/d"] is model["one_more/d"]
112112

113113

114114
class TestNested:
@@ -124,8 +124,8 @@ def test_nest_context_works(self):
124124
def test_named_context(self):
125125
with pm.Model() as m:
126126
NewModel(name="new")
127-
assert "new_v1" in m.named_vars
128-
assert "new_v2" in m.named_vars
127+
assert "new/v1" in m.named_vars
128+
assert "new/v2" in m.named_vars
129129

130130
def test_docstring_example1(self):
131131
usage1 = DocstringModel()
@@ -138,10 +138,10 @@ def test_docstring_example1(self):
138138
def test_docstring_example2(self):
139139
with pm.Model() as model:
140140
DocstringModel(name="prefix")
141-
assert "prefix_v1" in model.named_vars
142-
assert "prefix_v2" in model.named_vars
143-
assert "prefix_v3" in model.named_vars
144-
assert "prefix_v3_sq" in model.named_vars
141+
assert "prefix/v1" in model.named_vars
142+
assert "prefix/v2" in model.named_vars
143+
assert "prefix/v3" in model.named_vars
144+
assert "prefix/v3_sq" in model.named_vars
145145
assert len(model.potentials), 1
146146

147147
def test_duplicates_detection(self):
@@ -156,6 +156,20 @@ def test_model_root(self):
156156
with pm.Model() as sub:
157157
assert model is sub.root
158158

159+
def test_nested_named_model_repeated(self):
160+
with pm.Model("sub") as model:
161+
b = pm.Normal("var")
162+
with pm.Model("sub"):
163+
b = pm.Normal("var")
164+
assert {"sub/var", "sub/sub/var"} == set(model.named_vars.keys())
165+
166+
def test_nested_named_model(self):
167+
with pm.Model("sub1") as model:
168+
b = pm.Normal("var")
169+
with pm.Model("sub2"):
170+
b = pm.Normal("var")
171+
assert {"sub1/var", "sub1/sub2/var"} == set(model.named_vars.keys())
172+
159173

160174
class TestObserved:
161175
def test_observed_rv_fail(self):
@@ -658,14 +672,14 @@ def test_datalogpt_multiple_shapes():
658672

659673

660674
def test_nested_model_coords():
661-
COORDS = {"dim": range(10)}
662-
with pm.Model(name="m1", coords=COORDS) as m1:
663-
a = pm.Normal("a")
664-
with pm.Model(name="m2") as m2:
665-
b = pm.Normal("b")
666-
c = pm.HalfNormal("c")
667-
d = pm.Normal("d", b, c, dims="dim")
668-
e = pm.Normal("e", a + d, dims="dim")
675+
with pm.Model(name="m1", coords=dict(dim1=range(2))) as m1:
676+
a = pm.Normal("a", dims="dim1")
677+
with pm.Model(name="m2", coords=dict(dim2=range(4))) as m2:
678+
b = pm.Normal("b", dims="dim1")
679+
m1.add_coord("dim3", range(4))
680+
c = pm.HalfNormal("c", dims="dim3")
681+
d = pm.Normal("d", b, c, dims="dim2")
682+
e = pm.Normal("e", a[None] + d[:, None], dims=("dim2", "dim1"))
669683
assert m1.coords is m2.coords
670684
assert m1.dim_lengths is m2.dim_lengths
671685
assert set(m2.RV_dims) < set(m1.RV_dims)

pymc/tests/test_smc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -530,9 +530,9 @@ def test_named_model(self):
530530
s = pm.Simulator("s", self.normal_sim, a, b, observed=self.data)
531531

532532
trace = pm.sample_smc(draws=10, chains=2, return_inferencedata=False)
533-
assert f"{name}_a" in trace.varnames
534-
assert f"{name}_b" in trace.varnames
535-
assert f"{name}_b_log__" in trace.varnames
533+
assert f"{name}/a" in trace.varnames
534+
assert f"{name}/b" in trace.varnames
535+
assert f"{name}/b_log__" in trace.varnames
536536

537537

538538
class TestMHKernel(SeededTest):

0 commit comments

Comments
 (0)