Skip to content

Commit 0e85b45

Browse files
committed
fix: safer destructor in ModelTrainer
1 parent b898f76 commit 0e85b45

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

src/sagemaker/modules/train/model_trainer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ class ModelTrainer(BaseModel):
205205
"LOCAL_CONTAINER" mode.
206206
"""
207207

208-
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
208+
model_config = ConfigDict(
209+
arbitrary_types_allowed=True, validate_assignment=True, extra="forbid"
210+
)
209211

210212
training_mode: Mode = Mode.SAGEMAKER_TRAINING_JOB
211213
sagemaker_session: Optional[Session] = None
@@ -363,9 +365,10 @@ def _populate_intelligent_defaults_from_model_trainer_space(self):
363365

364366
def __del__(self):
365367
"""Destructor method to clean up the temporary directory."""
366-
# Clean up the temporary directory if it exists
367-
if self._temp_recipe_train_dir is not None:
368-
self._temp_recipe_train_dir.cleanup()
368+
# Clean up the temporary directory if it exists and class was initialized
369+
if hasattr(self, "__pydantic_fields_set__"):
370+
if self._temp_recipe_train_dir is not None:
371+
self._temp_recipe_train_dir.cleanup()
369372

370373
def _validate_training_image_and_algorithm_name(
371374
self, training_image: Optional[str], algorithm_name: Optional[str]

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,3 +1068,20 @@ def test_safe_configs():
10681068
# Test invalid type fails
10691069
with pytest.raises(ValueError):
10701070
SourceCode(entry_script=1)
1071+
1072+
1073+
@patch("sagemaker.modules.train.model_trainer.TemporaryDirectory")
1074+
def test_destructor_cleanup(mock_tmp_dir, modules_session):
1075+
with pytest.raises(ValueError):
1076+
ModelTrainer()
1077+
mock_tmp_dir.cleanup.assert_not_called()
1078+
1079+
model_trainer = ModelTrainer(
1080+
training_image=DEFAULT_IMAGE,
1081+
role=DEFAULT_ROLE,
1082+
sagemaker_session=modules_session,
1083+
)
1084+
model_trainer._temp_recipe_train_dir = mock_tmp_dir
1085+
mock_tmp_dir.assert_not_called()
1086+
del model_trainer
1087+
mock_tmp_dir.cleanup.assert_called_once()

0 commit comments

Comments
 (0)