Skip to content

Commit f7b8dd1

Browse files
fix: add tag fixes for pipelines and experiments
1 parent 607f85a commit f7b8dd1

File tree

14 files changed

+179
-83
lines changed

14 files changed

+179
-83
lines changed

src/sagemaker/experiments/experiment.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import time
1717

18+
from botocore.exceptions import ClientError
19+
1820
from sagemaker.apiutils import _base_types
1921
from sagemaker.experiments.trial import _Trial
2022
from sagemaker.experiments.trial_component import _TrialComponent
@@ -154,17 +156,21 @@ def _load_or_create(
154156
Returns:
155157
experiments.experiment._Experiment: A SageMaker `_Experiment` object
156158
"""
157-
sagemaker_client = sagemaker_session.sagemaker_client
158159
try:
159-
experiment = _Experiment.load(experiment_name, sagemaker_session)
160-
except sagemaker_client.exceptions.ResourceNotFound:
161160
experiment = _Experiment.create(
162161
experiment_name=experiment_name,
163162
display_name=display_name,
164163
description=description,
165164
tags=tags,
166165
sagemaker_session=sagemaker_session,
167166
)
167+
except ClientError as ce:
168+
error_code = ce.response["Error"]["Code"]
169+
error_message = ce.response["Error"]["Message"]
170+
if not (error_code == "ValidationException" and "already exists" in error_message):
171+
raise ce
172+
# already exists
173+
experiment = _Experiment.load(experiment_name, sagemaker_session)
168174
return experiment
169175

170176
def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None):

src/sagemaker/experiments/trial.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"""Contains the Trial class."""
1414
from __future__ import absolute_import
1515

16+
from botocore.exceptions import ClientError
17+
1618
from sagemaker.apiutils import _base_types
1719
from sagemaker.experiments import _api_types
1820
from sagemaker.experiments.trial_component import _TrialComponent
@@ -268,8 +270,20 @@ def _load_or_create(
268270
Returns:
269271
experiments.trial._Trial: A SageMaker `_Trial` object
270272
"""
271-
sagemaker_client = sagemaker_session.sagemaker_client
272273
try:
274+
trial = _Trial.create(
275+
experiment_name=experiment_name,
276+
trial_name=trial_name,
277+
display_name=display_name,
278+
tags=tags,
279+
sagemaker_session=sagemaker_session,
280+
)
281+
except ClientError as ce:
282+
error_code = ce.response["Error"]["Code"]
283+
error_message = ce.response["Error"]["Message"]
284+
if not (error_code == "ValidationException" and "already exists" in error_message):
285+
raise ce
286+
# already exists
273287
trial = _Trial.load(trial_name, sagemaker_session)
274288
if trial.experiment_name != experiment_name: # pylint: disable=no-member
275289
raise ValueError(
@@ -278,12 +292,4 @@ def _load_or_create(
278292
trial.experiment_name # pylint: disable=no-member
279293
)
280294
)
281-
except sagemaker_client.exceptions.ResourceNotFound:
282-
trial = _Trial.create(
283-
experiment_name=experiment_name,
284-
trial_name=trial_name,
285-
display_name=display_name,
286-
tags=tags,
287-
sagemaker_session=sagemaker_session,
288-
)
289295
return trial

src/sagemaker/experiments/trial_component.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import time
1717

18+
from botocore.exceptions import ClientError
19+
1820
from sagemaker.apiutils import _base_types
1921
from sagemaker.experiments import _api_types
2022
from sagemaker.experiments._api_types import TrialComponentSearchResult
@@ -326,16 +328,20 @@ def _load_or_create(
326328
experiments.trial_component._TrialComponent: A SageMaker `_TrialComponent` object.
327329
bool: A boolean variable indicating whether the trail component already exists
328330
"""
329-
sagemaker_client = sagemaker_session.sagemaker_client
330331
is_existed = False
331332
try:
332-
run_tc = _TrialComponent.load(trial_component_name, sagemaker_session)
333-
is_existed = True
334-
except sagemaker_client.exceptions.ResourceNotFound:
335333
run_tc = _TrialComponent.create(
336334
trial_component_name=trial_component_name,
337335
display_name=display_name,
338336
tags=tags,
339337
sagemaker_session=sagemaker_session,
340338
)
339+
except ClientError as ce:
340+
error_code = ce.response["Error"]["Code"]
341+
error_message = ce.response["Error"]["Message"]
342+
if not (error_code == "ValidationException" and "already exists" in error_message):
343+
raise ce
344+
# already exists
345+
run_tc = _TrialComponent.load(trial_component_name, sagemaker_session)
346+
is_existed = True
341347
return run_tc, is_existed

src/sagemaker/session.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3224,7 +3224,7 @@ def create_model_package_from_containers(
32243224
def submit(request):
32253225
if model_package_group_name is not None:
32263226
_create_resource(
3227-
self.sagemaker_client.create_model_package_group(
3227+
lambda: self.sagemaker_client.create_model_package_group(
32283228
ModelPackageGroupName=request["ModelPackageGroupName"]
32293229
)
32303230
)
@@ -5448,8 +5448,9 @@ def _deployment_entity_exists(describe_fn):
54485448

54495449

54505450
def _create_resource(create_fn):
5451-
"""Call create function and while doing so accepts/passes the resource already exists exception.
5452-
Throws an exception if any exception other than resource already exists.
5451+
"""Call create function and accepts/pass when resource already exists.
5452+
5453+
This is a helper function to use an existing resource if found when creating.
54535454
54545455
Args:
54555456
create_fn: Create resource function.
@@ -5823,9 +5824,9 @@ def _wait_until_training_done(callable_fn, desc, poll=5):
58235824
job_desc, finished = callable_fn(job_desc)
58245825
except botocore.exceptions.ClientError as err:
58255826
# For initial 5 mins we accept/pass AccessDeniedException.
5826-
# The reason is to await tag propagation to avoid false AccessDenied claims for an access
5827-
# policy based on resource tags, The caveat here is for true AccessDenied cases the routine
5828-
# will fail after 5 mins
5827+
# The reason is to await tag propagation to avoid false AccessDenied claims for an
5828+
# access policy based on resource tags, The caveat here is for true AccessDenied
5829+
# cases the routine will fail after 5 mins
58295830
if err.response["Error"]["Code"] == "AccessDeniedException" and elapsed_time <= 300:
58305831
LOGGER.warning(
58315832
"Received AccessDeniedException. This could mean the IAM role does not "
@@ -5834,8 +5835,7 @@ def _wait_until_training_done(callable_fn, desc, poll=5):
58345835
"continuing to wait for tag propagation.."
58355836
)
58365837
continue
5837-
else:
5838-
raise err
5838+
raise err
58395839
return job_desc
58405840

58415841

@@ -5861,8 +5861,7 @@ def _wait_until(callable_fn, poll=5):
58615861
"continuing to wait for tag propagation.."
58625862
)
58635863
continue
5864-
else:
5865-
raise err
5864+
raise err
58665865
return result
58675866

58685867

src/sagemaker/utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -604,12 +604,17 @@ def retries(
604604
)
605605

606606

607-
def retry_with_backoff(callable_func, num_attempts=8):
607+
def retry_with_backoff(callable_func, num_attempts=8, botocore_client_error_code=None):
608608
"""Retry with backoff until maximum attempts are reached
609609
610610
Args:
611611
callable_func (callable): The callable function to retry.
612-
num_attempts (int): The maximum number of attempts to retry.
612+
num_attempts (int): The maximum number of attempts to retry.(Default: 8)
613+
botocore_client_error_code (str): The specific Botocore ClientError exception error code
614+
on which to retry on.
615+
If provided other exceptions will be raised directly w/o retry.
616+
If not provided, retry on any exception.
617+
(Default: None)
613618
"""
614619
if num_attempts < 1:
615620
raise ValueError(
@@ -619,7 +624,15 @@ def retry_with_backoff(callable_func, num_attempts=8):
619624
try:
620625
return callable_func()
621626
except Exception as ex: # pylint: disable=broad-except
622-
if i == num_attempts - 1:
627+
if not botocore_client_error_code or (
628+
botocore_client_error_code
629+
and isinstance(ex, botocore.exceptions.ClientError)
630+
and ex.response["Error"]["Code"] # pylint: disable=no-member
631+
== botocore_client_error_code
632+
):
633+
if i == num_attempts - 1:
634+
raise ex
635+
else:
623636
raise ex
624637
logger.error("Retrying in attempt %s, due to %s", (i + 1), str(ex))
625638
time.sleep(2**i)

src/sagemaker/workflow/pipeline.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from sagemaker import s3
2727
from sagemaker._studio import _append_project_tags
2828
from sagemaker.session import Session
29+
from sagemaker.utils import retry_with_backoff
2930
from sagemaker.workflow.callback_step import CallbackOutput, CallbackStep
3031
from sagemaker.workflow.lambda_step import LambdaOutput, LambdaStep
3132
from sagemaker.workflow.entities import (
@@ -306,7 +307,12 @@ def start(
306307
update_args(kwargs, PipelineParameters=parameters)
307308
return self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs)
308309
update_args(kwargs, PipelineParameters=format_start_parameters(parameters))
309-
response = self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs)
310+
311+
# retry on AccessDeniedException to cover case of tag propagation delay
312+
response = retry_with_backoff(
313+
lambda: self.sagemaker_session.sagemaker_client.start_pipeline_execution(**kwargs),
314+
botocore_client_error_code="AccessDeniedException",
315+
)
310316
return _PipelineExecution(
311317
arn=response["PipelineExecutionArn"],
312318
sagemaker_session=self.sagemaker_session,

tests/integ/test_inference_pipeline.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from sagemaker.predictor import Predictor
2929
from sagemaker.serializers import JSONSerializer
3030
from sagemaker.sparkml.model import SparkMLModel
31-
from sagemaker.utils import sagemaker_timestamp
31+
from sagemaker.utils import unique_name_from_base
3232

3333
SPARKML_DATA_PATH = os.path.join(DATA_DIR, "sparkml_model")
3434
XGBOOST_DATA_PATH = os.path.join(DATA_DIR, "xgboost_model")
@@ -60,7 +60,7 @@ def test_inference_pipeline_batch_transform(sagemaker_session, cpu_instance_type
6060
path=os.path.join(XGBOOST_DATA_PATH, "xgb_model.tar.gz"),
6161
key_prefix="integ-test-data/xgboost/model",
6262
)
63-
batch_job_name = "test-inference-pipeline-batch-{}".format(sagemaker_timestamp())
63+
batch_job_name = unique_name_from_base("test-inference-pipeline-batch")
6464
sparkml_model = SparkMLModel(
6565
model_data=sparkml_model_data,
6666
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
@@ -99,7 +99,7 @@ def test_inference_pipeline_batch_transform(sagemaker_session, cpu_instance_type
9999
def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
100100
sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model")
101101
xgboost_data_path = os.path.join(DATA_DIR, "xgboost_model")
102-
endpoint_name = "test-inference-pipeline-deploy-{}".format(sagemaker_timestamp())
102+
endpoint_name = unique_name_from_base("test-inference-pipeline-deploy")
103103
sparkml_model_data = sagemaker_session.upload_data(
104104
path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"),
105105
key_prefix="integ-test-data/sparkml/model",
@@ -156,7 +156,7 @@ def test_inference_pipeline_model_deploy_and_update_endpoint(
156156
):
157157
sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model")
158158
xgboost_data_path = os.path.join(DATA_DIR, "xgboost_model")
159-
endpoint_name = "test-inference-pipeline-deploy-{}".format(sagemaker_timestamp())
159+
endpoint_name = unique_name_from_base("test-inference-pipeline-deploy")
160160
sparkml_model_data = sagemaker_session.upload_data(
161161
path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"),
162162
key_prefix="integ-test-data/sparkml/model",

tests/integ/test_mxnet.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sagemaker.mxnet.model import MXNetModel
2525
from sagemaker.mxnet.processing import MXNetProcessor
2626
from sagemaker.serverless import ServerlessInferenceConfig
27-
from sagemaker.utils import sagemaker_timestamp
27+
from sagemaker.utils import unique_name_from_base
2828
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
2929
from tests.integ.kms_utils import get_or_create_kms_key
3030
from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name
@@ -98,7 +98,7 @@ def test_framework_processing_job_with_deps(
9898

9999
@pytest.mark.release
100100
def test_attach_deploy(mxnet_training_job, sagemaker_session, cpu_instance_type):
101-
endpoint_name = "test-mxnet-attach-deploy-{}".format(sagemaker_timestamp())
101+
endpoint_name = unique_name_from_base("test-mxnet-attach-deploy")
102102

103103
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
104104
estimator = MXNet.attach(mxnet_training_job, sagemaker_session=sagemaker_session)
@@ -165,7 +165,7 @@ def test_deploy_model(
165165
mxnet_inference_latest_py_version,
166166
cpu_instance_type,
167167
):
168-
endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())
168+
endpoint_name = unique_name_from_base("test-mxnet-deploy-model")
169169

170170
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
171171
desc = sagemaker_session.sagemaker_client.describe_training_job(
@@ -200,7 +200,7 @@ def test_register_model_package(
200200
mxnet_inference_latest_py_version,
201201
cpu_instance_type,
202202
):
203-
endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())
203+
endpoint_name = unique_name_from_base("test-mxnet-deploy-model")
204204

205205
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
206206
desc = sagemaker_session.sagemaker_client.describe_training_job(
@@ -216,7 +216,7 @@ def test_register_model_package(
216216
sagemaker_session=sagemaker_session,
217217
framework_version=mxnet_inference_latest_version,
218218
)
219-
model_package_name = "register-model-package-{}".format(sagemaker_timestamp())
219+
model_package_name = unique_name_from_base("register-model-package")
220220
model_pkg = model.register(
221221
content_types=["application/json"],
222222
response_types=["application/json"],
@@ -239,13 +239,13 @@ def test_register_model_package_versioned(
239239
mxnet_inference_latest_py_version,
240240
cpu_instance_type,
241241
):
242-
endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())
242+
endpoint_name = unique_name_from_base("test-mxnet-deploy-model")
243243

244244
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
245245
desc = sagemaker_session.sagemaker_client.describe_training_job(
246246
TrainingJobName=mxnet_training_job
247247
)
248-
model_package_group_name = "register-model-package-{}".format(sagemaker_timestamp())
248+
model_package_group_name = unique_name_from_base("register-model-package")
249249
sagemaker_session.sagemaker_client.create_model_package_group(
250250
ModelPackageGroupName=model_package_group_name
251251
)
@@ -287,7 +287,7 @@ def test_deploy_model_with_tags_and_kms(
287287
mxnet_inference_latest_py_version,
288288
cpu_instance_type,
289289
):
290-
endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())
290+
endpoint_name = unique_name_from_base("test-mxnet-deploy-model")
291291

292292
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
293293
desc = sagemaker_session.sagemaker_client.describe_training_job(
@@ -347,7 +347,7 @@ def test_deploy_model_and_update_endpoint(
347347
cpu_instance_type,
348348
alternative_cpu_instance_type,
349349
):
350-
endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())
350+
endpoint_name = unique_name_from_base("test-mxnet-deploy-model")
351351

352352
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
353353
desc = sagemaker_session.sagemaker_client.describe_training_job(
@@ -395,7 +395,7 @@ def test_deploy_model_with_accelerator(
395395
mxnet_eia_latest_py_version,
396396
cpu_instance_type,
397397
):
398-
endpoint_name = "test-mxnet-deploy-model-ei-{}".format(sagemaker_timestamp())
398+
endpoint_name = unique_name_from_base("test-mxnet-deploy-model-ei")
399399

400400
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
401401
desc = sagemaker_session.sagemaker_client.describe_training_job(
@@ -426,7 +426,7 @@ def test_deploy_model_with_serverless_inference_config(
426426
mxnet_inference_latest_version,
427427
mxnet_inference_latest_py_version,
428428
):
429-
endpoint_name = "test-mxnet-deploy-model-serverless-{}".format(sagemaker_timestamp())
429+
endpoint_name = unique_name_from_base("test-mxnet-deploy-model-serverless")
430430

431431
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
432432
desc = sagemaker_session.sagemaker_client.describe_training_job(
@@ -465,7 +465,7 @@ def test_async_fit(
465465
mxnet_inference_latest_py_version,
466466
cpu_instance_type,
467467
):
468-
endpoint_name = "test-mxnet-attach-deploy-{}".format(sagemaker_timestamp())
468+
endpoint_name = unique_name_from_base("test-mxnet-attach-deploy")
469469

470470
with timeout(minutes=5):
471471
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")

0 commit comments

Comments
 (0)