Skip to content

Commit 191b9c5

Browse files
author
Keshav Chandak
committed
bugfix: Added check for the presence of model package group before creating one
1 parent bbbb76b commit 191b9c5

File tree

2 files changed

+130
-4
lines changed

2 files changed

+130
-4
lines changed

src/sagemaker/session.py

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4347,11 +4347,57 @@ def submit(request):
43474347
if model_package_group_name is not None and not model_package_group_name.startswith(
43484348
"arn:"
43494349
):
4350-
_create_resource(
4351-
lambda: self.sagemaker_client.create_model_package_group(
4352-
ModelPackageGroupName=request["ModelPackageGroupName"]
4350+
is_model_package_group_present = False
4351+
try:
4352+
model_package_groups_response = self.search(
4353+
resource="ModelPackageGroup",
4354+
search_expression={
4355+
"Filters": [
4356+
{
4357+
"Name": "ModelPackageGroupName",
4358+
"Value": request["ModelPackageGroupName"],
4359+
"Operator": "Equals",
4360+
}
4361+
],
4362+
}
4363+
)
4364+
if len(model_package_groups_response.get("Results")) > 0:
4365+
is_model_package_group_present = True
4366+
except:
4367+
model_package_groups = []
4368+
model_package_groups_response = self.sagemaker_client.list_model_package_groups(
4369+
NameContains=request["ModelPackageGroupName"],
4370+
)
4371+
model_package_groups = (
4372+
model_package_groups
4373+
+ model_package_groups_response["ModelPackageGroupSummaryList"]
4374+
)
4375+
next_token = model_package_groups_response.get("NextToken")
4376+
4377+
while next_token is not None and next_token != "":
4378+
model_package_groups_response = self.sagemaker_client.list_model_package_groups(
4379+
NameContains=request["ModelPackageGroupName"], NextToken=next_token
4380+
)
4381+
model_package_groups = (
4382+
model_package_groups
4383+
+ model_package_groups_response["ModelPackageGroupSummaryList"]
4384+
)
4385+
next_token = model_package_groups_response.get("NextToken")
4386+
4387+
filtered_model_package_group = list(
4388+
filter(
4389+
lambda mpg: mpg.get("ModelPackageGroupName")
4390+
== request["ModelPackageGroupName"],
4391+
model_package_groups,
4392+
)
4393+
)
4394+
is_model_package_group_present = len(filtered_model_package_group) > 0
4395+
if not is_model_package_group_present:
4396+
_create_resource(
4397+
lambda: self.sagemaker_client.create_model_package_group(
4398+
ModelPackageGroupName=request["ModelPackageGroupName"]
4399+
)
43534400
)
4354-
)
43554401
if "SourceUri" in request and request["SourceUri"] is not None:
43564402
# Remove inference spec from request if the
43574403
# given source uri can lead to auto-population of it

tests/unit/test_session.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5006,6 +5006,9 @@ def test_create_model_package_with_sagemaker_config_injection(sagemaker_session)
50065006
domain = "COMPUTER_VISION"
50075007
task = "IMAGE_CLASSIFICATION"
50085008
sample_payload_url = "s3://test-bucket/model"
5009+
sagemaker_session.sagemaker_client.search.return_value = {
5010+
"Results": []
5011+
}
50095012
sagemaker_session.create_model_package_from_containers(
50105013
containers=containers,
50115014
content_types=content_types,
@@ -5094,6 +5097,10 @@ def test_create_model_package_from_containers_with_source_uri_and_inference_spec
50945097
skip_model_validation = "All"
50955098
source_uri = "dummy-source-uri"
50965099

5100+
sagemaker_session.sagemaker_client.search.return_value = {
5101+
"Results": []
5102+
}
5103+
50975104
created_versioned_mp_arn = (
50985105
"arn:aws:sagemaker:us-west-2:123456789123:model-package/unit-test-package-version/1"
50995106
)
@@ -5149,6 +5156,9 @@ def test_create_model_package_from_containers_with_source_uri_for_unversioned_mp
51495156
approval_status = ("Approved",)
51505157
skip_model_validation = "All"
51515158
source_uri = "dummy-source-uri"
5159+
sagemaker_session.sagemaker_client.search.return_value = {
5160+
"Results": []
5161+
}
51525162

51535163
with pytest.raises(
51545164
ValueError,
@@ -5221,6 +5231,10 @@ def test_create_model_package_from_containers_with_source_uri_set_to_mp(sagemake
52215231
return_value={"ModelPackageArn": created_versioned_mp_arn}
52225232
)
52235233

5234+
sagemaker_session.sagemaker_client.search.return_value = {
5235+
"Results": []
5236+
}
5237+
52245238
sagemaker_session.create_model_package_from_containers(
52255239
model_package_group_name=model_package_group_name,
52265240
containers=containers,
@@ -5443,6 +5457,9 @@ def test_create_model_package_from_containers_without_instance_types(sagemaker_s
54435457
approval_status = ("Approved",)
54445458
description = "description"
54455459
customer_metadata_properties = {"key1": "value1"}
5460+
sagemaker_session.sagemaker_client.search.return_value = {
5461+
"Results": []
5462+
}
54465463
sagemaker_session.create_model_package_from_containers(
54475464
containers=containers,
54485465
content_types=content_types,
@@ -5510,6 +5527,9 @@ def test_create_model_package_from_containers_with_one_instance_types(
55105527
approval_status = ("Approved",)
55115528
description = "description"
55125529
customer_metadata_properties = {"key1": "value1"}
5530+
sagemaker_session.sagemaker_client.search.return_value = {
5531+
"Results": []
5532+
}
55135533
sagemaker_session.create_model_package_from_containers(
55145534
containers=containers,
55155535
content_types=content_types,
@@ -7183,3 +7203,63 @@ def test_delete_hub_content_reference(sagemaker_session):
71837203
}
71847204

71857205
sagemaker_session.sagemaker_client.delete_hub_content_reference.assert_called_with(**request)
7206+
7207+
def test_create_model_package_from_containers_to_create_mpg_if_not_present_without_search(sagemaker_session):
7208+
sagemaker_session.sagemaker_client.search.side_effect = Exception()
7209+
sagemaker_session.sagemaker_client.search.return_value = {}
7210+
sagemaker_session.sagemaker_client.list_model_package_groups.side_effect = [{
7211+
"ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg"}],
7212+
"NextToken": "NextToken",
7213+
},
7214+
{
7215+
"ModelPackageGroupSummaryList": [{"ModelPackageGroupName": "mock-mpg-test"}]
7216+
}]
7217+
sagemaker_session.create_model_package_from_containers(
7218+
source_uri="mock-source-uri", model_package_group_name="mock-mpg"
7219+
)
7220+
sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called()
7221+
sagemaker_session.create_model_package_from_containers(
7222+
source_uri="mock-source-uri",
7223+
model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg",
7224+
)
7225+
sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called()
7226+
sagemaker_session.sagemaker_client.list_model_package_groups.side_effect = [{
7227+
"ModelPackageGroupSummaryList": []
7228+
}]
7229+
sagemaker_session.create_model_package_from_containers(
7230+
source_uri="mock-source-uri", model_package_group_name="mock-mpg"
7231+
)
7232+
sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with(
7233+
ModelPackageGroupName="mock-mpg"
7234+
)
7235+
7236+
def test_create_model_package_from_containers_to_create_mpg_if_not_present(sagemaker_session):
7237+
# with search api
7238+
sagemaker_session.sagemaker_client.search.return_value = {
7239+
"Results": [
7240+
{
7241+
"ModelPackageGroup": {
7242+
"ModelPackageGroupName": "mock-mpg",
7243+
"ModelPackageGroupArn": "arn:aws:sagemaker:us-west-2:123456789012:model-package-group/mock-mpg",
7244+
}
7245+
}
7246+
]
7247+
}
7248+
sagemaker_session.create_model_package_from_containers(
7249+
source_uri="mock-source-uri", model_package_group_name="mock-mpg"
7250+
)
7251+
sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called()
7252+
sagemaker_session.create_model_package_from_containers(
7253+
source_uri="mock-source-uri",
7254+
model_package_group_name="arn:aws:sagemaker:us-east-1:215995503607:model-package-group/mock-mpg",
7255+
)
7256+
sagemaker_session.sagemaker_client.create_model_package_group.assert_not_called()
7257+
sagemaker_session.sagemaker_client.search.return_value = {
7258+
"Results": []
7259+
}
7260+
sagemaker_session.create_model_package_from_containers(
7261+
source_uri="mock-source-uri", model_package_group_name="mock-mpg"
7262+
)
7263+
sagemaker_session.sagemaker_client.create_model_package_group.assert_called_with(
7264+
ModelPackageGroupName="mock-mpg"
7265+
)

0 commit comments

Comments
 (0)