Skip to content

Commit 135171f

Browse files
chuyang-dengDan
authored andcommitted
fix: take checkpoint_s3_uri and checkpoint_local_path in Framework class (#1080)
1 parent 64834c2 commit 135171f

File tree

3 files changed

+55
-0
lines changed

3 files changed

+55
-0
lines changed

src/sagemaker/estimator.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1227,6 +1227,8 @@ def __init__(
12271227
dependencies=None,
12281228
enable_network_isolation=False,
12291229
git_config=None,
1230+
checkpoint_s3_uri=None,
1231+
checkpoint_local_path=None,
12301232
**kwargs
12311233
):
12321234
"""Base class initializer. Subclasses which override ``__init__`` should
@@ -1363,6 +1365,17 @@ def __init__(
13631365
authentication if they are provided; otherwise, python SDK will
13641366
try to use either CodeCommit credential helper or local
13651367
credential storage for authentication.
1368+
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
1369+
that the algorithm persists (if any) during training. (default:
1370+
``None``).
1371+
checkpoint_local_path (str): The local path that the algorithm
1372+
writes its checkpoints to. SageMaker will persist all files
1373+
under this path to `checkpoint_s3_uri` continually during
1374+
training. On job startup the reverse happens - data from the
1375+
s3 location is downloaded to this path before the algorithm is
1376+
started. If the path is unset then SageMaker assumes the
1377+
checkpoints will be provided under `/opt/ml/checkpoints/`.
1378+
(default: ``None``).
13661379
**kwargs: Additional kwargs passed to the ``EstimatorBase``
13671380
constructor.
13681381
"""
@@ -1391,6 +1404,8 @@ def __init__(
13911404
self.uploaded_code = None
13921405

13931406
self._hyperparameters = hyperparameters or {}
1407+
self.checkpoint_s3_uri = checkpoint_s3_uri
1408+
self.checkpoint_local_path = checkpoint_local_path
13941409

13951410
def enable_network_isolation(self):
13961411
"""Return True if this Estimator can use network isolation to run.

tests/integ/test_tf_script_mode.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,42 @@ def test_mnist(sagemaker_session, instance_type):
6666
assert df.size > 0
6767

6868

69+
@pytest.mark.skipif(
70+
tests.integ.test_region() != "us-east-1",
71+
reason="checkpoint s3 bucket is in us-east-1, ListObjectsV2 will fail in other regions",
72+
)
73+
def test_checkpoint_config(sagemaker_session, instance_type):
74+
checkpoint_s3_uri = "s3://142577830533-us-east-1-sagemaker-checkpoint"
75+
checkpoint_local_path = "/test/checkpoint/path"
76+
estimator = TensorFlow(
77+
entry_point=SCRIPT,
78+
role="SageMakerRole",
79+
train_instance_count=1,
80+
train_instance_type=instance_type,
81+
sagemaker_session=sagemaker_session,
82+
script_mode=True,
83+
framework_version=TensorFlow.LATEST_VERSION,
84+
py_version=tests.integ.PYTHON_VERSION,
85+
checkpoint_s3_uri=checkpoint_s3_uri,
86+
checkpoint_local_path=checkpoint_local_path,
87+
)
88+
inputs = estimator.sagemaker_session.upload_data(
89+
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="script/mnist"
90+
)
91+
training_job_name = unique_name_from_base("test-tf-sm-checkpoint")
92+
with tests.integ.timeout.timeout(minutes=tests.integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
93+
estimator.fit(inputs=inputs, job_name=training_job_name)
94+
95+
expected_training_checkpoint_config = {
96+
"S3Uri": checkpoint_s3_uri,
97+
"LocalPath": checkpoint_local_path,
98+
}
99+
actual_training_checkpoint_config = sagemaker_session.sagemaker_client.describe_training_job(
100+
TrainingJobName=training_job_name
101+
)["CheckpointConfig"]
102+
assert actual_training_checkpoint_config == expected_training_checkpoint_config
103+
104+
69105
def test_server_side_encryption(sagemaker_session):
70106
boto_session = sagemaker_session.boto_session
71107
with kms_utils.bucket_with_encryption(boto_session, ROLE) as (bucket_with_kms, kms_key):

tests/unit/test_estimator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ def test_framework_all_init_args(sagemaker_session):
203203
security_group_ids=["789", "012"],
204204
metric_definitions=[{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}],
205205
encrypt_inter_container_traffic=True,
206+
checkpoint_s3_uri="s3://bucket/checkpoint",
207+
checkpoint_local_path="file://local/checkpoint",
206208
)
207209
_TrainingJob.start_new(f, "s3://mydata")
208210
sagemaker_session.train.assert_called_once()
@@ -237,6 +239,8 @@ def test_framework_all_init_args(sagemaker_session):
237239
},
238240
"metric_definitions": [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}],
239241
"encrypt_inter_container_traffic": True,
242+
"checkpoint_s3_uri": "s3://bucket/checkpoint",
243+
"checkpoint_local_path": "file://local/checkpoint",
240244
}
241245

242246

0 commit comments

Comments
 (0)