Skip to content

feat: Enable spot training on remote decorator and executor #4077

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/sagemaker/remote_function/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def remote(
volume_size: int = 30,
encrypt_inter_container_traffic: bool = None,
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
):
"""Decorator for running the annotated function as a SageMaker training job.

Expand Down Expand Up @@ -255,6 +257,14 @@ def remote(
Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri
will be used for training. Note that ``image_uri`` can not be specified at the
same time otherwise a ``ValueError`` is thrown. Defaults to ``None``.

use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for
training. If enabled then the ``max_wait_time_in_seconds`` arg should also be set.
Defaults to ``False``.

max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
After this amount of time Amazon SageMaker will stop waiting for managed spot training
job to complete. Defaults to ``None``.
"""

def _remote(func):
Expand Down Expand Up @@ -284,6 +294,8 @@ def _remote(func):
volume_size=volume_size,
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
spark_config=spark_config,
use_spot_instances=use_spot_instances,
max_wait_time_in_seconds=max_wait_time_in_seconds,
)

@functools.wraps(func)
Expand Down Expand Up @@ -492,6 +504,8 @@ def __init__(
volume_size: int = 30,
encrypt_inter_container_traffic: bool = None,
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
):
"""Constructor for RemoteExecutor

Expand Down Expand Up @@ -670,6 +684,14 @@ def __init__(
Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri
will be used for training. Note that ``image_uri`` can not be specified at the
same time otherwise a ``ValueError`` is thrown. Defaults to ``None``.

use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for
training. If enabled then the ``max_wait_time_in_seconds`` arg should also be set.
Defaults to ``False``.

max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
After this amount of time Amazon SageMaker will stop waiting for managed spot training
job to complete. Defaults to ``None``.
"""
self.max_parallel_jobs = max_parallel_jobs

Expand Down Expand Up @@ -707,6 +729,8 @@ def __init__(
volume_size=volume_size,
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
spark_config=spark_config,
use_spot_instances=use_spot_instances,
max_wait_time_in_seconds=max_wait_time_in_seconds,
)

self._state_condition = threading.Condition()
Expand Down
24 changes: 21 additions & 3 deletions src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ def __init__(
volume_size: int = 30,
encrypt_inter_container_traffic: bool = None,
spark_config: SparkConfig = None,
use_spot_instances=False,
max_wait_time_in_seconds=None,
):
"""Initialize a _JobSettings instance which configures the remote job.

Expand Down Expand Up @@ -353,6 +355,14 @@ def __init__(
Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri
will be used for training. Note that ``image_uri`` can not be specified at the
same time otherwise a ``ValueError`` is thrown. Defaults to ``None``.

use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for
training. If enabled then the ``max_wait`` arg should also be set.
Defaults to ``False``.

max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
After this amount of time Amazon SageMaker will stop waiting for managed spot
training job to complete. Defaults to ``None``.
"""
self.sagemaker_session = sagemaker_session or Session()
self.environment_variables = resolve_value_from_config(
Expand Down Expand Up @@ -439,6 +449,8 @@ def __init__(
self.max_retry_attempts = max_retry_attempts
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
self.spark_config = spark_config
self.use_spot_instances = use_spot_instances
self.max_wait_time_in_seconds = max_wait_time_in_seconds
self.job_conda_env = resolve_value_from_config(
direct_input=job_conda_env,
config_path=REMOTE_FUNCTION_JOB_CONDA_ENV,
Expand Down Expand Up @@ -648,12 +660,16 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non

stored_function.save(func, *func_args, **func_kwargs)

stopping_condition = {
"MaxRuntimeInSeconds": job_settings.max_runtime_in_seconds,
}
if job_settings.max_wait_time_in_seconds is not None:
stopping_condition["MaxWaitTimeInSeconds"] = job_settings.max_wait_time_in_seconds

request_dict = dict(
TrainingJobName=job_name,
RoleArn=job_settings.role,
StoppingCondition={
"MaxRuntimeInSeconds": job_settings.max_runtime_in_seconds,
},
StoppingCondition=stopping_condition,
RetryStrategy={"MaximumRetryAttempts": job_settings.max_retry_attempts},
)

Expand Down Expand Up @@ -742,6 +758,8 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
if job_settings.vpc_config:
request_dict["VpcConfig"] = job_settings.vpc_config

request_dict["EnableManagedSpotTraining"] = job_settings.use_spot_instances

request_dict["Environment"] = job_settings.environment_variables

extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri)
Expand Down
18 changes: 18 additions & 0 deletions tests/integ/sagemaker/remote_function/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,24 @@ def get_file_content(file_names):
assert "line 2: bws: command not found" in str(e)


def test_decorator_with_spot_instances(
sagemaker_session, dummy_container_without_error, cpu_instance_type
):
@remote(
role=ROLE,
image_uri=dummy_container_without_error,
instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
use_spot_instances=True,
max_wait_time_in_seconds=48 * 60 * 60,
)
def divide(x, y):
return x / y

assert divide(10, 2) == 5
assert divide(20, 2) == 10


@pytest.mark.skip
def test_decorator_with_spark_job(sagemaker_session, cpu_instance_type):
@remote(
Expand Down
40 changes: 40 additions & 0 deletions tests/integ/sagemaker/remote_function/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,46 @@ def cube(x):
assert metric_summary.avg == 550


def test_executor_submit_using_spot_instances(
sagemaker_session, dummy_container_without_error, cpu_instance_type
):
def square_on_spot_instance(x):
return x * x

def cube_on_spot_instance(x):
return x * x * x

with RemoteExecutor(
max_parallel_jobs=1,
role=ROLE,
image_uri=dummy_container_without_error,
instance_type=cpu_instance_type,
sagemaker_session=sagemaker_session,
use_spot_instances=True,
max_wait_time_in_seconds=48 * 60 * 60,
) as e:
future_1 = e.submit(square_on_spot_instance, 10)
future_2 = e.submit(cube_on_spot_instance, 10)

assert future_1.result() == 100
assert future_2.result() == 1000

assert get_future(future_1._job.job_name, sagemaker_session).result() == 100
assert get_future(future_2._job.job_name, sagemaker_session).result() == 1000

describe_job_1 = next(
list_futures(job_name_prefix="square-on-spot-instance", sagemaker_session=sagemaker_session)
)._job.describe()
assert describe_job_1["EnableManagedSpotTraining"] is True
assert describe_job_1["StoppingCondition"]["MaxWaitTimeInSeconds"] == 172800

describe_job_2 = next(
list_futures(job_name_prefix="cube-on-spot-instance", sagemaker_session=sagemaker_session)
)._job.describe()
assert describe_job_2["EnableManagedSpotTraining"] is True
assert describe_job_2["StoppingCondition"]["MaxWaitTimeInSeconds"] == 172800


def test_executor_map_with_run(sagemaker_session, dummy_container_without_error, cpu_instance_type):
def square(x):
with load_run() as run:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,8 @@ def test_remote_decorator_fields_consistency(get_execution_role, session):
"volume_kms_key",
"vpc_config",
"tags",
"use_spot_instances",
"max_wait_time_in_seconds",
}

job_settings = _JobSettings(
Expand Down
20 changes: 20 additions & 0 deletions tests/unit/sagemaker/remote_function/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,26 @@ def square(x):
assert mock_job_settings.call_args.kwargs["spark_config"] == spark_config


@patch(
"sagemaker.remote_function.core.serialization.deserialize_obj_from_s3",
return_value=EXPECTED_JOB_RESULT,
)
@patch("sagemaker.remote_function.client._JobSettings")
@patch("sagemaker.remote_function.client._Job.start")
def test_decorator_with_spot_instances(mock_start, mock_job_settings, mock_deserialize):
mock_job = Mock(job_name=TRAINING_JOB_NAME)
mock_job.describe.return_value = COMPLETED_TRAINING_JOB

mock_start.return_value = mock_job

@remote(use_spot_instances=True, max_wait_time_in_seconds=48 * 60 * 60)
def square(x):
pass

assert mock_job_settings.call_args.kwargs["use_spot_instances"] is True
assert mock_job_settings.call_args.kwargs["max_wait_time_in_seconds"] == 172800


@pytest.mark.parametrize(
"args, kwargs, error_message",
[
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/sagemaker/remote_function/test_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ def test_start(
),
EnableNetworkIsolation=False,
EnableInterContainerTrafficEncryption=True,
EnableManagedSpotTraining=False,
Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
)

Expand Down Expand Up @@ -537,6 +538,7 @@ def test_start_with_complete_job_settings(
EnableNetworkIsolation=False,
EnableInterContainerTrafficEncryption=False,
VpcConfig=dict(Subnets=["subnet"], SecurityGroupIds=["sg"]),
EnableManagedSpotTraining=False,
Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
)

Expand Down Expand Up @@ -667,6 +669,7 @@ def test_start_with_spark(
),
EnableNetworkIsolation=False,
EnableInterContainerTrafficEncryption=True,
EnableManagedSpotTraining=False,
Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
)

Expand Down