@@ -147,7 +147,6 @@ def create(self, collection_name: str, parent_collection_name: str = None):
147
147
148
148
if error_code == "BadRequestException" and "group already exists" in message :
149
149
raise ValueError ("Collection with the given name already exists" )
150
-
151
150
self ._check_access_error (err = e )
152
151
raise
153
152
@@ -337,3 +336,127 @@ def move_model_group(
337
336
return {
338
337
"moved_success" : model_group ,
339
338
}
339
+
340
+ def _convert_tag_collection_response (self , tag_collections : List [str ]):
341
+ """Converts collection response from tag api to collection list response
342
+
343
+ Args:
344
+ tag_collections List[dict]: Collections list response from tag api
345
+ """
346
+ collection_details = []
347
+ for collection in tag_collections :
348
+ collection_arn = collection ["ResourceARN" ]
349
+ collection_name = collection_arn .split ("group/" )[1 ]
350
+ collection_details .append (
351
+ {
352
+ "Name" : collection_name ,
353
+ "Arn" : collection_arn ,
354
+ "Type" : "Collection" ,
355
+ }
356
+ )
357
+ return collection_details
358
+
359
+ def _convert_group_resource_response (
360
+ self , group_resource_details : List [dict ], is_model_group : bool = False
361
+ ):
362
+ """Converts collection response from resource group api to collection list response
363
+
364
+ Args:
365
+ group_resource_details (List[dict]): Collections list response from resource group api
366
+ is_model_group (bool): If the reponse is of collection or model group type
367
+ """
368
+ collection_details = []
369
+ if group_resource_details ["Resources" ]:
370
+ for resource_group in group_resource_details ["Resources" ]:
371
+ collection_arn = resource_group ["Identifier" ]["ResourceArn" ]
372
+ collection_name = collection_arn .split ("group/" )[1 ]
373
+ collection_details .append (
374
+ {
375
+ "Name" : collection_name ,
376
+ "Arn" : collection_arn ,
377
+ "Type" : resource_group ["Identifier" ]["ResourceType" ]
378
+ if is_model_group
379
+ else "Collection" ,
380
+ }
381
+ )
382
+ return collection_details
383
+
384
+ def _get_full_list_resource (self , collection_name , collection_filter ):
385
+ """Iterating to the full resource group list and returns appended paginated response
386
+
387
+ Args:
388
+ collection_name (str): Name of the collection to get the details
389
+ collection_filter (dict): Filter details to be passed to get the resource list
390
+
391
+ """
392
+ list_group_response = self .sagemaker_session .list_group_resources (
393
+ group = collection_name , filters = collection_filter
394
+ )
395
+ next_token = list_group_response .get ("NextToken" )
396
+ while next_token is not None :
397
+
398
+ paginated_group_response = self .sagemaker_session .list_group_resources (
399
+ group = collection_name ,
400
+ filters = collection_filter ,
401
+ next_token = next_token ,
402
+ )
403
+ list_group_response ["Resources" ] = (
404
+ list_group_response ["Resources" ] + paginated_group_response ["Resources" ]
405
+ )
406
+ list_group_response ["ResourceIdentifiers" ] = (
407
+ list_group_response ["ResourceIdentifiers" ]
408
+ + paginated_group_response ["ResourceIdentifiers" ]
409
+ )
410
+ next_token = paginated_group_response .get ("NextToken" )
411
+
412
+ return list_group_response
413
+
414
+ def list_collection (self , collection_name : str = None ):
415
+ """To all list the collections and content of the collections
416
+
417
+ In case there is no collection_name, it lists all the collections on the root level
418
+
419
+ Args:
420
+ collection_name (str): The name of the collection to list the contents of
421
+ """
422
+ collection_content = []
423
+ if collection_name is None :
424
+ tag_filters = [
425
+ {
426
+ "Key" : "sagemaker:collection-path:root" ,
427
+ "Values" : ["true" ],
428
+ },
429
+ ]
430
+ resource_type_filters = ["resource-groups:group" ]
431
+ tag_collections = self .sagemaker_session .get_tagging_resources (
432
+ tag_filters = tag_filters , resource_type_filters = resource_type_filters
433
+ )
434
+
435
+ return self ._convert_tag_collection_response (tag_collections )
436
+
437
+ collection_filter = [
438
+ {
439
+ "Name" : "resource-type" ,
440
+ "Values" : ["AWS::ResourceGroups::Group" ],
441
+ },
442
+ ]
443
+ list_group_response = self ._get_full_list_resource (
444
+ collection_name = collection_name , collection_filter = collection_filter
445
+ )
446
+ collection_content = self ._convert_group_resource_response (list_group_response )
447
+
448
+ collection_filter = [
449
+ {
450
+ "Name" : "resource-type" ,
451
+ "Values" : ["AWS::SageMaker::ModelPackageGroup" ],
452
+ },
453
+ ]
454
+ list_group_response = self ._get_full_list_resource (
455
+ collection_name = collection_name , collection_filter = collection_filter
456
+ )
457
+
458
+ collection_content = collection_content + self ._convert_group_resource_response (
459
+ list_group_response , True
460
+ )
461
+
462
+ return collection_content
0 commit comments