Skip to content

Commit 00695b9

Browse files
Ao Guogoelakash
authored andcommitted
feat: Support job checkpoint in remote function
1 parent 5d0f6b3 commit 00695b9

File tree

5 files changed

+319
-0
lines changed

5 files changed

+319
-0
lines changed

src/sagemaker/remote_function/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,4 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.remote_function.client import remote, RemoteExecutor # noqa: F401
17+
from sagemaker.remote_function.checkpoint_location import CheckpointLocation # noqa: F401
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""This module is used to define the CheckpointLocation to remote function."""
14+
from __future__ import absolute_import
15+
16+
from os import PathLike
17+
import re
18+
19+
# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CheckpointConfig.html
20+
S3_URI_REGEX_PATTERN = r"^(https|s3)://([^/]+)/?(.*)$"
21+
22+
_JOB_CHECKPOINT_LOCATION = "/opt/ml/checkpoints/"
23+
24+
25+
def _validate_s3_uri_for_checkpoint(s3_uri: str):
26+
"""Validate if checkpoint location is specified with a valid s3 URI."""
27+
return re.match(S3_URI_REGEX_PATTERN, s3_uri)
28+
29+
30+
class CheckpointLocation(PathLike):
31+
"""Class to represent the location where checkpoints are accessed in a remote function.
32+
33+
To save or load checkpoints in a remote function, pass an CheckpointLocation object as a
34+
function parameter and use it as a os.PathLike object. This CheckpointLocation object
35+
represents the local directory (/opt/ml/checkpoints/) of checkpoints in side the job.
36+
"""
37+
38+
_local_path = _JOB_CHECKPOINT_LOCATION
39+
40+
def __init__(self, s3_uri):
41+
if not _validate_s3_uri_for_checkpoint(s3_uri):
42+
raise ValueError("CheckpointLocation should be specified with valid s3 URI.")
43+
self._s3_uri = s3_uri
44+
45+
def __fspath__(self):
46+
"""Return job local path where checkpoints are stored."""
47+
return self._local_path

src/sagemaker/remote_function/job.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from sagemaker.experiments.run import Run
4747
from sagemaker.image_uris import get_base_python_image_uri
4848
from sagemaker import image_uris
49+
from sagemaker.remote_function.checkpoint_location import CheckpointLocation
4950
from sagemaker.session import get_execution_role, _logs_for_job, Session
5051
from sagemaker.utils import name_from_base, _tmpdir, resolve_value_from_config
5152
from sagemaker.s3 import s3_path_join, S3Uploader
@@ -673,6 +674,8 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
673674
RetryStrategy={"MaximumRetryAttempts": job_settings.max_retry_attempts},
674675
)
675676

677+
_update_job_request_with_checkpoint_config(func_args, func_kwargs, request_dict)
678+
676679
if job_settings.tags:
677680
request_dict["Tags"] = job_settings.tags
678681

@@ -1171,6 +1174,50 @@ def _extend_spark_config_to_request(
11711174
return extended_request
11721175

11731176

1177+
def _update_job_request_with_checkpoint_config(args, kwargs, request_dict):
1178+
"""Extend job request with checkpoint config based on CheckpointLocation in function args.
1179+
1180+
Args:
1181+
args (tuple): The positional arguments of the remote function.
1182+
kwargs (Dict): The keyword arguments of the remote function.
1183+
request_dict (Dict): create training job request dict.
1184+
"""
1185+
checkpoint_location_index_in_args = None
1186+
checkpoint_location_key_in_kwargs = None
1187+
checkpoint_location_count = 0
1188+
1189+
for index, arg in enumerate(args):
1190+
if isinstance(arg, CheckpointLocation):
1191+
checkpoint_location_index_in_args = index
1192+
checkpoint_location_count += 1
1193+
1194+
for key, value in kwargs.items():
1195+
if isinstance(value, CheckpointLocation):
1196+
checkpoint_location_key_in_kwargs = key
1197+
checkpoint_location_count += 1
1198+
1199+
if checkpoint_location_count < 1:
1200+
return
1201+
1202+
if checkpoint_location_count > 1:
1203+
raise ValueError(
1204+
"Remote function cannot have more than one argument of type CheckpointLocation."
1205+
)
1206+
1207+
if checkpoint_location_index_in_args is not None:
1208+
checkpoint_location_arg = args[checkpoint_location_index_in_args]
1209+
else:
1210+
checkpoint_location_arg = kwargs[checkpoint_location_key_in_kwargs]
1211+
1212+
checkpoint_s3_uri = checkpoint_location_arg._s3_uri
1213+
checkpoint_local_path = checkpoint_location_arg._local_path
1214+
1215+
request_dict["CheckpointConfig"] = {
1216+
"LocalPath": checkpoint_local_path,
1217+
"S3Uri": checkpoint_s3_uri,
1218+
}
1219+
1220+
11741221
@dataclasses.dataclass
11751222
class _RunInfo:
11761223
"""Data class to hold information of the run object from context."""

tests/integ/sagemaker/remote_function/test_decorator.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414
import time
15+
from typing import Union
1516

1617

1718
import pytest
@@ -23,6 +24,7 @@
2324
import subprocess
2425
import shlex
2526
from sagemaker.experiments.run import Run, load_run
27+
from sagemaker.remote_function import CheckpointLocation
2628
from tests.integ.sagemaker.experiments.helpers import cleanup_exp_resources
2729
from sagemaker.experiments.trial_component import _TrialComponent
2830
from sagemaker.experiments._api_types import _TrialComponentStatusType
@@ -40,15 +42,27 @@
4042

4143
from tests.integ.kms_utils import get_or_create_kms_key
4244
from tests.integ import DATA_DIR
45+
from tests.integ.s3_utils import assert_s3_files_exist
4346

4447
ROLE = "SageMakerRole"
48+
CHECKPOINT_FILE_CONTENT = "test checkpoint file"
4549

4650

4751
@pytest.fixture(scope="module")
4852
def s3_kms_key(sagemaker_session):
4953
return get_or_create_kms_key(sagemaker_session=sagemaker_session)
5054

5155

56+
@pytest.fixture(scope="module")
57+
def checkpoint_s3_location(sagemaker_session):
58+
def random_s3_uri():
59+
return "".join(random.choices(string.ascii_uppercase + string.digits, k=10))
60+
61+
return "s3://{}/rm-func-checkpoints/{}".format(
62+
sagemaker_session.default_bucket(), random_s3_uri()
63+
)
64+
65+
5266
def test_decorator(sagemaker_session, dummy_container_without_error, cpu_instance_type):
5367
@remote(
5468
role=ROLE,
@@ -626,6 +640,61 @@ def divide(x, y):
626640
assert divide(20, 2) == 10
627641

628642

643+
def test_decorator_with_spot_instances_save_and_load_checkpoints(
644+
sagemaker_session,
645+
dummy_container_without_error,
646+
cpu_instance_type,
647+
checkpoint_s3_location,
648+
):
649+
@remote(
650+
role=ROLE,
651+
image_uri=dummy_container_without_error,
652+
instance_type=cpu_instance_type,
653+
sagemaker_session=sagemaker_session,
654+
use_spot_instances=True,
655+
max_wait_time_in_seconds=48 * 60 * 60,
656+
)
657+
def save_checkpoints(checkpoint_path: Union[str, os.PathLike]):
658+
file_path_1 = os.path.join(checkpoint_path, "checkpoint_1.json")
659+
with open(file_path_1, "w") as f:
660+
f.write(CHECKPOINT_FILE_CONTENT)
661+
662+
file_path_2 = os.path.join(checkpoint_path, "checkpoint_2.json")
663+
with open(file_path_2, "w") as f:
664+
f.write(CHECKPOINT_FILE_CONTENT)
665+
666+
return CHECKPOINT_FILE_CONTENT
667+
668+
@remote(
669+
role=ROLE,
670+
image_uri=dummy_container_without_error,
671+
instance_type=cpu_instance_type,
672+
sagemaker_session=sagemaker_session,
673+
use_spot_instances=True,
674+
max_wait_time_in_seconds=48 * 60 * 60,
675+
)
676+
def load_checkpoints(checkpoint_path: Union[str, os.PathLike]):
677+
file_path_1 = os.path.join(checkpoint_path, "checkpoint_1.json")
678+
with open(file_path_1, "r") as file:
679+
file_content_1 = file.read()
680+
681+
file_path_2 = os.path.join(checkpoint_path, "checkpoint_2.json")
682+
with open(file_path_2, "r") as file:
683+
file_content_2 = file.read()
684+
685+
return file_content_1 + file_content_2
686+
687+
assert save_checkpoints(CheckpointLocation(checkpoint_s3_location)) == CHECKPOINT_FILE_CONTENT
688+
assert_s3_files_exist(
689+
sagemaker_session, checkpoint_s3_location, ["checkpoint_1.json", "checkpoint_2.json"]
690+
)
691+
692+
assert (
693+
load_checkpoints(CheckpointLocation(checkpoint_s3_location))
694+
== CHECKPOINT_FILE_CONTENT + CHECKPOINT_FILE_CONTENT
695+
)
696+
697+
629698
@pytest.mark.skip
630699
def test_decorator_with_spark_job(sagemaker_session, cpu_instance_type):
631700
@remote(

0 commit comments

Comments
 (0)