Skip to content

Commit 4143973

Browse files
Remove theanof.set_theano_conf and instead use the config context properly
1 parent 6f15cbb commit 4143973

File tree

4 files changed

+9
-56
lines changed

4 files changed

+9
-56
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
### Maintenance
66
- Fixed bug whereby partial traces returns after keyboard interrupt during parallel sampling had fewer draws than would've been available [#4318](https://github.com/pymc-devs/pymc3/pull/4318)
77
- Make `sample_shape` same across all contexts in `draw_values` (see [#4305](https://github.com/pymc-devs/pymc3/pull/4305)).
8+
- Removed `theanof.set_theano_config` because it illegally touched Theano's privates (see [#4329](https://github.com/pymc-devs/pymc3/pull/4329)).
9+
810

911
## PyMC3 3.10.0 (7 December 2020)
1012

pymc3/model.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
gradient,
4343
hessian,
4444
inputvars,
45-
set_theano_conf,
4645
)
4746
from pymc3.util import get_transformed_name, get_var_name
4847
from pymc3.vartypes import continuous_types, discrete_types, isgenerator, typefilter
@@ -288,15 +287,17 @@ def __new__(cls, name, bases, dct, **kargs): # pylint: disable=unused-argument
288287
def __enter__(self):
289288
self.__class__.context_class.get_contexts().append(self)
290289
# self._theano_config is set in Model.__new__
290+
self._config_context = None
291291
if hasattr(self, "_theano_config"):
292-
self._old_theano_config = set_theano_conf(self._theano_config)
292+
self._config_context = theano.change_flags(**self._theano_config)
293+
self._config_context.__enter__()
293294
return self
294295

295296
def __exit__(self, typ, value, traceback): # pylint: disable=unused-argument
296297
self.__class__.context_class.get_contexts().pop()
297298
# self._theano_config is set in Model.__new__
298-
if hasattr(self, "_old_theano_config"):
299-
set_theano_conf(self._old_theano_config)
299+
if self._config_context:
300+
self._config_context.__exit__(typ, value, traceback)
300301

301302
dct[__enter__.__name__] = __enter__
302303
dct[__exit__.__name__] = __exit__

pymc3/tests/test_theanof.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import theano
2222
import theano.tensor as tt
2323

24-
from pymc3.theanof import _conversion_map, set_theano_conf, take_along_axis
24+
from pymc3.theanof import _conversion_map, take_along_axis
2525
from pymc3.vartypes import int_types
2626

2727
FLOATX = str(theano.config.floatX)
@@ -72,27 +72,6 @@ def np_take_along_axis(arr, indices, axis):
7272
return arr[_make_along_axis_idx(arr.shape, indices, _axis)]
7373

7474

75-
class TestSetTheanoConfig:
76-
def test_invalid_key(self):
77-
with pytest.raises(ValueError) as e:
78-
set_theano_conf({"bad_key": True})
79-
e.match("Unknown")
80-
81-
def test_restore_when_bad_key(self):
82-
with theano.configparser.change_flags(compute_test_value="off"):
83-
with pytest.raises(ValueError):
84-
conf = collections.OrderedDict([("compute_test_value", "raise"), ("bad_key", True)])
85-
set_theano_conf(conf)
86-
assert theano.config.compute_test_value == "off"
87-
88-
def test_restore(self):
89-
with theano.configparser.change_flags(compute_test_value="off"):
90-
conf = set_theano_conf({"compute_test_value": "raise"})
91-
assert conf == {"compute_test_value": "off"}
92-
conf = set_theano_conf(conf)
93-
assert conf == {"compute_test_value": "raise"}
94-
95-
9675
class TestTakeAlongAxis:
9776
def setup_class(self):
9877
self.inputs_buffer = dict()

pymc3/theanof.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,8 @@
1515
import numpy as np
1616
import theano
1717

18-
from theano import scalar
18+
from theano import change_flags, scalar
1919
from theano import tensor as tt
20-
from theano.configparser import change_flags
2120
from theano.gof import Op
2221
from theano.gof.graph import inputs
2322
from theano.sandbox.rng_mrg import MRG_RandomStreams
@@ -442,34 +441,6 @@ def floatX_array(x):
442441
return floatX(np.array(x))
443442

444443

445-
def set_theano_conf(values):
446-
"""Change the theano configuration and return old values.
447-
448-
This is similar to `theano.configparser.change_flags`, but it
449-
returns the original values in a pickleable form.
450-
"""
451-
variables = {}
452-
unknown = set(values.keys())
453-
for variable in theano.configparser._config_var_list:
454-
if variable.fullname in values:
455-
variables[variable.fullname] = variable
456-
unknown.remove(variable.fullname)
457-
if len(unknown) > 0:
458-
raise ValueError("Unknown theano config settings: %s" % unknown)
459-
460-
old = {}
461-
for name, variable in variables.items():
462-
old_value = variable.__get__(True, None)
463-
try:
464-
variable.__set__(None, values[name])
465-
except Exception:
466-
for key, old_value in old.items():
467-
variables[key].__set__(None, old_value)
468-
raise
469-
old[name] = old_value
470-
return old
471-
472-
473444
def ix_(*args):
474445
"""
475446
Theano np.ix_ analog

0 commit comments

Comments
 (0)