23
23
import smdebug_rulesconfig as rule_configs # noqa: F401 # pylint: disable=unused-import
24
24
25
25
26
+ RULES_ECR_REPO_NAME = "sagemaker-debugger-rules"
27
+
28
+ SAGEMAKER_RULE_CONTAINERS_ACCOUNTS_MAP = {
29
+ "eu-north-1" : {RULES_ECR_REPO_NAME : "314864569078" },
30
+ "me-south-1" : {RULES_ECR_REPO_NAME : "986000313247" },
31
+ "ap-south-1" : {RULES_ECR_REPO_NAME : "904829902805" },
32
+ "eu-west-3" : {RULES_ECR_REPO_NAME : "447278800020" },
33
+ "us-east-2" : {RULES_ECR_REPO_NAME : "915447279597" },
34
+ "eu-west-1" : {RULES_ECR_REPO_NAME : "929884845733" },
35
+ "eu-central-1" : {RULES_ECR_REPO_NAME : "482524230118" },
36
+ "sa-east-1" : {RULES_ECR_REPO_NAME : "818342061345" },
37
+ "ap-east-1" : {RULES_ECR_REPO_NAME : "199566480951" },
38
+ "us-east-1" : {RULES_ECR_REPO_NAME : "503895931360" },
39
+ "ap-northeast-2" : {RULES_ECR_REPO_NAME : "578805364391" },
40
+ "eu-west-2" : {RULES_ECR_REPO_NAME : "250201462417" },
41
+ "ap-northeast-1" : {RULES_ECR_REPO_NAME : "430734990657" },
42
+ "us-west-2" : {RULES_ECR_REPO_NAME : "895741380848" },
43
+ "us-west-1" : {RULES_ECR_REPO_NAME : "685455198987" },
44
+ "ap-southeast-1" : {RULES_ECR_REPO_NAME : "972752614525" },
45
+ "ap-southeast-2" : {RULES_ECR_REPO_NAME : "184798709955" },
46
+ "ca-central-1" : {RULES_ECR_REPO_NAME : "519511493484" },
47
+ }
48
+
49
+
50
+ def get_rule_container_image_uri (region ):
51
+ """
52
+ Returns the rule image uri for the given AWS region and rule type
53
+
54
+ Args:
55
+ region: AWS Region
56
+
57
+ Returns:
58
+ str: Formatted image uri for the given region and the rule container type
59
+ """
60
+ registry_id = SAGEMAKER_RULE_CONTAINERS_ACCOUNTS_MAP .get (region ).get (RULES_ECR_REPO_NAME )
61
+ return "{}.dkr.ecr.{}.amazonaws.com/{}:latest" .format (registry_id , region , RULES_ECR_REPO_NAME )
62
+
63
+
26
64
class Rule (object ):
27
65
"""Rules analyze tensors emitted during the training of a model. They
28
66
monitor conditions that are critical for the success of a training job.
@@ -40,7 +78,7 @@ def __init__(
40
78
name ,
41
79
image_uri ,
42
80
instance_type ,
43
- container_local_path ,
81
+ container_local_output_path ,
44
82
s3_output_path ,
45
83
volume_size_in_gb ,
46
84
rule_parameters ,
@@ -58,7 +96,7 @@ def __init__(
58
96
image_uri (str): The URI of the image to be used by the debugger rule.
59
97
instance_type (str): Type of EC2 instance to use, for example,
60
98
'ml.c4.xlarge'.
61
- container_local_path (str): The path in the container .
99
+ container_local_output_path (str): The local path to store the Rule output .
62
100
s3_output_path (str): The location in S3 to store the output.
63
101
volume_size_in_gb (int): Size in GB of the EBS volume
64
102
to use for storing data.
@@ -68,7 +106,7 @@ def __init__(
68
106
"""
69
107
self .name = name
70
108
self .instance_type = instance_type
71
- self .container_local_path = container_local_path
109
+ self .container_local_output_path = container_local_output_path
72
110
self .s3_output_path = s3_output_path
73
111
self .volume_size_in_gb = volume_size_in_gb
74
112
self .rule_parameters = rule_parameters
@@ -80,10 +118,8 @@ def sagemaker(
80
118
cls ,
81
119
base_config ,
82
120
name = None ,
83
- instance_type = None ,
84
- container_local_path = None ,
121
+ container_local_output_path = None ,
85
122
s3_output_path = None ,
86
- volume_size_in_gb = None ,
87
123
other_trials_s3_input_paths = None ,
88
124
rule_parameters = None ,
89
125
collections_to_save = None ,
@@ -98,13 +134,8 @@ def sagemaker(
98
134
built-in list of rules. For example, 'rule_configs.dead_relu()'.
99
135
name (str): The name of the debugger rule. If one is not provided,
100
136
the name of the base_config will be used.
101
- instance_type (str): Type of EC2 instance to use, for example,
102
- 'ml.c4.xlarge'. If one is not provided, the instance type from
103
- the base_config will be used.
104
- container_local_path (str): The path in the container.
137
+ container_local_output_path (str): The path in the container.
105
138
s3_output_path (str): The location in S3 to store the output.
106
- volume_size_in_gb (int): Size in GB of the EBS volume
107
- to use for storing data.
108
139
other_trials_s3_input_paths ([str]): S3 input paths for other trials.
109
140
rule_parameters (dict): A dictionary of parameters for the rule.
110
141
collections_to_save ([sagemaker.debugger.CollectionConfig]): A list
@@ -113,9 +144,22 @@ def sagemaker(
113
144
Returns:
114
145
sagemaker.debugger.Rule: The instance of the built-in Rule.
115
146
"""
116
- other_trials_params = {}
147
+ merged_rule_params = {}
148
+
149
+ if rule_parameters is not None and rule_parameters .get ("rule_to_invoke" ) is not None :
150
+ raise RuntimeError (
151
+ """You cannot provide a 'rule_to_invoke' for SageMaker rules.
152
+ Please either remove the rule_to_invoke or use a custom rule.
153
+ """
154
+ )
155
+
117
156
if other_trials_s3_input_paths is not None :
118
- other_trials_params ["other_trials_s3_input_paths" ] = other_trials_s3_input_paths
157
+ for index , s3_input_path in enumerate (other_trials_s3_input_paths ):
158
+ merged_rule_params ["other_trial_{}" .format (str (index ))] = s3_input_path
159
+
160
+ default_rule_params = base_config ["DebugRuleConfiguration" ].get ("RuleParameters" , {})
161
+ merged_rule_params .update (default_rule_params )
162
+ merged_rule_params .update (rule_parameters or {})
119
163
120
164
base_config_collections = []
121
165
for config in base_config .get ("CollectionConfigurations" , []):
@@ -133,16 +177,11 @@ def sagemaker(
133
177
return cls (
134
178
name = name or base_config ["DebugRuleConfiguration" ].get ("RuleConfigurationName" ),
135
179
image_uri = "DEFAULT_RULE_EVALUATOR_IMAGE" ,
136
- instance_type = instance_type or "t3.medium" ,
137
- # TODO-reinvent-2019 [akarpur]: Remove t3.medium from line above,
138
- # uncomment line below when 1P package updated
139
- # or base_config["DebugRuleConfiguration"].get("InstanceType"),
140
- container_local_path = container_local_path ,
180
+ instance_type = None ,
181
+ container_local_output_path = container_local_output_path ,
141
182
s3_output_path = s3_output_path ,
142
- volume_size_in_gb = volume_size_in_gb ,
143
- rule_parameters = other_trials_params .update (
144
- rule_parameters or base_config ["DebugRuleConfiguration" ].get ("RuleParameters" , {})
145
- ),
183
+ volume_size_in_gb = None ,
184
+ rule_parameters = merged_rule_params ,
146
185
collections_to_save = collections_to_save or base_config_collections ,
147
186
)
148
187
@@ -154,7 +193,7 @@ def custom(
154
193
instance_type ,
155
194
source = None ,
156
195
rule_to_invoke = None ,
157
- container_local_path = None ,
196
+ container_local_output_path = None ,
158
197
s3_output_path = None ,
159
198
volume_size_in_gb = None ,
160
199
other_trials_s3_input_paths = None ,
@@ -175,7 +214,7 @@ def custom(
175
214
you must also provide rule_to_invoke.
176
215
rule_to_invoke (str): The name of the rule to invoke within the source.
177
216
If provided, you must also provide source.
178
- container_local_path (str): The path in the container.
217
+ container_local_output_path (str): The path in the container.
179
218
s3_output_path (str): The location in S3 to store the output.
180
219
volume_size_in_gb (int): Size in GB of the EBS volume
181
220
to use for storing data.
@@ -192,25 +231,28 @@ def custom(
192
231
"If you provide a source, you must also provide a rule to invoke (and vice versa)."
193
232
)
194
233
195
- source_params = {}
234
+ merged_rule_params = {}
235
+
196
236
if source is not None and rule_to_invoke is not None :
197
- source_params ["source_s3_uri" ] = source
198
- source_params ["rule_to_invoke" ] = rule_to_invoke
237
+ merged_rule_params ["source_s3_uri" ] = source
238
+ merged_rule_params ["rule_to_invoke" ] = rule_to_invoke
199
239
200
240
other_trials_params = {}
201
241
if other_trials_s3_input_paths is not None :
202
- other_trials_params ["other_trials_s3_input_paths" ] = other_trials_s3_input_paths
242
+ for index , s3_input_path in enumerate (other_trials_s3_input_paths ):
243
+ other_trials_params ["other_trial_{}" .format (str (index ))] = s3_input_path
203
244
204
- combined_rule_params = source_params .update (other_trials_params ) or {}
245
+ merged_rule_params .update (other_trials_params )
246
+ merged_rule_params .update (rule_parameters or {})
205
247
206
248
return cls (
207
249
name = name ,
208
250
image_uri = image_uri ,
209
251
instance_type = instance_type ,
210
- container_local_path = container_local_path ,
252
+ container_local_output_path = container_local_output_path ,
211
253
s3_output_path = s3_output_path ,
212
254
volume_size_in_gb = volume_size_in_gb ,
213
- rule_parameters = combined_rule_params . update ( rule_parameters or {}) ,
255
+ rule_parameters = merged_rule_params ,
214
256
collections_to_save = collections_to_save or [],
215
257
)
216
258
@@ -221,21 +263,19 @@ def to_debugger_rule_config_dict(self):
221
263
Returns:
222
264
dict: An portion of an API request as a dictionary.
223
265
"""
224
- if self .instance_type is None or self .volume_size_in_gb is None :
225
- raise RuntimeError (
226
- """Cannot create a dictionary if the instance type and volume size are not provided.
227
- Please set the instance type and volume size for this Rule object."""
228
- )
229
-
230
266
debugger_rule_config_request = {
231
267
"RuleConfigurationName" : self .name ,
232
268
"RuleEvaluatorImage" : self .image_uri ,
233
- "InstanceType" : self .instance_type ,
234
- "VolumeSizeInGB" : self .volume_size_in_gb ,
235
269
}
236
270
237
- if self .container_local_path is not None :
238
- debugger_rule_config_request ["LocalPath" ] = self .container_local_path
271
+ if self .instance_type is not None :
272
+ debugger_rule_config_request ["InstanceType" ] = self .instance_type
273
+
274
+ if self .volume_size_in_gb is not None :
275
+ debugger_rule_config_request ["VolumeSizeInGB" ] = self .volume_size_in_gb
276
+
277
+ if self .container_local_output_path is not None :
278
+ debugger_rule_config_request ["LocalPath" ] = self .container_local_output_path
239
279
240
280
if self .s3_output_path is not None :
241
281
debugger_rule_config_request ["S3OutputPath" ] = self .s3_output_path
@@ -254,7 +294,7 @@ class DebuggerHookConfig(object):
254
294
def __init__ (
255
295
self ,
256
296
s3_output_path ,
257
- container_local_path = None ,
297
+ container_local_output_path = None ,
258
298
hook_parameters = None ,
259
299
collection_configs = None ,
260
300
):
@@ -264,13 +304,13 @@ def __init__(
264
304
265
305
Args:
266
306
s3_output_path (str): The location in S3 to store the output.
267
- container_local_path (str): The path in the container.
307
+ container_local_output_path (str): The path in the container.
268
308
hook_parameters (dict): A dictionary of parameters.
269
309
collection_configs ([sagemaker.debugger.CollectionConfig]): A list
270
310
of CollectionConfig objects to be provided to the API.
271
311
"""
272
312
self .s3_output_path = s3_output_path
273
- self .container_local_path = container_local_path
313
+ self .container_local_output_path = container_local_output_path
274
314
self .hook_parameters = hook_parameters
275
315
self .collection_configs = collection_configs
276
316
@@ -283,8 +323,8 @@ def to_request_dict(self):
283
323
"""
284
324
debugger_hook_config_request = {"S3OutputPath" : self .s3_output_path }
285
325
286
- if self .container_local_path is not None :
287
- debugger_hook_config_request ["LocalPath" ] = self .container_local_path
326
+ if self .container_local_output_path is not None :
327
+ debugger_hook_config_request ["LocalPath" ] = self .container_local_output_path
288
328
289
329
if self .hook_parameters is not None :
290
330
debugger_hook_config_request ["HookParameters" ] = self .hook_parameters
@@ -301,17 +341,17 @@ class TensorBoardOutputConfig(object):
301
341
"""TensorBoardOutputConfig provides options to customize
302
342
debugging visualization using TensorBoard."""
303
343
304
- def __init__ (self , s3_output_path , container_local_path = None ):
344
+ def __init__ (self , s3_output_path , container_local_output_path = None ):
305
345
"""Initialize an instance of TensorBoardOutputConfig.
306
346
TensorBoardOutputConfig provides options to customize
307
347
debugging visualization using TensorBoard.
308
348
309
349
Args:
310
350
s3_output_path (str): The location in S3 to store the output.
311
- container_local_path (str): The path in the container.
351
+ container_local_output_path (str): The path in the container.
312
352
"""
313
353
self .s3_output_path = s3_output_path
314
- self .container_local_path = container_local_path
354
+ self .container_local_output_path = container_local_output_path
315
355
316
356
def to_request_dict (self ):
317
357
"""Generates a request dictionary using the parameters provided
@@ -322,16 +362,16 @@ def to_request_dict(self):
322
362
"""
323
363
tensorboard_output_config_request = {"S3OutputPath" : self .s3_output_path }
324
364
325
- if self .container_local_path is not None :
326
- tensorboard_output_config_request ["LocalPath" ] = self .container_local_path
365
+ if self .container_local_output_path is not None :
366
+ tensorboard_output_config_request ["LocalPath" ] = self .container_local_output_path
327
367
328
368
return tensorboard_output_config_request
329
369
330
370
331
371
class CollectionConfig (object ):
332
372
"""CollectionConfig object for SageMaker Debugger."""
333
373
334
- def __init__ (self , name , parameters ):
374
+ def __init__ (self , name , parameters = None ):
335
375
"""Initialize a ``CollectionConfig`` object.
336
376
337
377
Args:
@@ -359,7 +399,7 @@ def __ne__(self, other):
359
399
return self .name != other .name or self .parameters != other .parameters
360
400
361
401
def __hash__ (self ):
362
- return hash ((self .name , tuple (sorted (self .parameters .items ()))))
402
+ return hash ((self .name , tuple (sorted (( self .parameters or {}) .items ()))))
363
403
364
404
def to_request_dict (self ):
365
405
"""Generates a request dictionary using the parameters provided
@@ -368,9 +408,9 @@ def to_request_dict(self):
368
408
Returns:
369
409
dict: An portion of an API request as a dictionary.
370
410
"""
371
- collection_config_request = {
372
- "CollectionName" : self . name ,
373
- "CollectionParameters" : self . parameters ,
374
- }
411
+ collection_config_request = {"CollectionName" : self . name }
412
+
413
+ if self . parameters is not None :
414
+ collection_config_request [ "CollectionParameters" ] = self . parameters
375
415
376
416
return collection_config_request
0 commit comments