Skip to content

Commit 8567f73

Browse files
authored
Merge branch 'master' into fix-pytorch-inference-test
2 parents a30e223 + 656a17d commit 8567f73

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,12 @@ class constructor
871871
init_params["model_uri"] = channel["DataSource"]["S3DataSource"]["S3Uri"]
872872
break
873873

874+
if job_details.get("EnableManagedSpotTraining", False):
875+
init_params["use_spot_instances"] = True
876+
max_wait = job_details.get("StoppingCondition", {}).get("MaxWaitTimeInSeconds")
877+
if max_wait:
878+
init_params["max_wait"] = max_wait
879+
874880
return init_params
875881

876882
def transformer(

tests/unit/test_estimator.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2125,7 +2125,6 @@ def test_generic_deploy_accelerator_type(sagemaker_session):
21252125
e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, accelerator_type=ACCELERATOR_TYPE)
21262126

21272127
args = e.sagemaker_session.endpoint_from_production_variants.call_args[1]
2128-
print(args)
21292128
assert args["name"].startswith(IMAGE_URI)
21302129
assert args["production_variants"][0]["AcceleratorType"] == ACCELERATOR_TYPE
21312130
assert args["production_variants"][0]["InitialInstanceCount"] == INSTANCE_COUNT
@@ -2182,7 +2181,6 @@ def test_local_mode(session_class, local_session_class):
21822181
session_class.return_value = session
21832182

21842183
e = Estimator(IMAGE_URI, ROLE, INSTANCE_COUNT, "local")
2185-
print(e.sagemaker_session.local_mode)
21862184
assert e.sagemaker_session.local_mode is True
21872185

21882186
e2 = Estimator(IMAGE_URI, ROLE, INSTANCE_COUNT, "local_gpu")
@@ -2248,6 +2246,25 @@ def test_prepare_init_params_from_job_description_with_algorithm_training_job():
22482246
)
22492247

22502248

2249+
def test_prepare_init_params_from_job_description_with_spot_training():
2250+
job_description = RETURNED_JOB_DESCRIPTION.copy()
2251+
job_description["EnableManagedSpotTraining"] = True
2252+
job_description["StoppingCondition"] = {
2253+
"MaxRuntimeInSeconds": 86400,
2254+
"MaxWaitTimeInSeconds": 87000,
2255+
}
2256+
2257+
init_params = EstimatorBase._prepare_init_params_from_job_description(
2258+
job_details=job_description
2259+
)
2260+
2261+
assert init_params["role"] == "arn:aws:iam::366:role/SageMakerRole"
2262+
assert init_params["instance_count"] == 1
2263+
assert init_params["use_spot_instances"]
2264+
assert init_params["max_run"] == 86400
2265+
assert init_params["max_wait"] == 87000
2266+
2267+
22512268
def test_prepare_init_params_from_job_description_with_invalid_training_job():
22522269

22532270
invalid_job_description = RETURNED_JOB_DESCRIPTION.copy()

0 commit comments

Comments
 (0)