Skip to content

Commit 7963385

Browse files
fix: tag permission issue - remove describe before create
1 parent 4046ae4 commit 7963385

File tree

4 files changed

+250
-90
lines changed

4 files changed

+250
-90
lines changed

src/sagemaker/session.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3203,14 +3203,11 @@ def create_model_package_from_containers(
32033203

32043204
def submit(request):
32053205
if model_package_group_name is not None:
3206-
try:
3207-
self.sagemaker_client.describe_model_package_group(
3208-
ModelPackageGroupName=request["ModelPackageGroupName"]
3209-
)
3210-
except ClientError:
3206+
_create_resource(
32113207
self.sagemaker_client.create_model_package_group(
32123208
ModelPackageGroupName=request["ModelPackageGroupName"]
32133209
)
3210+
)
32143211
return self.sagemaker_client.create_model_package(**request)
32153212

32163213
return self._intercept_create_request(
@@ -3898,42 +3895,40 @@ def endpoint_from_model_data(
38983895
name = name or name_from_image(image_uri)
38993896
model_vpc_config = vpc_utils.sanitize(model_vpc_config)
39003897

3901-
if _deployment_entity_exists(
3902-
lambda: self.sagemaker_client.describe_endpoint(EndpointName=name)
3903-
):
3904-
raise ValueError(
3905-
'Endpoint with name "{}" already exists; please pick a different name.'.format(name)
3906-
)
3898+
primary_container = container_def(
3899+
image_uri=image_uri,
3900+
model_data_url=model_s3_location,
3901+
env=model_environment_vars,
3902+
)
39073903

3908-
if not _deployment_entity_exists(
3909-
lambda: self.sagemaker_client.describe_model(ModelName=name)
3910-
):
3911-
primary_container = container_def(
3912-
image_uri=image_uri,
3913-
model_data_url=model_s3_location,
3914-
env=model_environment_vars,
3915-
)
3916-
self.create_model(
3917-
name=name, role=role, container_defs=primary_container, vpc_config=model_vpc_config
3918-
)
3904+
self.create_model(
3905+
name=name, role=role, container_defs=primary_container, vpc_config=model_vpc_config
3906+
)
39193907

39203908
data_capture_config_dict = None
39213909
if data_capture_config is not None:
39223910
data_capture_config_dict = data_capture_config._to_request_dict()
39233911

3924-
if not _deployment_entity_exists(
3925-
lambda: self.sagemaker_client.describe_endpoint_config(EndpointConfigName=name)
3926-
):
3927-
self.create_endpoint_config(
3912+
_create_resource(
3913+
lambda: self.create_endpoint_config(
39283914
name=name,
39293915
model_name=name,
39303916
initial_instance_count=initial_instance_count,
39313917
instance_type=instance_type,
39323918
accelerator_type=accelerator_type,
39333919
data_capture_config_dict=data_capture_config_dict,
39343920
)
3921+
)
3922+
3923+
# to make change backwards compatible
3924+
response = _create_resource(
3925+
lambda: self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
3926+
)
3927+
if not response:
3928+
raise ValueError(
3929+
'Endpoint with name "{}" already exists; please pick a different name.'.format(name)
3930+
)
39353931

3936-
self.create_endpoint(endpoint_name=name, config_name=name, wait=wait)
39373932
return name
39383933

39393934
def endpoint_from_production_variants(
@@ -5432,6 +5427,34 @@ def _deployment_entity_exists(describe_fn):
54325427
return False
54335428

54345429

5430+
def _create_resource(create_fn):
5431+
"""Call create function and while doing so accepts/passes the resource already exists exception.
5432+
Throws an exception if any exception other than resource already exists.
5433+
5434+
Args:
5435+
create_fn: Create resource function.
5436+
5437+
Returns:
5438+
(bool): True if new resource was created, False if resource already exists.
5439+
"""
5440+
try:
5441+
create_fn()
5442+
# create function succeeded, resource does not exist already
5443+
return True
5444+
except ClientError as ce:
5445+
error_code = ce.response["Error"]["Code"]
5446+
error_message = ce.response["Error"]["Message"]
5447+
already_exists_exceptions = ["ValidationException", "ResourceInUse"]
5448+
already_exists_msg_patterns = ["Cannot create already existing", "already exists"]
5449+
if not (
5450+
error_code in already_exists_exceptions
5451+
and any(p in error_message for p in already_exists_msg_patterns)
5452+
):
5453+
raise ce
5454+
# no new resource created as resource already exists
5455+
return False
5456+
5457+
54355458
def _train_done(sagemaker_client, job_name, last_desc):
54365459
"""Placeholder docstring"""
54375460
in_progress_statuses = ["InProgress", "Created"]

src/sagemaker/workflow/pipeline.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -241,20 +241,19 @@ def upsert(
241241
Returns:
242242
response dict from service
243243
"""
244-
exists = True
245244
try:
246-
self.describe()
247-
except ClientError as e:
248-
err = e.response.get("Error", {})
249-
if err.get("Code", None) == "ResourceNotFound":
250-
exists = False
251-
else:
252-
raise e
253-
254-
if not exists:
255245
response = self.create(role_arn, description, tags, parallelism_config)
256-
else:
246+
except ClientError as ce:
247+
error_code = ce.response["Error"]["Code"]
248+
error_message = ce.response["Error"]["Message"]
249+
if not (
250+
error_code == "ValidationException"
251+
and "already exists" in error_message
252+
):
253+
raise ce
254+
# already exists
257255
response = self.update(role_arn, description)
256+
# add new tags to existing resource
258257
if tags is not None:
259258
old_tags = self.sagemaker_session.sagemaker_client.list_tags(
260259
ResourceArn=response["PipelineArn"]

tests/unit/sagemaker/workflow/test_pipeline.py

Lines changed: 101 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
from sagemaker.workflow.step_collections import StepCollection
3434
from tests.unit.sagemaker.workflow.helpers import ordered, CustomStep
3535
from sagemaker.local.local_session import LocalSession
36+
from botocore.exceptions import ClientError
37+
3638

3739

3840
@pytest.fixture
@@ -173,10 +175,17 @@ def test_large_pipeline_update(sagemaker_session_mock, role_arn):
173175
)
174176

175177

176-
def test_pipeline_upsert(sagemaker_session_mock, role_arn):
177-
sagemaker_session_mock.sagemaker_client.describe_pipeline.return_value = {
178-
"PipelineArn": "pipeline-arn"
179-
}
178+
def test_pipeline_upsert_resource_already_exists(sagemaker_session_mock, role_arn):
179+
180+
# case 1: resource already exists
181+
def _raise_does_already_exists_client_error(**kwargs):
182+
response = {"Error": {"Code": "ValidationException", "Message": "Resource already exists."}}
183+
raise ClientError(error_response=response, operation_name="create_pipeline")
184+
185+
sagemaker_session_mock.sagemaker_client.create_pipeline = Mock(
186+
name="create_pipeline", side_effect=_raise_does_already_exists_client_error
187+
)
188+
180189
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
181190
"PipelineArn": "pipeline-arn"
182191
}
@@ -197,9 +206,12 @@ def test_pipeline_upsert(sagemaker_session_mock, role_arn):
197206
]
198207
pipeline.upsert(role_arn=role_arn, tags=tags)
199208

200-
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_not_called()
209+
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_once_with(
210+
PipelineName="MyPipeline", RoleArn=role_arn, PipelineDefinition=pipeline.definition(),
211+
Tags=tags
212+
)
201213

202-
assert sagemaker_session_mock.sagemaker_client.update_pipeline.called_with(
214+
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_called_once_with(
203215
PipelineName="MyPipeline", PipelineDefinition=pipeline.definition(), RoleArn=role_arn
204216
)
205217
assert sagemaker_session_mock.sagemaker_client.list_tags.called_with(
@@ -211,6 +223,89 @@ def test_pipeline_upsert(sagemaker_session_mock, role_arn):
211223
ResourceArn="mock_pipeline_arn", Tags=tags
212224
)
213225

226+
def test_pipeline_upsert_create_unexpected_failure(sagemaker_session_mock, role_arn):
227+
228+
# case 2: unexpected failure on create
229+
def _raise_unexpected_client_error(**kwargs):
230+
response = {
231+
"Error": {"Code": "ValidationException", "Message": "Name does not satisfy expression."}
232+
}
233+
raise ClientError(error_response=response, operation_name="foo")
234+
235+
sagemaker_session_mock.sagemaker_client.create_pipeline = Mock(
236+
name="create_pipeline", side_effect=_raise_unexpected_client_error
237+
)
238+
239+
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
240+
"PipelineArn": "pipeline-arn"
241+
}
242+
sagemaker_session_mock.sagemaker_client.list_tags.return_value = {
243+
"Tags": [{"Key": "dummy", "Value": "dummy_tag"}]
244+
}
245+
246+
tags = [
247+
{"Key": "foo", "Value": "abc"},
248+
{"Key": "bar", "Value": "xyz"},
249+
]
250+
251+
pipeline = Pipeline(
252+
name="MyPipeline",
253+
parameters=[],
254+
steps=[],
255+
sagemaker_session=sagemaker_session_mock,
256+
)
257+
258+
with pytest.raises(ClientError):
259+
pipeline.upsert(role_arn=role_arn, tags=tags)
260+
261+
262+
263+
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_once_with(
264+
PipelineName="MyPipeline", RoleArn=role_arn, PipelineDefinition=pipeline.definition(),
265+
Tags=tags
266+
)
267+
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_not_called()
268+
sagemaker_session_mock.sagemaker_client.list_tags.assert_not_called()
269+
sagemaker_session_mock.sagemaker_client.add_tags.assert_not_called()
270+
271+
def test_pipeline_upsert_resourse_doesnt_exist(sagemaker_session_mock, role_arn):
272+
273+
# case 3: resource does not exist
274+
sagemaker_session_mock.sagemaker_client.create_pipeline = Mock(name="create_pipeline")
275+
276+
sagemaker_session_mock.sagemaker_client.update_pipeline.return_value = {
277+
"PipelineArn": "pipeline-arn"
278+
}
279+
sagemaker_session_mock.sagemaker_client.list_tags.return_value = {
280+
"Tags": [{"Key": "dummy", "Value": "dummy_tag"}]
281+
}
282+
283+
tags = [
284+
{"Key": "foo", "Value": "abc"},
285+
{"Key": "bar", "Value": "xyz"},
286+
]
287+
288+
pipeline = Pipeline(
289+
name="MyPipeline",
290+
parameters=[],
291+
steps=[],
292+
sagemaker_session=sagemaker_session_mock,
293+
)
294+
295+
try:
296+
pipeline.upsert(role_arn=role_arn, tags=tags)
297+
except ClientError:
298+
assert False, f"Unexpected ClientError raised"
299+
300+
sagemaker_session_mock.sagemaker_client.create_pipeline.assert_called_once_with(
301+
PipelineName="MyPipeline", RoleArn=role_arn, PipelineDefinition=pipeline.definition(),
302+
Tags=tags
303+
)
304+
305+
sagemaker_session_mock.sagemaker_client.update_pipeline.assert_not_called()
306+
sagemaker_session_mock.sagemaker_client.list_tags.assert_not_called()
307+
sagemaker_session_mock.sagemaker_client.add_tags.assert_not_called()
308+
214309

215310
def test_pipeline_delete(sagemaker_session_mock):
216311
pipeline = Pipeline(

0 commit comments

Comments
 (0)