Skip to content

Commit 23d3a68

Browse files
aoguo64Ao Guo
andauthored
feat: Enable spot training on remote decorator and executor (#4077)
Co-authored-by: Ao Guo <[email protected]>
1 parent 3eb45a6 commit 23d3a68

File tree

7 files changed

+128
-3
lines changed

7 files changed

+128
-3
lines changed

src/sagemaker/remote_function/client.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def remote(
8383
volume_size: int = 30,
8484
encrypt_inter_container_traffic: bool = None,
8585
spark_config: SparkConfig = None,
86+
use_spot_instances=False,
87+
max_wait_time_in_seconds=None,
8688
):
8789
"""Decorator for running the annotated function as a SageMaker training job.
8890
@@ -255,6 +257,14 @@ def remote(
255257
Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri
256258
will be used for training. Note that ``image_uri`` can not be specified at the
257259
same time otherwise a ``ValueError`` is thrown. Defaults to ``None``.
260+
261+
use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for
262+
training. If enabled then the ``max_wait_time_in_seconds`` arg should also be set.
263+
Defaults to ``False``.
264+
265+
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
266+
After this amount of time Amazon SageMaker will stop waiting for managed spot training
267+
job to complete. Defaults to ``None``.
258268
"""
259269

260270
def _remote(func):
@@ -284,6 +294,8 @@ def _remote(func):
284294
volume_size=volume_size,
285295
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
286296
spark_config=spark_config,
297+
use_spot_instances=use_spot_instances,
298+
max_wait_time_in_seconds=max_wait_time_in_seconds,
287299
)
288300

289301
@functools.wraps(func)
@@ -492,6 +504,8 @@ def __init__(
492504
volume_size: int = 30,
493505
encrypt_inter_container_traffic: bool = None,
494506
spark_config: SparkConfig = None,
507+
use_spot_instances=False,
508+
max_wait_time_in_seconds=None,
495509
):
496510
"""Constructor for RemoteExecutor
497511
@@ -670,6 +684,14 @@ def __init__(
670684
Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri
671685
will be used for training. Note that ``image_uri`` can not be specified at the
672686
same time otherwise a ``ValueError`` is thrown. Defaults to ``None``.
687+
688+
use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for
689+
training. If enabled then the ``max_wait_time_in_seconds`` arg should also be set.
690+
Defaults to ``False``.
691+
692+
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
693+
After this amount of time Amazon SageMaker will stop waiting for managed spot training
694+
job to complete. Defaults to ``None``.
673695
"""
674696
self.max_parallel_jobs = max_parallel_jobs
675697

@@ -707,6 +729,8 @@ def __init__(
707729
volume_size=volume_size,
708730
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
709731
spark_config=spark_config,
732+
use_spot_instances=use_spot_instances,
733+
max_wait_time_in_seconds=max_wait_time_in_seconds,
710734
)
711735

712736
self._state_condition = threading.Condition()

src/sagemaker/remote_function/job.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ def __init__(
191191
volume_size: int = 30,
192192
encrypt_inter_container_traffic: bool = None,
193193
spark_config: SparkConfig = None,
194+
use_spot_instances=False,
195+
max_wait_time_in_seconds=None,
194196
):
195197
"""Initialize a _JobSettings instance which configures the remote job.
196198
@@ -353,6 +355,14 @@ def __init__(
353355
Spark image. If ``spark_config`` is specified, a SageMaker Spark image uri
354356
will be used for training. Note that ``image_uri`` can not be specified at the
355357
same time otherwise a ``ValueError`` is thrown. Defaults to ``None``.
358+
359+
use_spot_instances (bool): Specifies whether to use SageMaker Managed Spot instances for
360+
training. If enabled then the ``max_wait`` arg should also be set.
361+
Defaults to ``False``.
362+
363+
max_wait_time_in_seconds (int): Timeout in seconds waiting for spot training job.
364+
After this amount of time Amazon SageMaker will stop waiting for managed spot
365+
training job to complete. Defaults to ``None``.
356366
"""
357367
self.sagemaker_session = sagemaker_session or Session()
358368
self.environment_variables = resolve_value_from_config(
@@ -439,6 +449,8 @@ def __init__(
439449
self.max_retry_attempts = max_retry_attempts
440450
self.keep_alive_period_in_seconds = keep_alive_period_in_seconds
441451
self.spark_config = spark_config
452+
self.use_spot_instances = use_spot_instances
453+
self.max_wait_time_in_seconds = max_wait_time_in_seconds
442454
self.job_conda_env = resolve_value_from_config(
443455
direct_input=job_conda_env,
444456
config_path=REMOTE_FUNCTION_JOB_CONDA_ENV,
@@ -648,12 +660,16 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
648660

649661
stored_function.save(func, *func_args, **func_kwargs)
650662

663+
stopping_condition = {
664+
"MaxRuntimeInSeconds": job_settings.max_runtime_in_seconds,
665+
}
666+
if job_settings.max_wait_time_in_seconds is not None:
667+
stopping_condition["MaxWaitTimeInSeconds"] = job_settings.max_wait_time_in_seconds
668+
651669
request_dict = dict(
652670
TrainingJobName=job_name,
653671
RoleArn=job_settings.role,
654-
StoppingCondition={
655-
"MaxRuntimeInSeconds": job_settings.max_runtime_in_seconds,
656-
},
672+
StoppingCondition=stopping_condition,
657673
RetryStrategy={"MaximumRetryAttempts": job_settings.max_retry_attempts},
658674
)
659675

@@ -742,6 +758,8 @@ def start(job_settings: _JobSettings, func, func_args, func_kwargs, run_info=Non
742758
if job_settings.vpc_config:
743759
request_dict["VpcConfig"] = job_settings.vpc_config
744760

761+
request_dict["EnableManagedSpotTraining"] = job_settings.use_spot_instances
762+
745763
request_dict["Environment"] = job_settings.environment_variables
746764

747765
extended_request = _extend_spark_config_to_request(request_dict, job_settings, s3_base_uri)

tests/integ/sagemaker/remote_function/test_decorator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,24 @@ def get_file_content(file_names):
601601
assert "line 2: bws: command not found" in str(e)
602602

603603

604+
def test_decorator_with_spot_instances(
605+
sagemaker_session, dummy_container_without_error, cpu_instance_type
606+
):
607+
@remote(
608+
role=ROLE,
609+
image_uri=dummy_container_without_error,
610+
instance_type=cpu_instance_type,
611+
sagemaker_session=sagemaker_session,
612+
use_spot_instances=True,
613+
max_wait_time_in_seconds=48 * 60 * 60,
614+
)
615+
def divide(x, y):
616+
return x / y
617+
618+
assert divide(10, 2) == 5
619+
assert divide(20, 2) == 10
620+
621+
604622
@pytest.mark.skip
605623
def test_decorator_with_spark_job(sagemaker_session, cpu_instance_type):
606624
@remote(

tests/integ/sagemaker/remote_function/test_executor.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,46 @@ def cube(x):
195195
assert metric_summary.avg == 550
196196

197197

198+
def test_executor_submit_using_spot_instances(
199+
sagemaker_session, dummy_container_without_error, cpu_instance_type
200+
):
201+
def square_on_spot_instance(x):
202+
return x * x
203+
204+
def cube_on_spot_instance(x):
205+
return x * x * x
206+
207+
with RemoteExecutor(
208+
max_parallel_jobs=1,
209+
role=ROLE,
210+
image_uri=dummy_container_without_error,
211+
instance_type=cpu_instance_type,
212+
sagemaker_session=sagemaker_session,
213+
use_spot_instances=True,
214+
max_wait_time_in_seconds=48 * 60 * 60,
215+
) as e:
216+
future_1 = e.submit(square_on_spot_instance, 10)
217+
future_2 = e.submit(cube_on_spot_instance, 10)
218+
219+
assert future_1.result() == 100
220+
assert future_2.result() == 1000
221+
222+
assert get_future(future_1._job.job_name, sagemaker_session).result() == 100
223+
assert get_future(future_2._job.job_name, sagemaker_session).result() == 1000
224+
225+
describe_job_1 = next(
226+
list_futures(job_name_prefix="square-on-spot-instance", sagemaker_session=sagemaker_session)
227+
)._job.describe()
228+
assert describe_job_1["EnableManagedSpotTraining"] is True
229+
assert describe_job_1["StoppingCondition"]["MaxWaitTimeInSeconds"] == 172800
230+
231+
describe_job_2 = next(
232+
list_futures(job_name_prefix="cube-on-spot-instance", sagemaker_session=sagemaker_session)
233+
)._job.describe()
234+
assert describe_job_2["EnableManagedSpotTraining"] is True
235+
assert describe_job_2["StoppingCondition"]["MaxWaitTimeInSeconds"] == 172800
236+
237+
198238
def test_executor_map_with_run(sagemaker_session, dummy_container_without_error, cpu_instance_type):
199239
def square(x):
200240
with load_run() as run:

tests/unit/sagemaker/feature_store/feature_processor/test_feature_scheduler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,8 @@ def test_remote_decorator_fields_consistency(get_execution_role, session):
873873
"volume_kms_key",
874874
"vpc_config",
875875
"tags",
876+
"use_spot_instances",
877+
"max_wait_time_in_seconds",
876878
}
877879

878880
job_settings = _JobSettings(

tests/unit/sagemaker/remote_function/test_client.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,26 @@ def square(x):
455455
assert mock_job_settings.call_args.kwargs["spark_config"] == spark_config
456456

457457

458+
@patch(
459+
"sagemaker.remote_function.core.serialization.deserialize_obj_from_s3",
460+
return_value=EXPECTED_JOB_RESULT,
461+
)
462+
@patch("sagemaker.remote_function.client._JobSettings")
463+
@patch("sagemaker.remote_function.client._Job.start")
464+
def test_decorator_with_spot_instances(mock_start, mock_job_settings, mock_deserialize):
465+
mock_job = Mock(job_name=TRAINING_JOB_NAME)
466+
mock_job.describe.return_value = COMPLETED_TRAINING_JOB
467+
468+
mock_start.return_value = mock_job
469+
470+
@remote(use_spot_instances=True, max_wait_time_in_seconds=48 * 60 * 60)
471+
def square(x):
472+
pass
473+
474+
assert mock_job_settings.call_args.kwargs["use_spot_instances"] is True
475+
assert mock_job_settings.call_args.kwargs["max_wait_time_in_seconds"] == 172800
476+
477+
458478
@pytest.mark.parametrize(
459479
"args, kwargs, error_message",
460480
[

tests/unit/sagemaker/remote_function/test_job.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,7 @@ def test_start(
412412
),
413413
EnableNetworkIsolation=False,
414414
EnableInterContainerTrafficEncryption=True,
415+
EnableManagedSpotTraining=False,
415416
Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
416417
)
417418

@@ -537,6 +538,7 @@ def test_start_with_complete_job_settings(
537538
EnableNetworkIsolation=False,
538539
EnableInterContainerTrafficEncryption=False,
539540
VpcConfig=dict(Subnets=["subnet"], SecurityGroupIds=["sg"]),
541+
EnableManagedSpotTraining=False,
540542
Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
541543
)
542544

@@ -667,6 +669,7 @@ def test_start_with_spark(
667669
),
668670
EnableNetworkIsolation=False,
669671
EnableInterContainerTrafficEncryption=True,
672+
EnableManagedSpotTraining=False,
670673
Environment={"AWS_DEFAULT_REGION": "us-west-2", "REMOTE_FUNCTION_SECRET_KEY": HMAC_KEY},
671674
)
672675

0 commit comments

Comments
 (0)