Skip to content

Commit 6223f26

Browse files
committed
fix: safer destructor in ModelTrainer
1 parent b898f76 commit 6223f26

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

src/sagemaker/modules/train/model_trainer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,9 @@ class ModelTrainer(BaseModel):
205205
"LOCAL_CONTAINER" mode.
206206
"""
207207

208-
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
208+
model_config = ConfigDict(
209+
arbitrary_types_allowed=True, validate_assignment=True, extra="forbid"
210+
)
209211

210212
training_mode: Mode = Mode.SAGEMAKER_TRAINING_JOB
211213
sagemaker_session: Optional[Session] = None
@@ -363,9 +365,10 @@ def _populate_intelligent_defaults_from_model_trainer_space(self):
363365

364366
def __del__(self):
365367
"""Destructor method to clean up the temporary directory."""
366-
# Clean up the temporary directory if it exists
367-
if self._temp_recipe_train_dir is not None:
368-
self._temp_recipe_train_dir.cleanup()
368+
# Clean up the temporary directory if it exists and class was initialized
369+
if hasattr(self, "__pydantic_fields_set__"):
370+
if self._temp_recipe_train_dir is not None:
371+
self._temp_recipe_train_dir.cleanup()
369372

370373
def _validate_training_image_and_algorithm_name(
371374
self, training_image: Optional[str], algorithm_name: Optional[str]

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import os
2020
import pytest
21+
from pydantic import ValidationError
2122
from unittest.mock import patch, MagicMock, ANY
2223

2324
from sagemaker import image_uris
@@ -1068,3 +1069,27 @@ def test_safe_configs():
10681069
# Test invalid type fails
10691070
with pytest.raises(ValueError):
10701071
SourceCode(entry_script=1)
1072+
1073+
1074+
@patch("sagemaker.modules.train.model_trainer.TemporaryDirectory")
1075+
def test_destructor_cleanup(mock_tmp_dir, modules_session):
1076+
1077+
with pytest.raises(ValidationError):
1078+
model_trainer = ModelTrainer(
1079+
training_image=DEFAULT_IMAGE,
1080+
role=DEFAULT_ROLE,
1081+
sagemaker_session=modules_session,
1082+
compute="test"
1083+
)
1084+
mock_tmp_dir.cleanup.assert_not_called()
1085+
1086+
model_trainer = ModelTrainer(
1087+
training_image=DEFAULT_IMAGE,
1088+
role=DEFAULT_ROLE,
1089+
sagemaker_session=modules_session,
1090+
compute=DEFAULT_COMPUTE_CONFIG,
1091+
)
1092+
model_trainer._temp_recipe_train_dir = mock_tmp_dir
1093+
mock_tmp_dir.assert_not_called()
1094+
del model_trainer
1095+
mock_tmp_dir.cleanup.assert_called_once()

0 commit comments

Comments
 (0)