Skip to content

Commit 8f82823

Browse files
Raise AttributeError in lightning_getattr and lightning_setattr when attribute not found (Lightning-AI#6024)
* Empty commit * Raise AttributeError instead of ValueError * Make functions private * Update tests * Add match string * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> * lightning to Lightning Co-authored-by: Adrian Wälchli <[email protected]>
1 parent b0074a4 commit 8f82823

File tree

2 files changed

+60
-25
lines changed

2 files changed

+60
-25
lines changed

pytorch_lightning/utilities/parsing.py

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,11 @@ def __repr__(self):
196196
return out
197197

198198

199-
def lightning_get_all_attr_holders(model, attribute):
200-
""" Special attribute finding for lightning. Gets all of the objects or dicts that holds attribute.
201-
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """
199+
def _lightning_get_all_attr_holders(model, attribute):
200+
"""
201+
Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute.
202+
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule.
203+
"""
202204
trainer = getattr(model, 'trainer', None)
203205

204206
holders = []
@@ -219,31 +221,40 @@ def lightning_get_all_attr_holders(model, attribute):
219221
return holders
220222

221223

222-
def lightning_get_first_attr_holder(model, attribute):
224+
def _lightning_get_first_attr_holder(model, attribute):
225+
"""
226+
Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None.
227+
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule,
228+
returns the last one that has it.
223229
"""
224-
Special attribute finding for lightning. Gets the object or dict that holds attribute, or None.
225-
Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule,
226-
returns the last one that has it.
227-
"""
228-
holders = lightning_get_all_attr_holders(model, attribute)
230+
holders = _lightning_get_all_attr_holders(model, attribute)
229231
if len(holders) == 0:
230232
return None
231233
# using the last holder to preserve backwards compatibility
232234
return holders[-1]
233235

234236

235237
def lightning_hasattr(model, attribute):
236-
""" Special hasattr for lightning. Checks for attribute in model namespace,
237-
the old hparams namespace/dict, and the datamodule. """
238-
return lightning_get_first_attr_holder(model, attribute) is not None
238+
"""
239+
Special hasattr for Lightning. Checks for attribute in model namespace,
240+
the old hparams namespace/dict, and the datamodule.
241+
"""
242+
return _lightning_get_first_attr_holder(model, attribute) is not None
239243

240244

241245
def lightning_getattr(model, attribute):
242-
""" Special getattr for lightning. Checks for attribute in model namespace,
243-
the old hparams namespace/dict, and the datamodule. """
244-
holder = lightning_get_first_attr_holder(model, attribute)
246+
"""
247+
Special getattr for Lightning. Checks for attribute in model namespace,
248+
the old hparams namespace/dict, and the datamodule.
249+
250+
Raises:
251+
AttributeError:
252+
If ``model`` doesn't have ``attribute`` in any of
253+
model namespace, the hparams namespace/dict, and the datamodule.
254+
"""
255+
holder = _lightning_get_first_attr_holder(model, attribute)
245256
if holder is None:
246-
raise ValueError(
257+
raise AttributeError(
247258
f'{attribute} is neither stored in the model namespace'
248259
' nor the `hparams` namespace/dict, nor the datamodule.'
249260
)
@@ -254,13 +265,19 @@ def lightning_getattr(model, attribute):
254265

255266

256267
def lightning_setattr(model, attribute, value):
257-
""" Special setattr for lightning. Checks for attribute in model namespace
258-
and the old hparams namespace/dict.
259-
Will also set the attribute on datamodule, if it exists.
260268
"""
261-
holders = lightning_get_all_attr_holders(model, attribute)
269+
Special setattr for Lightning. Checks for attribute in model namespace
270+
and the old hparams namespace/dict.
271+
Will also set the attribute on datamodule, if it exists.
272+
273+
Raises:
274+
AttributeError:
275+
If ``model`` doesn't have ``attribute`` in any of
276+
model namespace, the hparams namespace/dict, and the datamodule.
277+
"""
278+
holders = _lightning_get_all_attr_holders(model, attribute)
262279
if len(holders) == 0:
263-
raise ValueError(
280+
raise AttributeError(
264281
f'{attribute} is neither stored in the model namespace'
265282
' nor the `hparams` namespace/dict, nor the datamodule.'
266283
)

tests/utilities/test_parsing.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import pytest
1415

1516
from pytorch_lightning.utilities.parsing import lightning_getattr, lightning_hasattr, lightning_setattr
1617

@@ -74,8 +75,8 @@ class TestModel7: # test for datamodule w/ hparams w/ attribute (should use dat
7475

7576

7677
def test_lightning_hasattr(tmpdir):
77-
""" Test that the lightning_hasattr works in all cases"""
78-
model1, model2, model3, model4, model5, model6, model7 = _get_test_cases()
78+
"""Test that the lightning_hasattr works in all cases"""
79+
model1, model2, model3, model4, model5, model6, model7 = models = _get_test_cases()
7980
assert lightning_hasattr(model1, 'learning_rate'), \
8081
'lightning_hasattr failed to find namespace variable'
8182
assert lightning_hasattr(model2, 'learning_rate'), \
@@ -91,9 +92,12 @@ def test_lightning_hasattr(tmpdir):
9192
assert lightning_hasattr(model7, 'batch_size'), \
9293
'lightning_hasattr failed to find batch_size in hparams w/ datamodule present'
9394

95+
for m in models:
96+
assert not lightning_hasattr(m, "this_attr_not_exist")
97+
9498

9599
def test_lightning_getattr(tmpdir):
96-
""" Test that the lightning_getattr works in all cases"""
100+
"""Test that the lightning_getattr works in all cases"""
97101
models = _get_test_cases()
98102
for i, m in enumerate(models[:3]):
99103
value = lightning_getattr(m, 'learning_rate')
@@ -107,9 +111,16 @@ def test_lightning_getattr(tmpdir):
107111
assert lightning_getattr(model7, 'batch_size') == 8, \
108112
'batch_size not correctly extracted'
109113

114+
for m in models:
115+
with pytest.raises(
116+
AttributeError,
117+
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
118+
):
119+
lightning_getattr(m, "this_attr_not_exist")
120+
110121

111122
def test_lightning_setattr(tmpdir):
112-
""" Test that the lightning_setattr works in all cases"""
123+
"""Test that the lightning_setattr works in all cases"""
113124
models = _get_test_cases()
114125
for m in models[:3]:
115126
lightning_setattr(m, 'learning_rate', 10)
@@ -126,3 +137,10 @@ def test_lightning_setattr(tmpdir):
126137
'batch_size not correctly set'
127138
assert lightning_getattr(model7, 'batch_size') == 128, \
128139
'batch_size not correctly set'
140+
141+
for m in models:
142+
with pytest.raises(
143+
AttributeError,
144+
match="is neither stored in the model namespace nor the `hparams` namespace/dict, nor the datamodule."
145+
):
146+
lightning_setattr(m, "this_attr_not_exist", None)

0 commit comments

Comments
 (0)