@@ -56,6 +56,41 @@ def _check_access_error(self, err: ClientError):
56
56
"https://docs.aws.amazon.com/sagemaker/latest/dg/modelcollections-permissions.html"
57
57
)
58
58
59
+ def _add_model_group (self , model_package_group , tag_rule_key , tag_rule_value ):
60
+ """To add a model package group to a collection
61
+
62
+ Args:
63
+ model_package_group (str): The name of the model package group
64
+ tag_rule_key (str): The tag key of the corresponing collection to be added into
65
+ tag_rule_value (str): The tag value of the corresponing collection to be added into
66
+ """
67
+ model_group_details = self .sagemaker_session .sagemaker_client .describe_model_package_group (
68
+ ModelPackageGroupName = model_package_group
69
+ )
70
+ self .sagemaker_session .sagemaker_client .add_tags (
71
+ ResourceArn = model_group_details ["ModelPackageGroupArn" ],
72
+ Tags = [
73
+ {
74
+ "Key" : tag_rule_key ,
75
+ "Value" : tag_rule_value ,
76
+ }
77
+ ],
78
+ )
79
+
80
+ def _remove_model_group (self , model_package_group , tag_rule_key ):
81
+ """To remove a model package group from a collection
82
+
83
+ Args:
84
+ model_package_group (str): The name of the model package group
85
+ tag_rule_key (str): The tag key of the corresponing collection to be removed from
86
+ """
87
+ model_group_details = self .sagemaker_session .sagemaker_client .describe_model_package_group (
88
+ ModelPackageGroupName = model_package_group
89
+ )
90
+ self .sagemaker_session .sagemaker_client .delete_tags (
91
+ ResourceArn = model_group_details ["ModelPackageGroupArn" ], TagKeys = [tag_rule_key ]
92
+ )
93
+
59
94
def create (self , collection_name : str , parent_collection_name : str = None ):
60
95
"""Creates a collection
61
96
@@ -65,38 +100,22 @@ def create(self, collection_name: str, parent_collection_name: str = None):
65
100
To be None if the collection is to be created on the root level
66
101
"""
67
102
68
- tag_rule_key = f"sagemaker:collection-path:{ time .time ()} "
103
+ tag_rule_key = f"sagemaker:collection-path:{ int ( time .time () * 1000 )} "
69
104
tags_on_collection = {
70
105
"sagemaker:collection" : "true" ,
71
106
"sagemaker:collection-path:root" : "true" ,
72
107
}
73
108
tag_rule_values = [collection_name ]
74
109
75
110
if parent_collection_name is not None :
76
- try :
77
- group_query = self .sagemaker_session .get_resource_group_query (
78
- group = parent_collection_name
79
- )
80
- except ClientError as e :
81
- error_code = e .response ["Error" ]["Code" ]
82
-
83
- if error_code == "NotFoundException" :
84
- raise ValueError (f"Cannot find collection: { parent_collection_name } " )
85
- self ._check_access_error (err = e )
86
- raise
87
- if group_query .get ("GroupQuery" ):
88
- parent_tag_rule_query = json .loads (
89
- group_query ["GroupQuery" ].get ("ResourceQuery" , {}).get ("Query" , "" )
90
- )
91
- parent_tag_rule = parent_tag_rule_query .get ("TagFilters" , [])[0 ]
92
- if not parent_tag_rule :
93
- raise "Invalid parent_collection_name"
94
- parent_tag_value = parent_tag_rule ["Values" ][0 ]
95
- tags_on_collection = {
96
- parent_tag_rule ["Key" ]: parent_tag_value ,
97
- "sagemaker:collection" : "true" ,
98
- }
99
- tag_rule_values = [f"{ parent_tag_value } /{ collection_name } " ]
111
+ parent_tag_rules = self ._get_collection_tag_rule (collection_name = parent_collection_name )
112
+ parent_tag_rule_key = parent_tag_rules ["tag_rule_key" ]
113
+ parent_tag_value = parent_tag_rules ["tag_rule_value" ]
114
+ tags_on_collection = {
115
+ parent_tag_rule_key : parent_tag_value ,
116
+ "sagemaker:collection" : "true" ,
117
+ }
118
+ tag_rule_values = [f"{ parent_tag_value } /{ collection_name } " ]
100
119
try :
101
120
resource_filters = [
102
121
"AWS::SageMaker::ModelPackageGroup" ,
@@ -122,7 +141,6 @@ def create(self, collection_name: str, parent_collection_name: str = None):
122
141
"Name" : collection_create_response ["Group" ]["Name" ],
123
142
"Arn" : collection_create_response ["Group" ]["GroupArn" ],
124
143
}
125
-
126
144
except ClientError as e :
127
145
message = e .response ["Error" ]["Message" ]
128
146
error_code = e .response ["Error" ]["Code" ]
@@ -134,7 +152,7 @@ def create(self, collection_name: str, parent_collection_name: str = None):
134
152
raise
135
153
136
154
def delete (self , collections : List [str ]):
137
- """Deletes a lits of collection
155
+ """Deletes a list of collection.
138
156
139
157
Args:
140
158
collections (List[str]): List of collections to be deleted
@@ -152,6 +170,8 @@ def delete(self, collections: List[str]):
152
170
"Values" : ["AWS::ResourceGroups::Group" , "AWS::SageMaker::ModelPackageGroup" ],
153
171
},
154
172
]
173
+
174
+ # loops over the list of collection and deletes one at a time.
155
175
for collection in collections :
156
176
try :
157
177
collection_details = self .sagemaker_session .list_group_resources (
@@ -180,3 +200,140 @@ def delete(self, collections: List[str]):
180
200
"deleted_collections" : deleted_collection ,
181
201
"delete_collection_failures" : delete_collection_failures ,
182
202
}
203
+
204
+ def _get_collection_tag_rule (self , collection_name : str ):
205
+ """Returns the tag rule key and value for a collection"""
206
+
207
+ if collection_name is not None :
208
+ try :
209
+ group_query = self .sagemaker_session .get_resource_group_query (group = collection_name )
210
+ except ClientError as e :
211
+ error_code = e .response ["Error" ]["Code" ]
212
+
213
+ if error_code == "NotFoundException" :
214
+ raise ValueError (f"Cannot find collection: { collection_name } " )
215
+ self ._check_access_error (err = e )
216
+ raise
217
+ if group_query .get ("GroupQuery" ):
218
+ tag_rule_query = json .loads (
219
+ group_query ["GroupQuery" ].get ("ResourceQuery" , {}).get ("Query" , "" )
220
+ )
221
+ tag_rule = tag_rule_query .get ("TagFilters" , [])[0 ]
222
+ if not tag_rule :
223
+ raise "Unsupported parent_collection_name"
224
+ tag_rule_value = tag_rule ["Values" ][0 ]
225
+ tag_rule_key = tag_rule ["Key" ]
226
+
227
+ return {
228
+ "tag_rule_key" : tag_rule_key ,
229
+ "tag_rule_value" : tag_rule_value ,
230
+ }
231
+ raise ValueError ("Collection name is required" )
232
+
233
+ def add_model_groups (self , collection_name : str , model_groups : List [str ]):
234
+ """To add list of model package groups to a collection
235
+
236
+ Args:
237
+ collection_name (str): The name of the collection
238
+ model_groups List[str]: Model pckage group names list to be added into the collection
239
+ """
240
+ if len (model_groups ) > 10 :
241
+ raise Exception ("Model groups can have a maximum length of 10" )
242
+ tag_rules = self ._get_collection_tag_rule (collection_name = collection_name )
243
+ tag_rule_key = tag_rules ["tag_rule_key" ]
244
+ tag_rule_value = tag_rules ["tag_rule_value" ]
245
+
246
+ add_groups_success = []
247
+ add_groups_failure = []
248
+ if tag_rule_key is not None and tag_rule_value is not None :
249
+ for model_group in model_groups :
250
+ try :
251
+ self ._add_model_group (
252
+ model_package_group = model_group ,
253
+ tag_rule_key = tag_rule_key ,
254
+ tag_rule_value = tag_rule_value ,
255
+ )
256
+ add_groups_success .append (model_group )
257
+ except ClientError as e :
258
+ self ._check_access_error (err = e )
259
+ message = e .response ["Error" ]["Message" ]
260
+ add_groups_failure .append (
261
+ {
262
+ "model_group" : model_group ,
263
+ "failure_reason" : message ,
264
+ }
265
+ )
266
+ return {
267
+ "added_groups" : add_groups_success ,
268
+ "failure" : add_groups_failure ,
269
+ }
270
+
271
+ def remove_model_groups (self , collection_name : str , model_groups : List [str ]):
272
+ """To remove list of model package groups from a collection
273
+
274
+ Args:
275
+ collection_name (str): The name of the collection
276
+ model_groups List[str]: Model package group names list to be removed
277
+ """
278
+
279
+ if len (model_groups ) > 10 :
280
+ raise Exception ("Model groups can have a maximum length of 10" )
281
+ tag_rules = self ._get_collection_tag_rule (collection_name = collection_name )
282
+
283
+ tag_rule_key = tag_rules ["tag_rule_key" ]
284
+ tag_rule_value = tag_rules ["tag_rule_value" ]
285
+
286
+ remove_groups_success = []
287
+ remove_groups_failure = []
288
+ if tag_rule_key is not None and tag_rule_value is not None :
289
+ for model_group in model_groups :
290
+ try :
291
+ self ._remove_model_group (
292
+ model_package_group = model_group ,
293
+ tag_rule_key = tag_rule_key ,
294
+ )
295
+ remove_groups_success .append (model_group )
296
+ except ClientError as e :
297
+ self ._check_access_error (err = e )
298
+ message = e .response ["Error" ]["Message" ]
299
+ remove_groups_failure .append (
300
+ {
301
+ "model_group" : model_group ,
302
+ "failure_reason" : message ,
303
+ }
304
+ )
305
+ return {
306
+ "removed_groups" : remove_groups_success ,
307
+ "failure" : remove_groups_failure ,
308
+ }
309
+
310
+ def move_model_group (
311
+ self , source_collection_name : str , model_group : str , destination_collection_name : str
312
+ ):
313
+ """To move a model package group from one collection to another
314
+
315
+ Args:
316
+ source_collection_name (str): Collection name of the source
317
+ model_group (str): Model package group names which is to be moved
318
+ destination_collection_name (str): Collection name of the destination
319
+ """
320
+ remove_details = self .remove_model_groups (
321
+ collection_name = source_collection_name , model_groups = [model_group ]
322
+ )
323
+ if len (remove_details ["failure" ]) == 1 :
324
+ raise Exception (remove_details ["failure" ][0 ]["failure" ])
325
+
326
+ added_details = self .add_model_groups (
327
+ collection_name = destination_collection_name , model_groups = [model_group ]
328
+ )
329
+
330
+ if len (added_details ["failure" ]) == 1 :
331
+ # adding the model group back to the source collection in case of an add failure
332
+ self .add_model_groups (
333
+ collection_name = source_collection_name , model_groups = [model_group ]
334
+ )
335
+ raise Exception (added_details ["failure" ][0 ]["failure" ])
336
+
337
+ return {
338
+ "moved_success" : model_group ,
339
+ }
0 commit comments