Skip to content

feat: Support job checkpoint in remote function #4171

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/sagemaker/remote_function/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from __future__ import absolute_import

from sagemaker.remote_function.client import remote, RemoteExecutor # noqa: F401
from sagemaker.remote_function.checkpoint_location import CheckpointLocation # noqa: F401
47 changes: 47 additions & 0 deletions src/sagemaker/remote_function/checkpoint_location.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module is used to define the CheckpointLocation to remote function."""
from __future__ import absolute_import

from os import PathLike
import re

# Regex is taken from https://docs.aws.amazon.com/sagemaker/latest/APIReference/API_CheckpointConfig.html
S3_URI_REGEX_PATTERN = r"^(https|s3)://([^/]+)/?(.*)$"

_JOB_CHECKPOINT_LOCATION = "/opt/ml/checkpoints/"


def _validate_s3_uri_for_checkpoint(s3_uri: str):
"""Validate if checkpoint location is specified with a valid s3 URI."""
return re.match(S3_URI_REGEX_PATTERN, s3_uri)


class CheckpointLocation(PathLike):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to add this to the readthedoc.

"""Class to represent the location where checkpoints are accessed in a remote function.

To save or load checkpoints in a remote function, pass an CheckpointLocation object as a
function parameter and use it as a os.PathLike object. This CheckpointLocation object
represents the local directory (/opt/ml/checkpoints/) of checkpoints in side the job.
"""

_local_path = _JOB_CHECKPOINT_LOCATION

def __init__(self, s3_uri):
if not _validate_s3_uri_for_checkpoint(s3_uri):
raise ValueError("CheckpointLocation should be specified with valid s3 URI.")
self._s3_uri = s3_uri

def __fspath__(self):
"""Return job local path where checkpoints are stored."""
return self._local_path
47 changes: 47 additions & 0 deletions src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from sagemaker.experiments.run import Run
from sagemaker.image_uris import get_base_python_image_uri
from sagemaker import image_uris
from sagemaker.remote_function.checkpoint_location import CheckpointLocation
from sagemaker.session import get_execution_role, _logs_for_job, Session
from sagemaker.utils import name_from_base, _tmpdir, resolve_value_from_config
from sagemaker.s3 import s3_path_join, S3Uploader
Expand Down Expand Up @@ -681,6 +682,8 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
RetryStrategy={"MaximumRetryAttempts": job_settings.max_retry_attempts},
)

_update_job_request_with_checkpoint_config(func_args, func_kwargs, request_dict)

if job_settings.tags:
request_dict["Tags"] = job_settings.tags

Expand Down Expand Up @@ -1180,6 +1183,50 @@ def _extend_spark_config_to_request(
return extended_request


def _update_job_request_with_checkpoint_config(args, kwargs, request_dict):
"""Extend job request with checkpoint config based on CheckpointLocation in function args.

Args:
args (tuple): The positional arguments of the remote function.
kwargs (Dict): The keyword arguments of the remote function.
request_dict (Dict): create training job request dict.
"""
checkpoint_location_index_in_args = None
checkpoint_location_key_in_kwargs = None
checkpoint_location_count = 0

for index, arg in enumerate(args):
if isinstance(arg, CheckpointLocation):
checkpoint_location_index_in_args = index
checkpoint_location_count += 1

for key, value in kwargs.items():
if isinstance(value, CheckpointLocation):
checkpoint_location_key_in_kwargs = key
checkpoint_location_count += 1

if checkpoint_location_count < 1:
return

if checkpoint_location_count > 1:
raise ValueError(
"Remote function cannot have more than one argument of type CheckpointLocation."
)

if checkpoint_location_index_in_args is not None:
checkpoint_location_arg = args[checkpoint_location_index_in_args]
else:
checkpoint_location_arg = kwargs[checkpoint_location_key_in_kwargs]

checkpoint_s3_uri = checkpoint_location_arg._s3_uri
checkpoint_local_path = checkpoint_location_arg._local_path

request_dict["CheckpointConfig"] = {
"LocalPath": checkpoint_local_path,
"S3Uri": checkpoint_s3_uri,
}


@dataclasses.dataclass
class _RunInfo:
"""Data class to hold information of the run object from context."""
Expand Down
69 changes: 69 additions & 0 deletions tests/integ/sagemaker/remote_function/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import
import time
from typing import Union


import pytest
Expand All @@ -23,6 +24,7 @@
import subprocess
import shlex
from sagemaker.experiments.run import Run, load_run
from sagemaker.remote_function import CheckpointLocation
from tests.integ.sagemaker.experiments.helpers import cleanup_exp_resources
from sagemaker.experiments.trial_component import _TrialComponent
from sagemaker.experiments._api_types import _TrialComponentStatusType
Expand All @@ -40,15 +42,27 @@

from tests.integ.kms_utils import get_or_create_kms_key
from tests.integ import DATA_DIR
from tests.integ.s3_utils import assert_s3_files_exist

ROLE = "SageMakerRole"
CHECKPOINT_FILE_CONTENT = "test checkpoint file"


@pytest.fixture(scope="module")
def s3_kms_key(sagemaker_session):
return get_or_create_kms_key(sagemaker_session=sagemaker_session)


@pytest.fixture(scope="module")
def checkpoint_s3_location(sagemaker_session):
def random_s3_uri():
return "".join(random.choices(string.ascii_uppercase + string.digits, k=10))

return "s3://{}/rm-func-checkpoints/{}".format(
sagemaker_session.default_bucket(), random_s3_uri()
)


def test_decorator(sagemaker_session, dummy_container_without_error, cpu_instance_type):
@remote(
role=ROLE,
Expand Down Expand Up @@ -626,6 +640,61 @@ def divide(x, y):
assert divide(20, 2) == 10


def test_decorator_with_spot_instances_save_and_load_checkpoints(
sagemaker_session,
dummy_container_without_error,
cpu_instance_type,
checkpoint_s3_location,
):
@remote(
role=ROLE,
image_uri=dummy_container_without_error,
instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
use_spot_instances=True,
max_wait_time_in_seconds=48 * 60 * 60,
)
def save_checkpoints(checkpoint_path: Union[str, os.PathLike]):
file_path_1 = os.path.join(checkpoint_path, "checkpoint_1.json")
with open(file_path_1, "w") as f:
f.write(CHECKPOINT_FILE_CONTENT)

file_path_2 = os.path.join(checkpoint_path, "checkpoint_2.json")
with open(file_path_2, "w") as f:
f.write(CHECKPOINT_FILE_CONTENT)

return CHECKPOINT_FILE_CONTENT

@remote(
role=ROLE,
image_uri=dummy_container_without_error,
instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
use_spot_instances=True,
max_wait_time_in_seconds=48 * 60 * 60,
)
def load_checkpoints(checkpoint_path: Union[str, os.PathLike]):
file_path_1 = os.path.join(checkpoint_path, "checkpoint_1.json")
with open(file_path_1, "r") as file:
file_content_1 = file.read()

file_path_2 = os.path.join(checkpoint_path, "checkpoint_2.json")
with open(file_path_2, "r") as file:
file_content_2 = file.read()

return file_content_1 + file_content_2

assert save_checkpoints(CheckpointLocation(checkpoint_s3_location)) == CHECKPOINT_FILE_CONTENT
assert_s3_files_exist(
sagemaker_session, checkpoint_s3_location, ["checkpoint_1.json", "checkpoint_2.json"]
)

assert (
load_checkpoints(CheckpointLocation(checkpoint_s3_location))
== CHECKPOINT_FILE_CONTENT + CHECKPOINT_FILE_CONTENT
)


@pytest.mark.skip
def test_decorator_with_spark_job(sagemaker_session, cpu_instance_type):
@remote(
Expand Down
Loading