File tree Expand file tree Collapse file tree 2 files changed +32
-4
lines changed
src/sagemaker/modules/train
tests/unit/sagemaker/modules/train Expand file tree Collapse file tree 2 files changed +32
-4
lines changed Original file line number Diff line number Diff line change @@ -205,7 +205,9 @@ class ModelTrainer(BaseModel):
205
205
"LOCAL_CONTAINER" mode.
206
206
"""
207
207
208
- model_config = ConfigDict (arbitrary_types_allowed = True , extra = "forbid" )
208
+ model_config = ConfigDict (
209
+ arbitrary_types_allowed = True , validate_assignment = True , extra = "forbid"
210
+ )
209
211
210
212
training_mode : Mode = Mode .SAGEMAKER_TRAINING_JOB
211
213
sagemaker_session : Optional [Session ] = None
@@ -363,9 +365,10 @@ def _populate_intelligent_defaults_from_model_trainer_space(self):
363
365
364
366
def __del__ (self ):
365
367
"""Destructor method to clean up the temporary directory."""
366
- # Clean up the temporary directory if it exists
367
- if self ._temp_recipe_train_dir is not None :
368
- self ._temp_recipe_train_dir .cleanup ()
368
+ # Clean up the temporary directory if it exists and class was initialized
369
+ if hasattr (self , "__pydantic_fields_set__" ):
370
+ if self ._temp_recipe_train_dir is not None :
371
+ self ._temp_recipe_train_dir .cleanup ()
369
372
370
373
def _validate_training_image_and_algorithm_name (
371
374
self , training_image : Optional [str ], algorithm_name : Optional [str ]
Original file line number Diff line number Diff line change 18
18
import json
19
19
import os
20
20
import pytest
21
+ from pydantic import ValidationError
21
22
from unittest .mock import patch , MagicMock , ANY
22
23
23
24
from sagemaker import image_uris
@@ -1068,3 +1069,27 @@ def test_safe_configs():
1068
1069
# Test invalid type fails
1069
1070
with pytest .raises (ValueError ):
1070
1071
SourceCode (entry_script = 1 )
1072
+
1073
+
1074
+ @patch ("sagemaker.modules.train.model_trainer.TemporaryDirectory" )
1075
+ def test_destructor_cleanup (mock_tmp_dir , modules_session ):
1076
+
1077
+ with pytest .raises (ValidationError ):
1078
+ model_trainer = ModelTrainer (
1079
+ training_image = DEFAULT_IMAGE ,
1080
+ role = DEFAULT_ROLE ,
1081
+ sagemaker_session = modules_session ,
1082
+ compute = "test"
1083
+ )
1084
+ mock_tmp_dir .cleanup .assert_not_called ()
1085
+
1086
+ model_trainer = ModelTrainer (
1087
+ training_image = DEFAULT_IMAGE ,
1088
+ role = DEFAULT_ROLE ,
1089
+ sagemaker_session = modules_session ,
1090
+ compute = DEFAULT_COMPUTE_CONFIG ,
1091
+ )
1092
+ model_trainer ._temp_recipe_train_dir = mock_tmp_dir
1093
+ mock_tmp_dir .assert_not_called ()
1094
+ del model_trainer
1095
+ mock_tmp_dir .cleanup .assert_called_once ()
You can’t perform that action at this time.
0 commit comments