Skip to content

Commit a05684b

Browse files
authored
apply black formatter to mlda (#4162)
1 parent 87f603b commit a05684b

File tree

1 file changed

+22
-29
lines changed

1 file changed

+22
-29
lines changed

pymc3/step_methods/mlda.py

Lines changed: 22 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -106,21 +106,21 @@ class MLDA(ArrayStepShared):
106106
... datum = 1
107107
...
108108
... with pm.Model() as coarse_model:
109-
... x = Normal("x", mu=0, sigma=10)
110-
... y = Normal("y", mu=x, sigma=1, observed=datum - 0.1)
109+
... x = pm.Normal("x", mu=0, sigma=10)
110+
... y = pm.Normal("y", mu=x, sigma=1, observed=datum - 0.1)
111111
...
112112
... with pm.Model():
113-
... x = Normal("x", mu=0, sigma=10)
114-
... y = Normal("y", mu=x, sigma=1, observed=datum)
115-
... step_method = pm.MLDA(coarse_models=[coarse_model]
113+
... x = pm.Normal("x", mu=0, sigma=10)
114+
... y = pm.Normal("y", mu=x, sigma=1, observed=datum)
115+
... step_method = pm.MLDA(coarse_models=[coarse_model],
116116
... subsampling_rates=5)
117-
... trace = pm.sample(ndraws=500, chains=2,
117+
... trace = pm.sample(500, chains=2,
118118
... tune=100, step=step_method,
119119
... random_seed=123)
120120
...
121-
... pm.summary(trace)
122-
mean sd hpd_3% hpd_97%
123-
x 1.011 0.975 -0.925 2.824
121+
... pm.summary(trace, kind="stats")
122+
mean sd hdi_3% hdi_97%
123+
x 0.99 0.987 -0.734 2.992
124124
125125
References
126126
----------
@@ -161,7 +161,7 @@ def __init__(
161161
mode: Optional = None,
162162
subsampling_rates: List[int] = 5,
163163
base_blocked: bool = False,
164-
**kwargs
164+
**kwargs,
165165
) -> None:
166166

167167
warnings.warn(
@@ -174,9 +174,7 @@ def __init__(
174174
# assign internal state
175175
self.coarse_models = coarse_models
176176
if not isinstance(coarse_models, list):
177-
raise ValueError(
178-
"MLDA step method cannot use coarse_models if it is not a list"
179-
)
177+
raise ValueError("MLDA step method cannot use coarse_models if it is not a list")
180178
if len(self.coarse_models) == 0:
181179
raise ValueError(
182180
"MLDA step method was given an empty "
@@ -233,9 +231,7 @@ def __init__(
233231
if self.num_levels == 2:
234232
with self.next_model:
235233
# make sure the correct variables are selected from next_model
236-
vars_next = [
237-
var for var in self.next_model.vars if var.name in self.var_names
238-
]
234+
vars_next = [var for var in self.next_model.vars if var.name in self.var_names]
239235
# MetropolisMLDA sampler in base level (level=0), targeting self.next_model
240236
self.next_step_method = pm.MetropolisMLDA(
241237
vars=vars_next,
@@ -253,9 +249,7 @@ def __init__(
253249
next_subsampling_rates = self.subsampling_rates[:-1]
254250
with self.next_model:
255251
# make sure the correct variables are selected from next_model
256-
vars_next = [
257-
var for var in self.next_model.vars if var.name in self.var_names
258-
]
252+
vars_next = [var for var in self.next_model.vars if var.name in self.var_names]
259253
# MLDA sampler in some intermediate level, targeting self.next_model
260254
self.next_step_method = pm.MLDA(
261255
vars=vars_next,
@@ -335,9 +329,7 @@ def astep(self, q0):
335329
self.base_scaling_stats = {"base_scaling": np.array(scaling_list)}
336330
elif not isinstance(self.next_step_method, MLDA):
337331
# next method is any block sampler
338-
self.base_scaling_stats = {
339-
"base_scaling": np.array(self.next_step_method.scaling)
340-
}
332+
self.base_scaling_stats = {"base_scaling": np.array(self.next_step_method.scaling)}
341333
else:
342334
# next method is MLDA - propagate dict from lower levels
343335
self.base_scaling_stats = self.next_step_method.base_scaling_stats
@@ -366,19 +358,20 @@ class RecursiveDAProposal(Proposal):
366358
each of which is used to propose samples to the chain above.
367359
"""
368360

369-
def __init__(self,
370-
next_step_method: Union[MLDA, Metropolis, CompoundStep],
371-
next_model: Model,
372-
tune: bool,
373-
subsampling_rate: int) -> None:
361+
def __init__(
362+
self,
363+
next_step_method: Union[MLDA, Metropolis, CompoundStep],
364+
next_model: Model,
365+
tune: bool,
366+
subsampling_rate: int,
367+
) -> None:
374368

375369
self.next_step_method = next_step_method
376370
self.next_model = next_model
377371
self.tune = tune
378372
self.subsampling_rate = subsampling_rate
379373

380-
def __call__(self,
381-
q0_dict: dict) -> dict:
374+
def __call__(self, q0_dict: dict) -> dict:
382375
"""Returns proposed sample given the current sample
383376
in dictionary form (q0_dict).
384377
"""

0 commit comments

Comments
 (0)