@@ -56,6 +56,41 @@ def _check_access_error(self, err: ClientError):
56
56
"https://docs.aws.amazon.com/sagemaker/latest/dg/modelcollections-permissions.html"
57
57
)
58
58
59
+ def _add_model_group (self , model_package_group , tag_rule_key , tag_rule_value ):
60
+ """To add a model package group to a collection
61
+
62
+ Args:
63
+ model_package_group (str): The name of the model package group
64
+ tag_rule_key (str): The tag key of the corresponing collection to be added into
65
+ tag_rule_value (str): The tag value of the corresponing collection to be added into
66
+ """
67
+ model_group_details = self .sagemaker_session .sagemaker_client .describe_model_package_group (
68
+ ModelPackageGroupName = model_package_group
69
+ )
70
+ self .sagemaker_session .sagemaker_client .add_tags (
71
+ ResourceArn = model_group_details ["ModelPackageGroupArn" ],
72
+ Tags = [
73
+ {
74
+ "Key" : tag_rule_key ,
75
+ "Value" : tag_rule_value ,
76
+ }
77
+ ],
78
+ )
79
+
80
+ def _remove_model_group (self , model_package_group , tag_rule_key ):
81
+ """To remove a model package group from a collection
82
+
83
+ Args:
84
+ model_package_group (str): The name of the model package group
85
+ tag_rule_key (str): The tag key of the corresponing collection to be removed from
86
+ """
87
+ model_group_details = self .sagemaker_session .sagemaker_client .describe_model_package_group (
88
+ ModelPackageGroupName = model_package_group
89
+ )
90
+ self .sagemaker_session .sagemaker_client .delete_tags (
91
+ ResourceArn = model_group_details ["ModelPackageGroupArn" ], TagKeys = [tag_rule_key ]
92
+ )
93
+
59
94
def create (self , collection_name : str , parent_collection_name : str = None ):
60
95
"""Creates a collection
61
96
@@ -65,38 +100,22 @@ def create(self, collection_name: str, parent_collection_name: str = None):
65
100
To be None if the collection is to be created on the root level
66
101
"""
67
102
68
- tag_rule_key = f"sagemaker:collection-path:{ time .time ()} "
103
+ tag_rule_key = f"sagemaker:collection-path:{ int ( time .time () * 1000 )} "
69
104
tags_on_collection = {
70
105
"sagemaker:collection" : "true" ,
71
106
"sagemaker:collection-path:root" : "true" ,
72
107
}
73
108
tag_rule_values = [collection_name ]
74
109
75
110
if parent_collection_name is not None :
76
- try :
77
- group_query = self .sagemaker_session .get_resource_group_query (
78
- group = parent_collection_name
79
- )
80
- except ClientError as e :
81
- error_code = e .response ["Error" ]["Code" ]
82
-
83
- if error_code == "NotFoundException" :
84
- raise ValueError (f"Cannot find collection: { parent_collection_name } " )
85
- self ._check_access_error (err = e )
86
- raise
87
- if group_query .get ("GroupQuery" ):
88
- parent_tag_rule_query = json .loads (
89
- group_query ["GroupQuery" ].get ("ResourceQuery" , {}).get ("Query" , "" )
90
- )
91
- parent_tag_rule = parent_tag_rule_query .get ("TagFilters" , [])[0 ]
92
- if not parent_tag_rule :
93
- raise "Invalid parent_collection_name"
94
- parent_tag_value = parent_tag_rule ["Values" ][0 ]
95
- tags_on_collection = {
96
- parent_tag_rule ["Key" ]: parent_tag_value ,
97
- "sagemaker:collection" : "true" ,
98
- }
99
- tag_rule_values = [f"{ parent_tag_value } /{ collection_name } " ]
111
+ parent_tag_rules = self ._get_collection_tag_rule (collection_name = parent_collection_name )
112
+ parent_tag_rule_key = parent_tag_rules ["tag_rule_key" ]
113
+ parent_tag_value = parent_tag_rules ["tag_rule_value" ]
114
+ tags_on_collection = {
115
+ parent_tag_rule_key : parent_tag_value ,
116
+ "sagemaker:collection" : "true" ,
117
+ }
118
+ tag_rule_values = [f"{ parent_tag_value } /{ collection_name } " ]
100
119
try :
101
120
resource_filters = [
102
121
"AWS::SageMaker::ModelPackageGroup" ,
@@ -122,19 +141,17 @@ def create(self, collection_name: str, parent_collection_name: str = None):
122
141
"Name" : collection_create_response ["Group" ]["Name" ],
123
142
"Arn" : collection_create_response ["Group" ]["GroupArn" ],
124
143
}
125
-
126
144
except ClientError as e :
127
145
message = e .response ["Error" ]["Message" ]
128
146
error_code = e .response ["Error" ]["Code" ]
129
147
130
148
if error_code == "BadRequestException" and "group already exists" in message :
131
149
raise ValueError ("Collection with the given name already exists" )
132
-
133
150
self ._check_access_error (err = e )
134
151
raise
135
152
136
153
def delete (self , collections : List [str ]):
137
- """Deletes a lits of collection
154
+ """Deletes a list of collection.
138
155
139
156
Args:
140
157
collections (List[str]): List of collections to be deleted
@@ -152,6 +169,8 @@ def delete(self, collections: List[str]):
152
169
"Values" : ["AWS::ResourceGroups::Group" , "AWS::SageMaker::ModelPackageGroup" ],
153
170
},
154
171
]
172
+
173
+ # loops over the list of collection and deletes one at a time.
155
174
for collection in collections :
156
175
try :
157
176
collection_details = self .sagemaker_session .list_group_resources (
@@ -180,3 +199,264 @@ def delete(self, collections: List[str]):
180
199
"deleted_collections" : deleted_collection ,
181
200
"delete_collection_failures" : delete_collection_failures ,
182
201
}
202
+
203
+ def _get_collection_tag_rule (self , collection_name : str ):
204
+ """Returns the tag rule key and value for a collection"""
205
+
206
+ if collection_name is not None :
207
+ try :
208
+ group_query = self .sagemaker_session .get_resource_group_query (group = collection_name )
209
+ except ClientError as e :
210
+ error_code = e .response ["Error" ]["Code" ]
211
+
212
+ if error_code == "NotFoundException" :
213
+ raise ValueError (f"Cannot find collection: { collection_name } " )
214
+ self ._check_access_error (err = e )
215
+ raise
216
+ if group_query .get ("GroupQuery" ):
217
+ tag_rule_query = json .loads (
218
+ group_query ["GroupQuery" ].get ("ResourceQuery" , {}).get ("Query" , "" )
219
+ )
220
+ tag_rule = tag_rule_query .get ("TagFilters" , [])[0 ]
221
+ if not tag_rule :
222
+ raise "Unsupported parent_collection_name"
223
+ tag_rule_value = tag_rule ["Values" ][0 ]
224
+ tag_rule_key = tag_rule ["Key" ]
225
+
226
+ return {
227
+ "tag_rule_key" : tag_rule_key ,
228
+ "tag_rule_value" : tag_rule_value ,
229
+ }
230
+ raise ValueError ("Collection name is required" )
231
+
232
+ def add_model_groups (self , collection_name : str , model_groups : List [str ]):
233
+ """To add list of model package groups to a collection
234
+
235
+ Args:
236
+ collection_name (str): The name of the collection
237
+ model_groups List[str]: Model pckage group names list to be added into the collection
238
+ """
239
+ if len (model_groups ) > 10 :
240
+ raise Exception ("Model groups can have a maximum length of 10" )
241
+ tag_rules = self ._get_collection_tag_rule (collection_name = collection_name )
242
+ tag_rule_key = tag_rules ["tag_rule_key" ]
243
+ tag_rule_value = tag_rules ["tag_rule_value" ]
244
+
245
+ add_groups_success = []
246
+ add_groups_failure = []
247
+ if tag_rule_key is not None and tag_rule_value is not None :
248
+ for model_group in model_groups :
249
+ try :
250
+ self ._add_model_group (
251
+ model_package_group = model_group ,
252
+ tag_rule_key = tag_rule_key ,
253
+ tag_rule_value = tag_rule_value ,
254
+ )
255
+ add_groups_success .append (model_group )
256
+ except ClientError as e :
257
+ self ._check_access_error (err = e )
258
+ message = e .response ["Error" ]["Message" ]
259
+ add_groups_failure .append (
260
+ {
261
+ "model_group" : model_group ,
262
+ "failure_reason" : message ,
263
+ }
264
+ )
265
+ return {
266
+ "added_groups" : add_groups_success ,
267
+ "failure" : add_groups_failure ,
268
+ }
269
+
270
+ def remove_model_groups (self , collection_name : str , model_groups : List [str ]):
271
+ """To remove list of model package groups from a collection
272
+
273
+ Args:
274
+ collection_name (str): The name of the collection
275
+ model_groups List[str]: Model package group names list to be removed
276
+ """
277
+
278
+ if len (model_groups ) > 10 :
279
+ raise Exception ("Model groups can have a maximum length of 10" )
280
+ tag_rules = self ._get_collection_tag_rule (collection_name = collection_name )
281
+
282
+ tag_rule_key = tag_rules ["tag_rule_key" ]
283
+ tag_rule_value = tag_rules ["tag_rule_value" ]
284
+
285
+ remove_groups_success = []
286
+ remove_groups_failure = []
287
+ if tag_rule_key is not None and tag_rule_value is not None :
288
+ for model_group in model_groups :
289
+ try :
290
+ self ._remove_model_group (
291
+ model_package_group = model_group ,
292
+ tag_rule_key = tag_rule_key ,
293
+ )
294
+ remove_groups_success .append (model_group )
295
+ except ClientError as e :
296
+ self ._check_access_error (err = e )
297
+ message = e .response ["Error" ]["Message" ]
298
+ remove_groups_failure .append (
299
+ {
300
+ "model_group" : model_group ,
301
+ "failure_reason" : message ,
302
+ }
303
+ )
304
+ return {
305
+ "removed_groups" : remove_groups_success ,
306
+ "failure" : remove_groups_failure ,
307
+ }
308
+
309
+ def move_model_group (
310
+ self , source_collection_name : str , model_group : str , destination_collection_name : str
311
+ ):
312
+ """To move a model package group from one collection to another
313
+
314
+ Args:
315
+ source_collection_name (str): Collection name of the source
316
+ model_group (str): Model package group names which is to be moved
317
+ destination_collection_name (str): Collection name of the destination
318
+ """
319
+ remove_details = self .remove_model_groups (
320
+ collection_name = source_collection_name , model_groups = [model_group ]
321
+ )
322
+ if len (remove_details ["failure" ]) == 1 :
323
+ raise Exception (remove_details ["failure" ][0 ]["failure" ])
324
+
325
+ added_details = self .add_model_groups (
326
+ collection_name = destination_collection_name , model_groups = [model_group ]
327
+ )
328
+
329
+ if len (added_details ["failure" ]) == 1 :
330
+ # adding the model group back to the source collection in case of an add failure
331
+ self .add_model_groups (
332
+ collection_name = source_collection_name , model_groups = [model_group ]
333
+ )
334
+ raise Exception (added_details ["failure" ][0 ]["failure" ])
335
+
336
+ return {
337
+ "moved_success" : model_group ,
338
+ }
339
+
340
+ def _convert_tag_collection_response (self , tag_collections : List [str ]):
341
+ """Converts collection response from tag api to collection list response
342
+
343
+ Args:
344
+ tag_collections List[dict]: Collections list response from tag api
345
+ """
346
+ collection_details = []
347
+ for collection in tag_collections :
348
+ collection_arn = collection ["ResourceARN" ]
349
+ collection_name = collection_arn .split ("group/" )[1 ]
350
+ collection_details .append (
351
+ {
352
+ "Name" : collection_name ,
353
+ "Arn" : collection_arn ,
354
+ "Type" : "Collection" ,
355
+ }
356
+ )
357
+ return collection_details
358
+
359
+ def _convert_group_resource_response (
360
+ self , group_resource_details : List [dict ], is_model_group : bool = False
361
+ ):
362
+ """Converts collection response from resource group api to collection list response
363
+
364
+ Args:
365
+ group_resource_details (List[dict]): Collections list response from resource group api
366
+ is_model_group (bool): If the reponse is of collection or model group type
367
+ """
368
+ collection_details = []
369
+ if group_resource_details ["Resources" ]:
370
+ for resource_group in group_resource_details ["Resources" ]:
371
+ collection_arn = resource_group ["Identifier" ]["ResourceArn" ]
372
+ collection_name = collection_arn .split ("group/" )[1 ]
373
+ collection_details .append (
374
+ {
375
+ "Name" : collection_name ,
376
+ "Arn" : collection_arn ,
377
+ "Type" : resource_group ["Identifier" ]["ResourceType" ]
378
+ if is_model_group
379
+ else "Collection" ,
380
+ }
381
+ )
382
+ return collection_details
383
+
384
+ def _get_full_list_resource (self , collection_name , collection_filter ):
385
+ """Iterating to the full resource group list and returns appended paginated response
386
+
387
+ Args:
388
+ collection_name (str): Name of the collection to get the details
389
+ collection_filter (dict): Filter details to be passed to get the resource list
390
+
391
+ """
392
+ list_group_response = self .sagemaker_session .list_group_resources (
393
+ group = collection_name , filters = collection_filter
394
+ )
395
+ next_token = list_group_response .get ("NextToken" )
396
+ while next_token is not None :
397
+
398
+ paginated_group_response = self .sagemaker_session .list_group_resources (
399
+ group = collection_name ,
400
+ filters = collection_filter ,
401
+ next_token = next_token ,
402
+ )
403
+ list_group_response ["Resources" ] = (
404
+ list_group_response ["Resources" ] + paginated_group_response ["Resources" ]
405
+ )
406
+ list_group_response ["ResourceIdentifiers" ] = (
407
+ list_group_response ["ResourceIdentifiers" ]
408
+ + paginated_group_response ["ResourceIdentifiers" ]
409
+ )
410
+ next_token = paginated_group_response .get ("NextToken" )
411
+
412
+ return list_group_response
413
+
414
+ def list_collection (self , collection_name : str = None ):
415
+ """To all list the collections and content of the collections
416
+
417
+ In case there is no collection_name, it lists all the collections on the root level
418
+
419
+ Args:
420
+ collection_name (str): The name of the collection to list the contents of
421
+ """
422
+ collection_content = []
423
+ if collection_name is None :
424
+ tag_filters = [
425
+ {
426
+ "Key" : "sagemaker:collection-path:root" ,
427
+ "Values" : ["true" ],
428
+ },
429
+ ]
430
+ resource_type_filters = ["resource-groups:group" ]
431
+ tag_collections = self .sagemaker_session .get_tagging_resources (
432
+ tag_filters = tag_filters , resource_type_filters = resource_type_filters
433
+ )
434
+
435
+ return self ._convert_tag_collection_response (tag_collections )
436
+
437
+ collection_filter = [
438
+ {
439
+ "Name" : "resource-type" ,
440
+ "Values" : ["AWS::ResourceGroups::Group" ],
441
+ },
442
+ ]
443
+ list_group_response = self ._get_full_list_resource (
444
+ collection_name = collection_name , collection_filter = collection_filter
445
+ )
446
+ collection_content = self ._convert_group_resource_response (list_group_response )
447
+
448
+ collection_filter = [
449
+ {
450
+ "Name" : "resource-type" ,
451
+ "Values" : ["AWS::SageMaker::ModelPackageGroup" ],
452
+ },
453
+ ]
454
+ list_group_response = self ._get_full_list_resource (
455
+ collection_name = collection_name , collection_filter = collection_filter
456
+ )
457
+
458
+ collection_content = collection_content + self ._convert_group_resource_response (
459
+ list_group_response , True
460
+ )
461
+
462
+ return collection_content
0 commit comments