Skip to content

Commit b3524d2

Browse files
author
Keshav Chandak
committed
Feat: Create/Delete Model Collection
1 parent ce2fb78 commit b3524d2

File tree

4 files changed

+379
-0
lines changed

4 files changed

+379
-0
lines changed

src/sagemaker/collection.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
14+
"""This module contains code related to Amazon SageMaker Collection.
15+
16+
These Classes helps in providing features to maintain and create collections
17+
"""
18+
19+
from __future__ import absolute_import
20+
import json
21+
import time
22+
from typing import List
23+
24+
25+
from botocore.exceptions import ClientError
26+
from sagemaker.session import Session
27+
28+
29+
class Collection(object):
30+
"""Sets up Amazon SageMaker Collection."""
31+
32+
def __init__(self, sagemaker_session):
33+
"""Initializes a Collection instance.
34+
35+
The collection provides a logical grouping for model groups
36+
37+
Args:
38+
sagemaker_session (sagemaker.session.Session): Session object which
39+
manages interactions with Amazon SageMaker APIs and any other
40+
AWS services needed. If not specified, one is created using
41+
the default AWS configuration chain.
42+
"""
43+
self.sagemaker_session = sagemaker_session or Session()
44+
45+
def _check_access_error(self, err: ClientError):
46+
"""To check if the error is related to the access error and to provide the relavant message
47+
48+
Args:
49+
err: The client error that needs to be checked
50+
"""
51+
error_code = err.response["Error"]["Code"]
52+
if error_code == "AccessDeniedException":
53+
raise Exception(
54+
f"{error_code}: This account needs to attach a custom policy "
55+
"to the user role to gain access to Collections. Refer - "
56+
"https://docs.aws.amazon.com/sagemaker/latest/dg/modelcollections-permissions.html"
57+
)
58+
59+
def create(self, collection_name: str, parent_collection_name: str = None):
60+
"""Creates a collection
61+
62+
Args:
63+
collection_name (str): The name of the collection to be created
64+
parent_collection_name (str): The name of the parent collection.
65+
To be None if the collection is to be created on the root level
66+
"""
67+
68+
tag_rule_key = f"sagemaker:collection-path:{time.time()}"
69+
tags_on_collection = {
70+
"sagemaker:collection": "true",
71+
"sagemaker:collection-path:root": "true",
72+
}
73+
tag_rule_values = [collection_name]
74+
75+
if parent_collection_name is not None:
76+
try:
77+
group_query = self.sagemaker_session.get_resource_group_query(
78+
group=parent_collection_name
79+
)
80+
except ClientError as e:
81+
error_code = e.response["Error"]["Code"]
82+
83+
if error_code == "NotFoundException":
84+
raise ValueError(f"Cannot find collection: {parent_collection_name}")
85+
self._check_access_error(err=e)
86+
raise
87+
if group_query.get("GroupQuery"):
88+
parent_tag_rule_query = json.loads(
89+
group_query["GroupQuery"].get("ResourceQuery", {}).get("Query", "")
90+
)
91+
parent_tag_rule = parent_tag_rule_query.get("TagFilters", [])[0]
92+
if not parent_tag_rule:
93+
raise "Invalid parent_collection_name"
94+
parent_tag_value = parent_tag_rule["Values"][0]
95+
tags_on_collection = {
96+
parent_tag_rule["Key"]: parent_tag_value,
97+
"sagemaker:collection": "true",
98+
}
99+
tag_rule_values = [f"{parent_tag_value}/{collection_name}"]
100+
try:
101+
resource_filters = [
102+
"AWS::SageMaker::ModelPackageGroup",
103+
"AWS::ResourceGroups::Group",
104+
]
105+
106+
tag_filters = [
107+
{
108+
"Key": tag_rule_key,
109+
"Values": tag_rule_values,
110+
}
111+
]
112+
resource_query = {
113+
"Query": json.dumps(
114+
{"ResourceTypeFilters": resource_filters, "TagFilters": tag_filters}
115+
),
116+
"Type": "TAG_FILTERS_1_0",
117+
}
118+
collection_create_response = self.sagemaker_session.create_group(
119+
collection_name, resource_query, tags_on_collection
120+
)
121+
return {
122+
"Name": collection_create_response["Group"]["Name"],
123+
"Arn": collection_create_response["Group"]["GroupArn"],
124+
}
125+
126+
except ClientError as e:
127+
message = e.response["Error"]["Message"]
128+
error_code = e.response["Error"]["Code"]
129+
130+
if error_code == "BadRequestException" and "group already exists" in message:
131+
raise ValueError("Collection with the given name already exists")
132+
133+
self._check_access_error(err=e)
134+
raise
135+
136+
def delete(self, collections: List[str]):
137+
"""Deletes a lits of collection
138+
139+
Args:
140+
collections (List[str]): List of collections to be deleted
141+
Only deletes a collection if it is empty
142+
"""
143+
144+
if len(collections) > 10:
145+
raise ValueError("Can delete upto 10 collections at a time")
146+
147+
delete_collection_failures = []
148+
deleted_collection = []
149+
collection_filter = [
150+
{
151+
"Name": "resource-type",
152+
"Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"],
153+
},
154+
]
155+
for collection in collections:
156+
try:
157+
collection_details = self.sagemaker_session.list_group_resources(
158+
group=collection, filters=collection_filter
159+
)
160+
except ClientError as e:
161+
self._check_access_error(err=e)
162+
delete_collection_failures.append(
163+
{"collection": collection, "message": e.response["Error"]["Message"]}
164+
)
165+
continue
166+
if collection_details.get("Resources") and len(collection_details["Resources"]) > 0:
167+
delete_collection_failures.append(
168+
{"collection": collection, "message": "Validation error: Collection not empty"}
169+
)
170+
else:
171+
try:
172+
self.sagemaker_session.delete_resource_group(group=collection)
173+
deleted_collection.append(collection)
174+
except ClientError as e:
175+
self._check_access_error(err=e)
176+
delete_collection_failures.append(
177+
{"collection": collection, "message": e.response["Error"]["Message"]}
178+
)
179+
return {
180+
"deleted_collections": deleted_collection,
181+
"delete_collection_failures": delete_collection_failures,
182+
}

src/sagemaker/session.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ def __init__(
204204
self._default_bucket_name_override = default_bucket
205205
self.s3_resource = None
206206
self.s3_client = None
207+
self.resource_groups_client = None
207208
self.config = None
208209
self.lambda_client = None
209210
self.settings = settings
@@ -3961,6 +3962,57 @@ def delete_model(self, model_name):
39613962
LOGGER.info("Deleting model with name: %s", model_name)
39623963
self.sagemaker_client.delete_model(ModelName=model_name)
39633964

3965+
def list_group_resources(self, group, filters):
3966+
"""To list group resources with given filters
3967+
3968+
Args:
3969+
group (str): The name or the ARN of the group.
3970+
filters (list): Filters that needs to be applied to the list operation.
3971+
"""
3972+
self.resource_groups_client = self.resource_groups_client or self.boto_session.client(
3973+
"resource-groups"
3974+
)
3975+
return self.resource_groups_client.list_group_resources(Group=group, Filters=filters)
3976+
3977+
def delete_resource_group(self, group):
3978+
"""To delete a resource group
3979+
3980+
Args:
3981+
group (str): The name or the ARN of the resource group to delete.
3982+
"""
3983+
self.resource_groups_client = self.resource_groups_client or self.boto_session.client(
3984+
"resource-groups"
3985+
)
3986+
return self.resource_groups_client.delete_group(Group=group)
3987+
3988+
def get_resource_group_query(self, group):
3989+
"""To get the group query for an AWS Resource Group
3990+
3991+
Args:
3992+
group (str): The name or the ARN of the resource group to query.
3993+
"""
3994+
self.resource_groups_client = self.resource_groups_client or self.boto_session.client(
3995+
"resource-groups"
3996+
)
3997+
return self.resource_groups_client.get_group_query(Group=group)
3998+
3999+
def create_group(self, name, resource_query, tags):
4000+
"""To create a AWS Resource Group
4001+
4002+
Args:
4003+
name (str): The name of the group, which is also the identifier of the group.
4004+
resource_query (str): The resource query that determines
4005+
which AWS resources are members of this group
4006+
tags (dict): The Tags to be attached to the Resource Group
4007+
"""
4008+
self.resource_groups_client = self.resource_groups_client or self.boto_session.client(
4009+
"resource-groups"
4010+
)
4011+
4012+
return self.resource_groups_client.create_group(
4013+
Name=name, ResourceQuery=resource_query, Tags=tags
4014+
)
4015+
39644016
def list_tags(self, resource_arn, max_results=50):
39654017
"""List the tags given an Amazon Resource Name.
39664018

tests/integ/test_collection.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
from sagemaker.utils import unique_name_from_base
16+
from sagemaker.collection import Collection
17+
18+
19+
def test_create_collection_root_success(sagemaker_session):
20+
collection = Collection(sagemaker_session)
21+
collection_name = unique_name_from_base("test-collection")
22+
collection.create(collection_name)
23+
collection_filter = [
24+
{
25+
"Name": "resource-type",
26+
"Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"],
27+
},
28+
]
29+
collection_details = sagemaker_session.list_group_resources(
30+
group=collection_name, filters=collection_filter
31+
)
32+
assert collection_details["ResponseMetadata"]["HTTPStatusCode"] == 200
33+
delete_response = collection.delete([collection_name])
34+
assert len(delete_response["deleted_collections"]) == 1
35+
assert len(delete_response["delete_collection_failures"]) == 0
36+
37+
38+
def test_create_collection_nested_success(sagemaker_session):
39+
collection = Collection(sagemaker_session)
40+
collection_name = unique_name_from_base("test-collection")
41+
child_collection_name = unique_name_from_base("test-collection-2")
42+
collection.create(collection_name)
43+
collection.create(collection_name=child_collection_name, parent_collection_name=collection_name)
44+
collection_filter = [
45+
{
46+
"Name": "resource-type",
47+
"Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"],
48+
},
49+
]
50+
collection_details = sagemaker_session.list_group_resources(
51+
group=collection_name, filters=collection_filter
52+
)
53+
# has one child i.e child collection
54+
assert len(collection_details["Resources"]) == 1
55+
56+
collection_details = sagemaker_session.list_group_resources(
57+
group=child_collection_name, filters=collection_filter
58+
)
59+
collection_details["ResponseMetadata"]["HTTPStatusCode"]
60+
delete_response = collection.delete([child_collection_name, collection_name])
61+
assert len(delete_response["deleted_collections"]) == 2
62+
assert len(delete_response["delete_collection_failures"]) == 0

tests/unit/test_collection.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pytest
16+
import json
17+
from mock import Mock
18+
19+
from sagemaker.collection import Collection
20+
21+
REGION = "us-west-2"
22+
COLLECTION_NAME = "test-collection"
23+
QUERY = {
24+
"ResourceTypeFilters": ["AWS::SageMaker::ModelPackageGroup", "AWS::ResourceGroups::Group"],
25+
"TagFilters": [
26+
{"Key": "sagemaker:collection-path:1676120428.4811652", "Values": ["test-collection-k"]}
27+
],
28+
}
29+
CREATE_COLLECTION_RESPONSE = {
30+
"Group": {
31+
"GroupArn": f"arn:aws:resource-groups:us-west-2:205984106344:group/{COLLECTION_NAME}",
32+
"Name": COLLECTION_NAME,
33+
},
34+
"ResourceQuery": {
35+
"Type": "TAG_FILTERS_1_0",
36+
"Query": json.dumps(QUERY),
37+
},
38+
"Tags": {"sagemaker:collection-path:root": "true"},
39+
}
40+
41+
42+
@pytest.fixture()
43+
def sagemaker_session():
44+
boto_mock = Mock(name="boto_session", region_name=REGION)
45+
session_mock = Mock(
46+
name="sagemaker_session",
47+
boto_session=boto_mock,
48+
boto_region_name=REGION,
49+
config=None,
50+
local_mode=False,
51+
)
52+
53+
session_mock.create_group = Mock(
54+
name="create_collection", return_value=CREATE_COLLECTION_RESPONSE
55+
)
56+
session_mock.delete_resource_group = Mock(name="delete_resource_group", return_value=True)
57+
session_mock.list_group_resources = Mock(name="list_group_resources", return_value={})
58+
59+
return session_mock
60+
61+
62+
def test_create_collection_success(sagemaker_session):
63+
collection = Collection(sagemaker_session)
64+
create_response = collection.create(collection_name=COLLECTION_NAME)
65+
assert create_response["Name"] is COLLECTION_NAME
66+
assert create_response["Arn"] is not None
67+
68+
69+
def test_delete_collection_success(sagemaker_session):
70+
collection = Collection(sagemaker_session)
71+
delete_response = collection.delete(collections=[COLLECTION_NAME])
72+
assert len(delete_response["deleted_collections"]) == 1
73+
assert len(delete_response["delete_collection_failures"]) == 0
74+
75+
76+
def test_delete_collection_failure_when_collection_is_not_empty(sagemaker_session):
77+
collection = Collection(sagemaker_session)
78+
sagemaker_session.list_group_resources = Mock(
79+
name="list_group_resources", return_value={"Resources": [{}]}
80+
)
81+
delete_response = collection.delete(collections=[COLLECTION_NAME])
82+
assert len(delete_response["deleted_collections"]) == 0
83+
assert len(delete_response["delete_collection_failures"]) == 1

0 commit comments

Comments
 (0)