Skip to content

Commit 088d949

Browse files
committed
update checkpoint config
1 parent f5791ce commit 088d949

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

src/sagemaker/modules/train/model_trainer.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353

5454
from sagemaker.utils import resolve_value_from_config
5555
from sagemaker.modules import Session, get_execution_role
56+
from sagemaker.modules import configs
5657
from sagemaker.modules.configs import (
5758
Compute,
5859
StoppingCondition,
@@ -1132,7 +1133,9 @@ def with_tensorboard_output_config(
11321133
tensorboard_output_config (sagemaker.modules.configs.TensorBoardOutputConfig):
11331134
The TensorBoard output configuration.
11341135
"""
1135-
self._tensorboard_output_config = tensorboard_output_config or TensorBoardOutputConfig()
1136+
self._tensorboard_output_config = (
1137+
tensorboard_output_config or configs.TensorBoardOutputConfig()
1138+
)
11361139
return self
11371140

11381141
def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer": # noqa: D412
@@ -1227,3 +1230,25 @@ def with_remote_debug_config(
12271230
enable_remote_debug=True
12281231
)
12291232
return self
1233+
1234+
def with_checkpoint_config(
1235+
self, checkpoint_config: Optional[CheckpointConfig] = None
1236+
) -> "ModelTrainer":
1237+
"""Set the checkpoint configuration for the training job.
1238+
1239+
Example:
1240+
1241+
.. code:: python
1242+
1243+
from sagemaker.modules.train import ModelTrainer
1244+
1245+
model_trainer = ModelTrainer(
1246+
...
1247+
).with_checkpoint_config()
1248+
1249+
Args:
1250+
checkpoint_config (sagemaker.modules.configs.CheckpointConfig):
1251+
The checkpoint configuration for the training job.
1252+
"""
1253+
self.checkpoint_config = checkpoint_config or configs.CheckpointConfig()
1254+
return self

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,12 +1242,16 @@ def mock_upload_data(path, bucket, key_prefix):
12421242
modules_session.upload_data.side_effect = mock_upload_data
12431243
mock_unique_name.return_value = unique_name
12441244

1245-
model_trainer = ModelTrainer(
1246-
training_image=DEFAULT_IMAGE,
1247-
sagemaker_session=modules_session,
1248-
checkpoint_config=CheckpointConfig(),
1249-
base_job_name=base_name,
1250-
).with_tensorboard_output_config(TensorBoardOutputConfig())
1245+
model_trainer = (
1246+
ModelTrainer(
1247+
training_image=DEFAULT_IMAGE,
1248+
sagemaker_session=modules_session,
1249+
base_job_name=base_name,
1250+
)
1251+
.with_tensorboard_output_config()
1252+
.with_checkpoint_config()
1253+
)
1254+
12511255
model_trainer.train()
12521256

12531257
_, kwargs = mock_training_job.create.call_args

0 commit comments

Comments
 (0)