Skip to content

Commit 66fe78d

Browse files
keshav-chandakKeshav Chandak
andauthored
Feature/list collection (#3781)
Co-authored-by: Keshav Chandak <[email protected]>
1 parent c4669c2 commit 66fe78d

File tree

4 files changed

+533
-31
lines changed

4 files changed

+533
-31
lines changed

src/sagemaker/collection.py

Lines changed: 308 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,41 @@ def _check_access_error(self, err: ClientError):
5656
"https://docs.aws.amazon.com/sagemaker/latest/dg/modelcollections-permissions.html"
5757
)
5858

59+
def _add_model_group(self, model_package_group, tag_rule_key, tag_rule_value):
60+
"""To add a model package group to a collection
61+
62+
Args:
63+
model_package_group (str): The name of the model package group
64+
tag_rule_key (str): The tag key of the corresponing collection to be added into
65+
tag_rule_value (str): The tag value of the corresponing collection to be added into
66+
"""
67+
model_group_details = self.sagemaker_session.sagemaker_client.describe_model_package_group(
68+
ModelPackageGroupName=model_package_group
69+
)
70+
self.sagemaker_session.sagemaker_client.add_tags(
71+
ResourceArn=model_group_details["ModelPackageGroupArn"],
72+
Tags=[
73+
{
74+
"Key": tag_rule_key,
75+
"Value": tag_rule_value,
76+
}
77+
],
78+
)
79+
80+
def _remove_model_group(self, model_package_group, tag_rule_key):
81+
"""To remove a model package group from a collection
82+
83+
Args:
84+
model_package_group (str): The name of the model package group
85+
tag_rule_key (str): The tag key of the corresponing collection to be removed from
86+
"""
87+
model_group_details = self.sagemaker_session.sagemaker_client.describe_model_package_group(
88+
ModelPackageGroupName=model_package_group
89+
)
90+
self.sagemaker_session.sagemaker_client.delete_tags(
91+
ResourceArn=model_group_details["ModelPackageGroupArn"], TagKeys=[tag_rule_key]
92+
)
93+
5994
def create(self, collection_name: str, parent_collection_name: str = None):
6095
"""Creates a collection
6196
@@ -65,38 +100,22 @@ def create(self, collection_name: str, parent_collection_name: str = None):
65100
To be None if the collection is to be created on the root level
66101
"""
67102

68-
tag_rule_key = f"sagemaker:collection-path:{time.time()}"
103+
tag_rule_key = f"sagemaker:collection-path:{int(time.time() * 1000)}"
69104
tags_on_collection = {
70105
"sagemaker:collection": "true",
71106
"sagemaker:collection-path:root": "true",
72107
}
73108
tag_rule_values = [collection_name]
74109

75110
if parent_collection_name is not None:
76-
try:
77-
group_query = self.sagemaker_session.get_resource_group_query(
78-
group=parent_collection_name
79-
)
80-
except ClientError as e:
81-
error_code = e.response["Error"]["Code"]
82-
83-
if error_code == "NotFoundException":
84-
raise ValueError(f"Cannot find collection: {parent_collection_name}")
85-
self._check_access_error(err=e)
86-
raise
87-
if group_query.get("GroupQuery"):
88-
parent_tag_rule_query = json.loads(
89-
group_query["GroupQuery"].get("ResourceQuery", {}).get("Query", "")
90-
)
91-
parent_tag_rule = parent_tag_rule_query.get("TagFilters", [])[0]
92-
if not parent_tag_rule:
93-
raise "Invalid parent_collection_name"
94-
parent_tag_value = parent_tag_rule["Values"][0]
95-
tags_on_collection = {
96-
parent_tag_rule["Key"]: parent_tag_value,
97-
"sagemaker:collection": "true",
98-
}
99-
tag_rule_values = [f"{parent_tag_value}/{collection_name}"]
111+
parent_tag_rules = self._get_collection_tag_rule(collection_name=parent_collection_name)
112+
parent_tag_rule_key = parent_tag_rules["tag_rule_key"]
113+
parent_tag_value = parent_tag_rules["tag_rule_value"]
114+
tags_on_collection = {
115+
parent_tag_rule_key: parent_tag_value,
116+
"sagemaker:collection": "true",
117+
}
118+
tag_rule_values = [f"{parent_tag_value}/{collection_name}"]
100119
try:
101120
resource_filters = [
102121
"AWS::SageMaker::ModelPackageGroup",
@@ -122,19 +141,17 @@ def create(self, collection_name: str, parent_collection_name: str = None):
122141
"Name": collection_create_response["Group"]["Name"],
123142
"Arn": collection_create_response["Group"]["GroupArn"],
124143
}
125-
126144
except ClientError as e:
127145
message = e.response["Error"]["Message"]
128146
error_code = e.response["Error"]["Code"]
129147

130148
if error_code == "BadRequestException" and "group already exists" in message:
131149
raise ValueError("Collection with the given name already exists")
132-
133150
self._check_access_error(err=e)
134151
raise
135152

136153
def delete(self, collections: List[str]):
137-
"""Deletes a lits of collection
154+
"""Deletes a list of collection.
138155
139156
Args:
140157
collections (List[str]): List of collections to be deleted
@@ -152,6 +169,8 @@ def delete(self, collections: List[str]):
152169
"Values": ["AWS::ResourceGroups::Group", "AWS::SageMaker::ModelPackageGroup"],
153170
},
154171
]
172+
173+
# loops over the list of collection and deletes one at a time.
155174
for collection in collections:
156175
try:
157176
collection_details = self.sagemaker_session.list_group_resources(
@@ -180,3 +199,264 @@ def delete(self, collections: List[str]):
180199
"deleted_collections": deleted_collection,
181200
"delete_collection_failures": delete_collection_failures,
182201
}
202+
203+
def _get_collection_tag_rule(self, collection_name: str):
204+
"""Returns the tag rule key and value for a collection"""
205+
206+
if collection_name is not None:
207+
try:
208+
group_query = self.sagemaker_session.get_resource_group_query(group=collection_name)
209+
except ClientError as e:
210+
error_code = e.response["Error"]["Code"]
211+
212+
if error_code == "NotFoundException":
213+
raise ValueError(f"Cannot find collection: {collection_name}")
214+
self._check_access_error(err=e)
215+
raise
216+
if group_query.get("GroupQuery"):
217+
tag_rule_query = json.loads(
218+
group_query["GroupQuery"].get("ResourceQuery", {}).get("Query", "")
219+
)
220+
tag_rule = tag_rule_query.get("TagFilters", [])[0]
221+
if not tag_rule:
222+
raise "Unsupported parent_collection_name"
223+
tag_rule_value = tag_rule["Values"][0]
224+
tag_rule_key = tag_rule["Key"]
225+
226+
return {
227+
"tag_rule_key": tag_rule_key,
228+
"tag_rule_value": tag_rule_value,
229+
}
230+
raise ValueError("Collection name is required")
231+
232+
def add_model_groups(self, collection_name: str, model_groups: List[str]):
233+
"""To add list of model package groups to a collection
234+
235+
Args:
236+
collection_name (str): The name of the collection
237+
model_groups List[str]: Model pckage group names list to be added into the collection
238+
"""
239+
if len(model_groups) > 10:
240+
raise Exception("Model groups can have a maximum length of 10")
241+
tag_rules = self._get_collection_tag_rule(collection_name=collection_name)
242+
tag_rule_key = tag_rules["tag_rule_key"]
243+
tag_rule_value = tag_rules["tag_rule_value"]
244+
245+
add_groups_success = []
246+
add_groups_failure = []
247+
if tag_rule_key is not None and tag_rule_value is not None:
248+
for model_group in model_groups:
249+
try:
250+
self._add_model_group(
251+
model_package_group=model_group,
252+
tag_rule_key=tag_rule_key,
253+
tag_rule_value=tag_rule_value,
254+
)
255+
add_groups_success.append(model_group)
256+
except ClientError as e:
257+
self._check_access_error(err=e)
258+
message = e.response["Error"]["Message"]
259+
add_groups_failure.append(
260+
{
261+
"model_group": model_group,
262+
"failure_reason": message,
263+
}
264+
)
265+
return {
266+
"added_groups": add_groups_success,
267+
"failure": add_groups_failure,
268+
}
269+
270+
def remove_model_groups(self, collection_name: str, model_groups: List[str]):
271+
"""To remove list of model package groups from a collection
272+
273+
Args:
274+
collection_name (str): The name of the collection
275+
model_groups List[str]: Model package group names list to be removed
276+
"""
277+
278+
if len(model_groups) > 10:
279+
raise Exception("Model groups can have a maximum length of 10")
280+
tag_rules = self._get_collection_tag_rule(collection_name=collection_name)
281+
282+
tag_rule_key = tag_rules["tag_rule_key"]
283+
tag_rule_value = tag_rules["tag_rule_value"]
284+
285+
remove_groups_success = []
286+
remove_groups_failure = []
287+
if tag_rule_key is not None and tag_rule_value is not None:
288+
for model_group in model_groups:
289+
try:
290+
self._remove_model_group(
291+
model_package_group=model_group,
292+
tag_rule_key=tag_rule_key,
293+
)
294+
remove_groups_success.append(model_group)
295+
except ClientError as e:
296+
self._check_access_error(err=e)
297+
message = e.response["Error"]["Message"]
298+
remove_groups_failure.append(
299+
{
300+
"model_group": model_group,
301+
"failure_reason": message,
302+
}
303+
)
304+
return {
305+
"removed_groups": remove_groups_success,
306+
"failure": remove_groups_failure,
307+
}
308+
309+
def move_model_group(
310+
self, source_collection_name: str, model_group: str, destination_collection_name: str
311+
):
312+
"""To move a model package group from one collection to another
313+
314+
Args:
315+
source_collection_name (str): Collection name of the source
316+
model_group (str): Model package group names which is to be moved
317+
destination_collection_name (str): Collection name of the destination
318+
"""
319+
remove_details = self.remove_model_groups(
320+
collection_name=source_collection_name, model_groups=[model_group]
321+
)
322+
if len(remove_details["failure"]) == 1:
323+
raise Exception(remove_details["failure"][0]["failure"])
324+
325+
added_details = self.add_model_groups(
326+
collection_name=destination_collection_name, model_groups=[model_group]
327+
)
328+
329+
if len(added_details["failure"]) == 1:
330+
# adding the model group back to the source collection in case of an add failure
331+
self.add_model_groups(
332+
collection_name=source_collection_name, model_groups=[model_group]
333+
)
334+
raise Exception(added_details["failure"][0]["failure"])
335+
336+
return {
337+
"moved_success": model_group,
338+
}
339+
340+
def _convert_tag_collection_response(self, tag_collections: List[str]):
341+
"""Converts collection response from tag api to collection list response
342+
343+
Args:
344+
tag_collections List[dict]: Collections list response from tag api
345+
"""
346+
collection_details = []
347+
for collection in tag_collections:
348+
collection_arn = collection["ResourceARN"]
349+
collection_name = collection_arn.split("group/")[1]
350+
collection_details.append(
351+
{
352+
"Name": collection_name,
353+
"Arn": collection_arn,
354+
"Type": "Collection",
355+
}
356+
)
357+
return collection_details
358+
359+
def _convert_group_resource_response(
360+
self, group_resource_details: List[dict], is_model_group: bool = False
361+
):
362+
"""Converts collection response from resource group api to collection list response
363+
364+
Args:
365+
group_resource_details (List[dict]): Collections list response from resource group api
366+
is_model_group (bool): If the reponse is of collection or model group type
367+
"""
368+
collection_details = []
369+
if group_resource_details["Resources"]:
370+
for resource_group in group_resource_details["Resources"]:
371+
collection_arn = resource_group["Identifier"]["ResourceArn"]
372+
collection_name = collection_arn.split("group/")[1]
373+
collection_details.append(
374+
{
375+
"Name": collection_name,
376+
"Arn": collection_arn,
377+
"Type": resource_group["Identifier"]["ResourceType"]
378+
if is_model_group
379+
else "Collection",
380+
}
381+
)
382+
return collection_details
383+
384+
def _get_full_list_resource(self, collection_name, collection_filter):
385+
"""Iterating to the full resource group list and returns appended paginated response
386+
387+
Args:
388+
collection_name (str): Name of the collection to get the details
389+
collection_filter (dict): Filter details to be passed to get the resource list
390+
391+
"""
392+
list_group_response = self.sagemaker_session.list_group_resources(
393+
group=collection_name, filters=collection_filter
394+
)
395+
next_token = list_group_response.get("NextToken")
396+
while next_token is not None:
397+
398+
paginated_group_response = self.sagemaker_session.list_group_resources(
399+
group=collection_name,
400+
filters=collection_filter,
401+
next_token=next_token,
402+
)
403+
list_group_response["Resources"] = (
404+
list_group_response["Resources"] + paginated_group_response["Resources"]
405+
)
406+
list_group_response["ResourceIdentifiers"] = (
407+
list_group_response["ResourceIdentifiers"]
408+
+ paginated_group_response["ResourceIdentifiers"]
409+
)
410+
next_token = paginated_group_response.get("NextToken")
411+
412+
return list_group_response
413+
414+
def list_collection(self, collection_name: str = None):
415+
"""To all list the collections and content of the collections
416+
417+
In case there is no collection_name, it lists all the collections on the root level
418+
419+
Args:
420+
collection_name (str): The name of the collection to list the contents of
421+
"""
422+
collection_content = []
423+
if collection_name is None:
424+
tag_filters = [
425+
{
426+
"Key": "sagemaker:collection-path:root",
427+
"Values": ["true"],
428+
},
429+
]
430+
resource_type_filters = ["resource-groups:group"]
431+
tag_collections = self.sagemaker_session.get_tagging_resources(
432+
tag_filters=tag_filters, resource_type_filters=resource_type_filters
433+
)
434+
435+
return self._convert_tag_collection_response(tag_collections)
436+
437+
collection_filter = [
438+
{
439+
"Name": "resource-type",
440+
"Values": ["AWS::ResourceGroups::Group"],
441+
},
442+
]
443+
list_group_response = self._get_full_list_resource(
444+
collection_name=collection_name, collection_filter=collection_filter
445+
)
446+
collection_content = self._convert_group_resource_response(list_group_response)
447+
448+
collection_filter = [
449+
{
450+
"Name": "resource-type",
451+
"Values": ["AWS::SageMaker::ModelPackageGroup"],
452+
},
453+
]
454+
list_group_response = self._get_full_list_resource(
455+
collection_name=collection_name, collection_filter=collection_filter
456+
)
457+
458+
collection_content = collection_content + self._convert_group_resource_response(
459+
list_group_response, True
460+
)
461+
462+
return collection_content

0 commit comments

Comments
 (0)