Skip to content

Commit 9e0d5b5

Browse files
fix: add tag fixes for pipelines and experiments
fix: units and integs fix: units and integs
1 parent 607f85a commit 9e0d5b5

File tree

13 files changed

+173
-79
lines changed

13 files changed

+173
-79
lines changed

src/sagemaker/experiments/experiment.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import time
1717

18+
from botocore.exceptions import ClientError
19+
1820
from sagemaker.apiutils import _base_types
1921
from sagemaker.experiments.trial import _Trial
2022
from sagemaker.experiments.trial_component import _TrialComponent
@@ -154,17 +156,21 @@ def _load_or_create(
154156
Returns:
155157
experiments.experiment._Experiment: A SageMaker `_Experiment` object
156158
"""
157-
sagemaker_client = sagemaker_session.sagemaker_client
158159
try:
159-
experiment = _Experiment.load(experiment_name, sagemaker_session)
160-
except sagemaker_client.exceptions.ResourceNotFound:
161160
experiment = _Experiment.create(
162161
experiment_name=experiment_name,
163162
display_name=display_name,
164163
description=description,
165164
tags=tags,
166165
sagemaker_session=sagemaker_session,
167166
)
167+
except ClientError as ce:
168+
error_code = ce.response["Error"]["Code"]
169+
error_message = ce.response["Error"]["Message"]
170+
if not (error_code == "ValidationException" and "already exists" in error_message):
171+
raise ce
172+
# already exists
173+
experiment = _Experiment.load(experiment_name, sagemaker_session)
168174
return experiment
169175

170176
def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None):

src/sagemaker/experiments/trial.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Contains the Trial class."""
1414
from __future__ import absolute_import
1515

16+
from botocore.exceptions import ClientError
17+
1618
from sagemaker.apiutils import _base_types
1719
from sagemaker.experiments import _api_types
1820
from sagemaker.experiments.trial_component import _TrialComponent
@@ -268,8 +270,20 @@ def _load_or_create(
268270
Returns:
269271
experiments.trial._Trial: A SageMaker `_Trial` object
270272
"""
271-
sagemaker_client = sagemaker_session.sagemaker_client
272273
try:
274+
trial = _Trial.create(
275+
experiment_name=experiment_name,
276+
trial_name=trial_name,
277+
display_name=display_name,
278+
tags=tags,
279+
sagemaker_session=sagemaker_session,
280+
)
281+
except ClientError as ce:
282+
error_code = ce.response["Error"]["Code"]
283+
error_message = ce.response["Error"]["Message"]
284+
if not (error_code == "ValidationException" and "already exists" in error_message):
285+
raise ce
286+
# already exists
273287
trial = _Trial.load(trial_name, sagemaker_session)
274288
if trial.experiment_name != experiment_name: # pylint: disable=no-member
275289
raise ValueError(
@@ -278,12 +292,4 @@ def _load_or_create(
278292
trial.experiment_name # pylint: disable=no-member
279293
)
280294
)
281-
except sagemaker_client.exceptions.ResourceNotFound:
282-
trial = _Trial.create(
283-
experiment_name=experiment_name,
284-
trial_name=trial_name,
285-
display_name=display_name,
286-
tags=tags,
287-
sagemaker_session=sagemaker_session,
288-
)
289295
return trial

src/sagemaker/experiments/trial_component.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import time
1717

18+
from botocore.exceptions import ClientError
19+
1820
from sagemaker.apiutils import _base_types
1921
from sagemaker.experiments import _api_types
2022
from sagemaker.experiments._api_types import TrialComponentSearchResult
@@ -326,16 +328,20 @@ def _load_or_create(
326328
experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object.
327329
bool: A boolean variable indicating whether the trail component already exists
328330
"""
329-
sagemaker_client = sagemaker_session.sagemaker_client
330331
is_existed = False
331332
try:
332-
run_tc = _TrialComponent.load(trial_component_name, sagemaker_session)
333-
is_existed = True
334-
except sagemaker_client.exceptions.ResourceNotFound:
335333
run_tc = _TrialComponent.create(
336334
trial_component_name=trial_component_name,
337335
display_name=display_name,
338336
tags=tags,
339337
sagemaker_session=sagemaker_session,
340338
)
339+
except ClientError as ce:
340+
error_code = ce.response["Error"]["Code"]
341+
error_message = ce.response["Error"]["Message"]
342+
if not (error_code == "ValidationException" and "already exists" in error_message):
343+
raise ce
344+
# already exists
345+
run_tc = _TrialComponent.load(trial_component_name, sagemaker_session)
346+
is_existed = True
341347
return run_tc, is_existed

src/sagemaker/session.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3224,7 +3224,7 @@ def create_model_package_from_containers(
32243224
def submit(request):
32253225
if model_package_group_name is not None:
32263226
_create_resource(
3227-
self.sagemaker_client.create_model_package_group(
3227+
lambda: self.sagemaker_client.create_model_package_group(
32283228
ModelPackageGroupName=request["ModelPackageGroupName"]
32293229
)
32303230
)
@@ -5448,8 +5448,9 @@ def _deployment_entity_exists(describe_fn):
54485448

54495449

54505450
def _create_resource(create_fn):
5451-
"""Call create function and while doing so accepts/passes the resource already exists exception.
5452-
Throws an exception if any exception other than resource already exists.
5451+
"""Call create function and accepts/pass when resource already exists.
5452+
5453+
This is a helper function to use an existing resource if found when creating.
54535454
54545455
Args:
54555456
create_fn: Create resource function.
@@ -5823,9 +5824,9 @@ def _wait_until_training_done(callable_fn, desc, poll=5):
58235824
job_desc, finished = callable_fn(job_desc)
58245825
except botocore.exceptions.ClientError as err:
58255826
# For initial 5 mins we accept/pass AccessDeniedException.
5826-
# The reason is to await tag propagation to avoid false AccessDenied claims for an access
5827-
# policy based on resource tags, The caveat here is for true AccessDenied cases the routine
5828-
# will fail after 5 mins
5827+
# The reason is to await tag propagation to avoid false AccessDenied claims for an
5828+
# access policy based on resource tags, The caveat here is for true AccessDenied
5829+
# cases the routine will fail after 5 mins
58295830
if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
58305831
LOGGER.warning(
58315832
"Received AccessDeniedException. This could mean the IAM role does not "
@@ -5834,8 +5835,6 @@ def _wait_until_training_done(callable_fn, desc, poll=5):
58345835
"continuing to wait for tag propagation.."
58355836
)
58365837
continue
5837-
else:
5838-
raise err
58395838
return job_desc
58405839

58415840

@@ -5861,8 +5860,6 @@ def _wait_until(callable_fn, poll=5):
58615860
"continuing to wait for tag propagation.."
58625861
)
58635862
continue
5864-
else:
5865-
raise err
58665863
return result
58675864

58685865

src/sagemaker/utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,12 +604,17 @@ def retries(
604604
)
605605

606606

607-
def retry_with_backoff(callable_func, num_attempts=8):
607+
def retry_with_backoff(callable_func, num_attempts=8, botocore_client_error_code=None):
608608
"""Retry with backoff until maximum attempts are reached
609609
610610
Args:
611611
callable_func (callable): The callable function to retry.
612-
num_attempts (int): The maximum number of attempts to retry.
612+
num_attempts (int): The maximum number of attempts to retry.(Default: 8)
613+
botocore_client_error_code (str): The specific Botocore ClientError exception error code
614+
on which to retry on.
615+
If provided other exceptions will be raised directly w/o retry.
616+
If not provided, retry on any exception.
617+
(Default: None)
613618
"""
614619
if num_attempts < 1:
615620
raise ValueError(
@@ -619,7 +624,15 @@ def retry_with_backoff(callable_func, num_attempts=8):
619624
try:
620625
return callable_func()
621626
except Exception as ex: # pylint: disable=broad-except
622-
if i == num_attempts - 1:
627+
if not botocore_client_error_code or (
628+
botocore_client_error_code
629+
and isinstance(ex) == botocore.exceptions.ClientError
630+
and ex.response["Error"]["Code"] # pylint: disable=no-member
631+
== botocore_client_error_code
632+
):
633+
if i == num_attempts - 1:
634+
raise ex
635+
else:
623636
raise ex
624637
logger.error("Retrying in attempt %s, due to %s", (i + 1), str(ex))
625638
time.sleep(2**i)

src/sagemaker/workflow/pipeline.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sagemaker import s3
2727
from sagemaker._studio import _append_project_tags
2828
from sagemaker.session import Session
29+
from sagemaker.utils import retry_with_backoff
2930
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
3031
from sagemaker.workflow.lambda_step import LambdaOutput, LambdaStep
3132
from sagemaker.workflow.entities import (
@@ -306,7 +307,12 @@ def start(
306307
update_args(kwargs, PipelineParameters=parameters)
307308
return self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs)
308309
update_args(kwargs, PipelineParameters=format_start_parameters(parameters))
309-
response = self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs)
310+
311+
# retry on AccessDeniedException to cover case of tag propagation delay
312+
response = retry_with_backoff(
313+
lambda: self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs),
314+
botocore_client_error_code="AccessDeniedException",
315+
)
310316
return _PipelineExecution(
311317
arn=response["PipelineExecutionArn"],
312318
sagemaker_session=self.sagemaker_session,

tests/integ/test_mxnet.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sagemaker.mxnet.model import MXNetModel
2525
from sagemaker.mxnet.processing import MXNetProcessor
2626
from sagemaker.serverless import ServerlessInferenceConfig
27-
from sagemaker.utils import sagemaker_timestamp
27+
from sagemaker.utils import unique_name_from_base
2828
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2929
from tests.integ.kms_utils import get_or_create_kms_key
3030
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
@@ -98,7 +98,7 @@ def test_framework_processing_job_with_deps(
9898

9999
@pytest.mark.release
100100
def test_attach_deploy(mxnet_training_job, sagemaker_session, cpu_instance_type):
101-
endpoint_name = "test-mxnet-attach-deploy-{}".format(sagemaker_timestamp())
101+
endpoint_name = unique_name_from_base("test-mxnet-attach-deploy")
102102

103103
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
104104
estimator = MXNet.attach(mxnet_training_job, sagemaker_session=sagemaker_session)
@@ -165,7 +165,7 @@ def test_deploy_model(
165165
mxnet_inference_latest_py_version,
166166
cpu_instance_type,
167167
):
168-
endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())
168+
endpoint_name = unique_name_from_base("test-mxnet-deploy-model")
169169

170170
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
171171
desc = sagemaker_session.sagemaker_client.describe_training_job(
@@ -200,7 +200,7 @@ def test_register_model_package(
200200
mxnet_inference_latest_py_version,
201201
cpu_instance_type,
202202
):
203-
endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())
203+
endpoint_name = unique_name_from_base("test-mxnet-deploy-model")
204204

205205
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
206206
desc = sagemaker_session.sagemaker_client.describe_training_job(
@@ -216,7 +216,7 @@ def test_register_model_package(
216216
sagemaker_session=sagemaker_session,
217217
framework_version=mxnet_inference_latest_version,
218218
)
219-
model_package_name = "register-model-package-{}".format(sagemaker_timestamp())
219+
model_package_name = unique_name_from_base("register-model-package")
220220
model_pkg = model.register(
221221
content_types=["application/json"],
222222
response_types=["application/json"],
@@ -239,13 +239,13 @@ def test_register_model_package_versioned(
239239
mxnet_inference_latest_py_version,
240240
cpu_instance_type,
241241
):
242-
endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())
242+
endpoint_name = unique_name_from_base("test-mxnet-deploy-model")
243243

244244
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
245245
desc = sagemaker_session.sagemaker_client.describe_training_job(
246246
TrainingJobName=mxnet_training_job
247247
)
248-
model_package_group_name = "register-model-package-{}".format(sagemaker_timestamp())
248+
model_package_group_name = unique_name_from_base("register-model-package")
249249
sagemaker_session.sagemaker_client.create_model_package_group(
250250
ModelPackageGroupName=model_package_group_name
251251
)
@@ -287,7 +287,7 @@ def test_deploy_model_with_tags_and_kms(
287287
mxnet_inference_latest_py_version,
288288
cpu_instance_type,
289289
):
290-
endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())
290+
endpoint_name = unique_name_from_base("test-mxnet-deploy-model")
291291

292292
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
293293
desc = sagemaker_session.sagemaker_client.describe_training_job(
@@ -347,7 +347,7 @@ def test_deploy_model_and_update_endpoint(
347347
cpu_instance_type,
348348
alternative_cpu_instance_type,
349349
):
350-
endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())
350+
endpoint_name = unique_name_from_base("test-mxnet-deploy-model")
351351

352352
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
353353
desc = sagemaker_session.sagemaker_client.describe_training_job(
@@ -395,7 +395,7 @@ def test_deploy_model_with_accelerator(
395395
mxnet_eia_latest_py_version,
396396
cpu_instance_type,
397397
):
398-
endpoint_name = "test-mxnet-deploy-model-ei-{}".format(sagemaker_timestamp())
398+
endpoint_name = unique_name_from_base("test-mxnet-deploy-model-ei")
399399

400400
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
401401
desc = sagemaker_session.sagemaker_client.describe_training_job(
@@ -426,7 +426,7 @@ def test_deploy_model_with_serverless_inference_config(
426426
mxnet_inference_latest_version,
427427
mxnet_inference_latest_py_version,
428428
):
429-
endpoint_name = "test-mxnet-deploy-model-serverless-{}".format(sagemaker_timestamp())
429+
endpoint_name = unique_name_from_base("test-mxnet-deploy-model-serverless")
430430

431431
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
432432
desc = sagemaker_session.sagemaker_client.describe_training_job(
@@ -465,7 +465,7 @@ def test_async_fit(
465465
mxnet_inference_latest_py_version,
466466
cpu_instance_type,
467467
):
468-
endpoint_name = "test-mxnet-attach-deploy-{}".format(sagemaker_timestamp())
468+
endpoint_name = unique_name_from_base("test-mxnet-attach-deploy")
469469

470470
with timeout(minutes=5):
471471
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")

tests/integ/test_pytorch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from sagemaker.pytorch.model import PyTorchModel
2121
from sagemaker.pytorch.processing import PyTorchProcessor
2222
from sagemaker.serverless import ServerlessInferenceConfig
23-
from sagemaker.utils import sagemaker_timestamp
23+
from sagemaker.utils import unique_name_from_base
2424
from tests.integ import (
2525
test_region,
2626
DATA_DIR,
@@ -130,7 +130,7 @@ def test_framework_processing_job_with_deps(
130130
def test_fit_deploy(
131131
pytorch_training_job_with_latest_infernce_version, sagemaker_session, cpu_instance_type
132132
):
133-
endpoint_name = "test-pytorch-sync-fit-attach-deploy{}".format(sagemaker_timestamp())
133+
endpoint_name = unique_name_from_base("test-pytorch-sync-fit-attach-deploy")
134134
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
135135
estimator = PyTorch.attach(
136136
pytorch_training_job_with_latest_infernce_version, sagemaker_session=sagemaker_session
@@ -180,7 +180,7 @@ def test_deploy_model(
180180
pytorch_inference_latest_version,
181181
pytorch_inference_latest_py_version,
182182
):
183-
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())
183+
endpoint_name = unique_name_from_base("test-pytorch-deploy-model")
184184

185185
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
186186
desc = sagemaker_session.sagemaker_client.describe_training_job(
@@ -210,7 +210,7 @@ def test_deploy_packed_model_with_entry_point_name(
210210
pytorch_inference_latest_version,
211211
pytorch_inference_latest_py_version,
212212
):
213-
endpoint_name = "test-pytorch-deploy-model-{}".format(sagemaker_timestamp())
213+
endpoint_name = unique_name_from_base("test-pytorch-deploy-model")
214214

215215
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
216216
model_data = sagemaker_session.upload_data(path=PACKED_MODEL)
@@ -240,7 +240,7 @@ def test_deploy_model_with_accelerator(
240240
pytorch_eia_latest_version,
241241
pytorch_eia_latest_py_version,
242242
):
243-
endpoint_name = "test-pytorch-deploy-eia-{}".format(sagemaker_timestamp())
243+
endpoint_name = unique_name_from_base("test-pytorch-deploy-eia")
244244
model_data = sagemaker_session.upload_data(path=EIA_MODEL)
245245
pytorch = PyTorchModel(
246246
model_data,
@@ -272,7 +272,7 @@ def test_deploy_model_with_serverless_inference_config(
272272
pytorch_inference_latest_version,
273273
pytorch_inference_latest_py_version,
274274
):
275-
endpoint_name = "test-pytorch-deploy-model-serverless-{}".format(sagemaker_timestamp())
275+
endpoint_name = unique_name_from_base("test-pytorch-deploy-model-serverless")
276276

277277
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
278278
desc = sagemaker_session.sagemaker_client.describe_training_job(

0 commit comments

Comments
 (0)