Skip to content

Commit bc0e6a4

Browse files
author
Keshav Chandak
committed
Feat: Add/Remove model package group from collection
1 parent 2fc3ed5 commit bc0e6a4

File tree

3 files changed

+341
-28
lines changed

3 files changed

+341
-28
lines changed

src/sagemaker/collection.py

Lines changed: 184 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,41 @@ def _check_access_error(self, err: ClientError):
5656
"https://docs.aws.amazon.com/sagemaker/latest/dg/modelcollections-permissions.html"
5757
)
5858

59+
def _add_model_group(self, model_package_group, tag_rule_key, tag_rule_value):
60+
"""To add a model package group to a collection
61+
62+
Args:
63+
model_package_group (str): The name of the model package group
64+
tag_rule_key (str): The tag key of the corresponing collection to be added into
65+
tag_rule_value (str): The tag value of the corresponing collection to be added into
66+
"""
67+
model_group_details = self.sagemaker_session.sagemaker_client.describe_model_package_group(
68+
ModelPackageGroupName=model_package_group
69+
)
70+
self.sagemaker_session.sagemaker_client.add_tags(
71+
ResourceArn=model_group_details["ModelPackageGroupArn"],
72+
Tags=[
73+
{
74+
"Key": tag_rule_key,
75+
"Value": tag_rule_value,
76+
}
77+
],
78+
)
79+
80+
def _remove_model_group(self, model_package_group, tag_rule_key):
81+
"""To remove a model package group from a collection
82+
83+
Args:
84+
model_package_group (str): The name of the model package group
85+
tag_rule_key (str): The tag key of the corresponing collection to be removed from
86+
"""
87+
model_group_details = self.sagemaker_session.sagemaker_client.describe_model_package_group(
88+
ModelPackageGroupName=model_package_group
89+
)
90+
self.sagemaker_session.sagemaker_client.delete_tags(
91+
ResourceArn=model_group_details["ModelPackageGroupArn"], TagKeys=[tag_rule_key]
92+
)
93+
5994
def create(self, collection_name: str, parent_collection_name: str = None):
6095
"""Creates a collection
6196
@@ -65,38 +100,22 @@ def create(self, collection_name: str, parent_collection_name: str = None):
65100
To be None if the collection is to be created on the root level
66101
"""
67102

68-
tag_rule_key = f"sagemaker:collection-path:{time.time()}"
103+
tag_rule_key = f"sagemaker:collection-path:{int(time.time() * 1000)}"
69104
tags_on_collection = {
70105
"sagemaker:collection": "true",
71106
"sagemaker:collection-path:root": "true",
72107
}
73108
tag_rule_values = [collection_name]
74109

75110
if parent_collection_name is not None:
76-
try:
77-
group_query = self.sagemaker_session.get_resource_group_query(
78-
group=parent_collection_name
79-
)
80-
except ClientError as e:
81-
error_code = e.response["Error"]["Code"]
82-
83-
if error_code == "NotFoundException":
84-
raise ValueError(f"Cannot find collection: {parent_collection_name}")
85-
self._check_access_error(err=e)
86-
raise
87-
if group_query.get("GroupQuery"):
88-
parent_tag_rule_query = json.loads(
89-
group_query["GroupQuery"].get("ResourceQuery", {}).get("Query", "")
90-
)
91-
parent_tag_rule = parent_tag_rule_query.get("TagFilters", [])[0]
92-
if not parent_tag_rule:
93-
raise "Invalid parent_collection_name"
94-
parent_tag_value = parent_tag_rule["Values"][0]
95-
tags_on_collection = {
96-
parent_tag_rule["Key"]: parent_tag_value,
97-
"sagemaker:collection": "true",
98-
}
99-
tag_rule_values = [f"{parent_tag_value}/{collection_name}"]
111+
parent_tag_rules = self._get_collection_tag_rule(collection_name=parent_collection_name)
112+
parent_tag_rule_key = parent_tag_rules["tag_rule_key"]
113+
parent_tag_value = parent_tag_rules["tag_rule_value"]
114+
tags_on_collection = {
115+
parent_tag_rule_key: parent_tag_value,
116+
"sagemaker:collection": "true",
117+
}
118+
tag_rule_values = [f"{parent_tag_value}/{collection_name}"]
100119
try:
101120
resource_filters = [
102121
"AWS::SageMaker::ModelPackageGroup",
@@ -122,7 +141,6 @@ def create(self, collection_name: str, parent_collection_name: str = None):
122141
"Name": collection_create_response["Group"]["Name"],
123142
"Arn": collection_create_response["Group"]["GroupArn"],
124143
}
125-
126144
except ClientError as e:
127145
message = e.response["Error"]["Message"]
128146
error_code = e.response["Error"]["Code"]
@@ -134,7 +152,7 @@ def create(self, collection_name: str, parent_collection_name: str = None):
134152
raise
135153

136154
def delete(self, collections: List[str]):
137-
"""Deletes a lits of collection
155+
"""Deletes a list of collection.
138156
139157
Args:
140158
collections (List[str]): List of collections to be deleted
@@ -152,6 +170,8 @@ def delete(self, collections: List[str]):
152170
"Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"],
153171
},
154172
]
173+
174+
# loops over the list of collection and deletes one at a time.
155175
for collection in collections:
156176
try:
157177
collection_details = self.sagemaker_session.list_group_resources(
@@ -180,3 +200,140 @@ def delete(self, collections: List[str]):
180200
"deleted_collections": deleted_collection,
181201
"delete_collection_failures": delete_collection_failures,
182202
}
203+
204+
def _get_collection_tag_rule(self, collection_name: str):
205+
"""Returns the tag rule key and value for a collection"""
206+
207+
if collection_name is not None:
208+
try:
209+
group_query = self.sagemaker_session.get_resource_group_query(group=collection_name)
210+
except ClientError as e:
211+
error_code = e.response["Error"]["Code"]
212+
213+
if error_code == "NotFoundException":
214+
raise ValueError(f"Cannot find collection: {collection_name}")
215+
self._check_access_error(err=e)
216+
raise
217+
if group_query.get("GroupQuery"):
218+
tag_rule_query = json.loads(
219+
group_query["GroupQuery"].get("ResourceQuery", {}).get("Query", "")
220+
)
221+
tag_rule = tag_rule_query.get("TagFilters", [])[0]
222+
if not tag_rule:
223+
raise "Unsupported parent_collection_name"
224+
tag_rule_value = tag_rule["Values"][0]
225+
tag_rule_key = tag_rule["Key"]
226+
227+
return {
228+
"tag_rule_key": tag_rule_key,
229+
"tag_rule_value": tag_rule_value,
230+
}
231+
raise ValueError("Collection name is required")
232+
233+
def add_model_groups(self, collection_name: str, model_groups: List[str]):
234+
"""To add list of model package groups to a collection
235+
236+
Args:
237+
collection_name (str): The name of the collection
238+
model_groups List[str]: Model pckage group names list to be added into the collection
239+
"""
240+
if len(model_groups) > 10:
241+
raise Exception("Model groups can have a maximum length of 10")
242+
tag_rules = self._get_collection_tag_rule(collection_name=collection_name)
243+
tag_rule_key = tag_rules["tag_rule_key"]
244+
tag_rule_value = tag_rules["tag_rule_value"]
245+
246+
add_groups_success = []
247+
add_groups_failure = []
248+
if tag_rule_key is not None and tag_rule_value is not None:
249+
for model_group in model_groups:
250+
try:
251+
self._add_model_group(
252+
model_package_group=model_group,
253+
tag_rule_key=tag_rule_key,
254+
tag_rule_value=tag_rule_value,
255+
)
256+
add_groups_success.append(model_group)
257+
except ClientError as e:
258+
self._check_access_error(err=e)
259+
message = e.response["Error"]["Message"]
260+
add_groups_failure.append(
261+
{
262+
"model_group": model_group,
263+
"failure_reason": message,
264+
}
265+
)
266+
return {
267+
"added_groups": add_groups_success,
268+
"failure": add_groups_failure,
269+
}
270+
271+
def remove_model_groups(self, collection_name: str, model_groups: List[str]):
272+
"""To remove list of model package groups from a collection
273+
274+
Args:
275+
collection_name (str): The name of the collection
276+
model_groups List[str]: Model package group names list to be removed
277+
"""
278+
279+
if len(model_groups) > 10:
280+
raise Exception("Model groups can have a maximum length of 10")
281+
tag_rules = self._get_collection_tag_rule(collection_name=collection_name)
282+
283+
tag_rule_key = tag_rules["tag_rule_key"]
284+
tag_rule_value = tag_rules["tag_rule_value"]
285+
286+
remove_groups_success = []
287+
remove_groups_failure = []
288+
if tag_rule_key is not None and tag_rule_value is not None:
289+
for model_group in model_groups:
290+
try:
291+
self._remove_model_group(
292+
model_package_group=model_group,
293+
tag_rule_key=tag_rule_key,
294+
)
295+
remove_groups_success.append(model_group)
296+
except ClientError as e:
297+
self._check_access_error(err=e)
298+
message = e.response["Error"]["Message"]
299+
remove_groups_failure.append(
300+
{
301+
"model_group": model_group,
302+
"failure_reason": message,
303+
}
304+
)
305+
return {
306+
"removed_groups": remove_groups_success,
307+
"failure": remove_groups_failure,
308+
}
309+
310+
def move_model_group(
311+
self, source_collection_name: str, model_group: str, destination_collection_name: str
312+
):
313+
"""To move a model package group from one collection to another
314+
315+
Args:
316+
source_collection_name (str): Collection name of the source
317+
model_group (str): Model package group names which is to be moved
318+
destination_collection_name (str): Collection name of the destination
319+
"""
320+
remove_details = self.remove_model_groups(
321+
collection_name=source_collection_name, model_groups=[model_group]
322+
)
323+
if len(remove_details["failure"]) == 1:
324+
raise Exception(remove_details["failure"][0]["failure"])
325+
326+
added_details = self.add_model_groups(
327+
collection_name=destination_collection_name, model_groups=[model_group]
328+
)
329+
330+
if len(added_details["failure"]) == 1:
331+
# adding the model group back to the source collection in case of an add failure
332+
self.add_model_groups(
333+
collection_name=source_collection_name, model_groups=[model_group]
334+
)
335+
raise Exception(added_details["failure"][0]["failure"])
336+
337+
return {
338+
"moved_success": model_group,
339+
}

tests/integ/test_collection.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,109 @@ def test_create_collection_nested_success(sagemaker_session):
6060
delete_response = collection.delete([child_collection_name, collection_name])
6161
assert len(delete_response["deleted_collections"]) == 2
6262
assert len(delete_response["delete_collection_failures"]) == 0
63+
64+
65+
def test_add_remove_model_groups_in_collection_success(sagemaker_session):
66+
model_group_name = unique_name_from_base("test-model-group")
67+
sagemaker_session.sagemaker_client.create_model_package_group(
68+
ModelPackageGroupName=model_group_name
69+
)
70+
collection = Collection(sagemaker_session)
71+
collection_name = unique_name_from_base("test-collection")
72+
collection.create(collection_name)
73+
model_groups = []
74+
model_groups.append(model_group_name)
75+
add_response = collection.add_model_groups(
76+
collection_name=collection_name, model_groups=model_groups
77+
)
78+
collection_filter = [
79+
{
80+
"Name": "resource-type",
81+
"Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"],
82+
},
83+
]
84+
collection_details = sagemaker_session.list_group_resources(
85+
group=collection_name, filters=collection_filter
86+
)
87+
88+
assert len(add_response["failure"]) == 0
89+
assert len(add_response["added_groups"]) == 1
90+
assert len(collection_details["Resources"]) == 1
91+
92+
remove_response = collection.remove_model_groups(
93+
collection_name=collection_name, model_groups=model_groups
94+
)
95+
collection_details = sagemaker_session.list_group_resources(
96+
group=collection_name, filters=collection_filter
97+
)
98+
assert len(remove_response["failure"]) == 0
99+
assert len(remove_response["removed_groups"]) == 1
100+
assert len(collection_details["Resources"]) == 0
101+
102+
delete_response = collection.delete([collection_name])
103+
assert len(delete_response["deleted_collections"]) == 1
104+
sagemaker_session.sagemaker_client.delete_model_package_group(
105+
ModelPackageGroupName=model_group_name
106+
)
107+
108+
109+
def test_move_model_groups_in_collection_success(sagemaker_session):
110+
model_group_name = unique_name_from_base("test-model-group")
111+
sagemaker_session.sagemaker_client.create_model_package_group(
112+
ModelPackageGroupName=model_group_name
113+
)
114+
collection = Collection(sagemaker_session)
115+
source_collection_name = unique_name_from_base("test-collection-source")
116+
destination_collection_name = unique_name_from_base("test-collection-destination")
117+
collection.create(source_collection_name)
118+
collection.create(destination_collection_name)
119+
model_groups = []
120+
model_groups.append(model_group_name)
121+
add_response = collection.add_model_groups(
122+
collection_name=source_collection_name, model_groups=model_groups
123+
)
124+
collection_filter = [
125+
{
126+
"Name": "resource-type",
127+
"Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"],
128+
},
129+
]
130+
collection_details = sagemaker_session.list_group_resources(
131+
group=source_collection_name, filters=collection_filter
132+
)
133+
134+
assert len(add_response["failure"]) == 0
135+
assert len(add_response["added_groups"]) == 1
136+
assert len(collection_details["Resources"]) == 1
137+
138+
move_response = collection.move_model_group(
139+
source_collection_name=source_collection_name,
140+
model_group=model_group_name,
141+
destination_collection_name=destination_collection_name,
142+
)
143+
144+
assert move_response["moved_success"] == model_group_name
145+
146+
collection_details = sagemaker_session.list_group_resources(
147+
group=destination_collection_name, filters=collection_filter
148+
)
149+
150+
assert len(collection_details["Resources"]) == 1
151+
152+
collection_details = sagemaker_session.list_group_resources(
153+
group=source_collection_name, filters=collection_filter
154+
)
155+
assert len(collection_details["Resources"]) == 0
156+
157+
remove_response = collection.remove_model_groups(
158+
collection_name=destination_collection_name, model_groups=model_groups
159+
)
160+
161+
assert len(remove_response["failure"]) == 0
162+
assert len(remove_response["removed_groups"]) == 1
163+
164+
delete_response = collection.delete([source_collection_name, destination_collection_name])
165+
assert len(delete_response["deleted_collections"]) == 2
166+
sagemaker_session.sagemaker_client.delete_model_package_group(
167+
ModelPackageGroupName=model_group_name
168+
)

0 commit comments

Comments
 (0)