Skip to content

Commit 98c7af0

Browse files
author
Keshav Chandak
committed
Feat: list model collection
1 parent bc0e6a4 commit 98c7af0

File tree

3 files changed

+192
-3
lines changed

3 files changed

+192
-3
lines changed

src/sagemaker/collection.py

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ def create(self, collection_name: str, parent_collection_name: str = None):
147147

148148
if error_code == "BadRequestException" and "group already exists" in message:
149149
raise ValueError("Collection with the given name already exists")
150-
151150
self._check_access_error(err=e)
152151
raise
153152

@@ -337,3 +336,127 @@ def move_model_group(
337336
return {
338337
"moved_success": model_group,
339338
}
339+
340+
def _convert_tag_collection_response(self, tag_collections: List[str]):
341+
"""Converts collection response from tag api to collection list response
342+
343+
Args:
344+
tag_collections List[dict]: Collections list response from tag api
345+
"""
346+
collection_details = []
347+
for collection in tag_collections:
348+
collection_arn = collection["ResourceARN"]
349+
collection_name = collection_arn.split("group/")[1]
350+
collection_details.append(
351+
{
352+
"Name": collection_name,
353+
"Arn": collection_arn,
354+
"Type": "Collection",
355+
}
356+
)
357+
return collection_details
358+
359+
def _convert_group_resource_response(
360+
self, group_resource_details: List[dict], is_model_group: bool = False
361+
):
362+
"""Converts collection response from resource group api to collection list response
363+
364+
Args:
365+
group_resource_details (List[dict]): Collections list response from resource group api
366+
is_model_group (bool): If the reponse is of collection or model group type
367+
"""
368+
collection_details = []
369+
if group_resource_details["Resources"]:
370+
for resource_group in group_resource_details["Resources"]:
371+
collection_arn = resource_group["Identifier"]["ResourceArn"]
372+
collection_name = collection_arn.split("group/")[1]
373+
collection_details.append(
374+
{
375+
"Name": collection_name,
376+
"Arn": collection_arn,
377+
"Type": resource_group["Identifier"]["ResourceType"]
378+
if is_model_group
379+
else "Collection",
380+
}
381+
)
382+
return collection_details
383+
384+
def _get_full_list_resource(self, collection_name, collection_filter):
385+
"""Iterating to the full resource group list and returns appended paginated response
386+
387+
Args:
388+
collection_name (str): Name of the collection to get the details
389+
collection_filter (dict): Filter details to be passed to get the resource list
390+
391+
"""
392+
list_group_response = self.sagemaker_session.list_group_resources(
393+
group=collection_name, filters=collection_filter
394+
)
395+
next_token = list_group_response.get("NextToken")
396+
while next_token is not None:
397+
398+
paginated_group_response = self.sagemaker_session.list_group_resources(
399+
group=collection_name,
400+
filters=collection_filter,
401+
next_token=next_token,
402+
)
403+
list_group_response["Resources"] = (
404+
list_group_response["Resources"] + paginated_group_response["Resources"]
405+
)
406+
list_group_response["ResourceIdentifiers"] = (
407+
list_group_response["ResourceIdentifiers"]
408+
+ paginated_group_response["ResourceIdentifiers"]
409+
)
410+
next_token = paginated_group_response.get("NextToken")
411+
412+
return list_group_response
413+
414+
def list_collection(self, collection_name: str = None):
415+
"""To all list the collections and content of the collections
416+
417+
In case there is no collection_name, it lists all the collections on the root level
418+
419+
Args:
420+
collection_name (str): The name of the collection to list the contents of
421+
"""
422+
collection_content = []
423+
if collection_name is None:
424+
tag_filters = [
425+
{
426+
"Key": "sagemaker:collection-path:root",
427+
"Values": ["true"],
428+
},
429+
]
430+
resource_type_filters = ["resource-groups:group"]
431+
tag_collections = self.sagemaker_session.get_tagging_resources(
432+
tag_filters=tag_filters, resource_type_filters=resource_type_filters
433+
)
434+
435+
return self._convert_tag_collection_response(tag_collections)
436+
437+
collection_filter = [
438+
{
439+
"Name": "resource-type",
440+
"Values": ["AWS::ResourceGroups::Group"],
441+
},
442+
]
443+
list_group_response = self._get_full_list_resource(
444+
collection_name=collection_name, collection_filter=collection_filter
445+
)
446+
collection_content = self._convert_group_resource_response(list_group_response)
447+
448+
collection_filter = [
449+
{
450+
"Name": "resource-type",
451+
"Values": ["AWS::SageMaker::ModelPackageGroup"],
452+
},
453+
]
454+
list_group_response = self._get_full_list_resource(
455+
collection_name=collection_name, collection_filter=collection_filter
456+
)
457+
458+
collection_content = collection_content + self._convert_group_resource_response(
459+
list_group_response, True
460+
)
461+
462+
return collection_content

src/sagemaker/session.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def __init__(
205205
self.s3_resource = None
206206
self.s3_client = None
207207
self.resource_groups_client = None
208+
self.resource_group_tagging_client = None
208209
self.config = None
209210
self.lambda_client = None
210211
self.settings = settings
@@ -3962,7 +3963,7 @@ def delete_model(self, model_name):
39623963
LOGGER.info("Deleting model with name: %s", model_name)
39633964
self.sagemaker_client.delete_model(ModelName=model_name)
39643965

3965-
def list_group_resources(self, group, filters):
3966+
def list_group_resources(self, group, filters, next_token: str = ""):
39663967
"""To list group resources with given filters
39673968
39683969
Args:
@@ -3972,7 +3973,9 @@ def list_group_resources(self, group, filters):
39723973
self.resource_groups_client = self.resource_groups_client or self.boto_session.client(
39733974
"resource-groups"
39743975
)
3975-
return self.resource_groups_client.list_group_resources(Group=group, Filters=filters)
3976+
return self.resource_groups_client.list_group_resources(
3977+
Group=group, Filters=filters, NextToken=next_token, MaxResults=3
3978+
)
39763979

39773980
def delete_resource_group(self, group):
39783981
"""To delete a resource group
@@ -3996,6 +3999,39 @@ def get_resource_group_query(self, group):
39963999
)
39974000
return self.resource_groups_client.get_group_query(Group=group)
39984001

4002+
def get_tagging_resources(self, tag_filters, resource_type_filters):
4003+
"""To list the complete resources for a particular resource group tag
4004+
4005+
tag_filters: filters for the tag
4006+
resource_type_filters: resource filter for the tag
4007+
"""
4008+
self.resource_group_tagging_client = (
4009+
self.resource_group_tagging_client
4010+
or self.boto_session.client("resourcegroupstaggingapi")
4011+
)
4012+
resource_list = []
4013+
4014+
try:
4015+
resource_tag_response = self.resource_group_tagging_client.get_resources(
4016+
TagFilters=tag_filters, ResourceTypeFilters=resource_type_filters
4017+
)
4018+
4019+
resource_list = resource_list + resource_tag_response["ResourceTagMappingList"]
4020+
4021+
next_token = resource_tag_response.get("PaginationToken")
4022+
while next_token is not None and next_token != "":
4023+
resource_tag_response = self.resource_group_tagging_client.get_resources(
4024+
TagFilters=tag_filters,
4025+
ResourceTypeFilters=resource_type_filters,
4026+
NextToken=next_token,
4027+
)
4028+
resource_list = resource_list + resource_tag_response["ResourceTagMappingList"]
4029+
next_token = resource_tag_response.get("PaginationToken")
4030+
4031+
return resource_list
4032+
except ClientError as error:
4033+
raise error
4034+
39994035
def create_group(self, name, resource_query, tags):
40004036
"""To create a AWS Resource Group
40014037

tests/integ/test_collection.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,3 +166,33 @@ def test_move_model_groups_in_collection_success(sagemaker_session):
166166
sagemaker_session.sagemaker_client.delete_model_package_group(
167167
ModelPackageGroupName=model_group_name
168168
)
169+
170+
171+
def test_list_collection_success(sagemaker_session):
172+
model_group_name = unique_name_from_base("test-model-group")
173+
sagemaker_session.sagemaker_client.create_model_package_group(
174+
ModelPackageGroupName=model_group_name
175+
)
176+
collection = Collection(sagemaker_session)
177+
collection_name = unique_name_from_base("test-collection")
178+
collection.create(collection_name)
179+
model_groups = []
180+
model_groups.append(model_group_name)
181+
collection.add_model_groups(collection_name=collection_name, model_groups=model_groups)
182+
child_collection_name = unique_name_from_base("test-collection")
183+
collection.create(parent_collection_name=collection_name, collection_name=child_collection_name)
184+
root_collections = collection.list_collection()
185+
is_collection_found = False
186+
for root_collection in root_collections:
187+
if root_collection["Name"] == collection_name:
188+
is_collection_found = True
189+
assert is_collection_found
190+
191+
collection_content = collection.list_collection(collection_name)
192+
assert len(collection_content) == 2
193+
194+
collection.remove_model_groups(collection_name=collection_name, model_groups=model_groups)
195+
collection.delete([child_collection_name, collection_name])
196+
sagemaker_session.sagemaker_client.delete_model_package_group(
197+
ModelPackageGroupName=model_group_name
198+
)

0 commit comments

Comments
 (0)