Skip to content

Commit a8b462f

Browse files
ajaykarpurknakad
authored andcommitted
Add support for SageMaker Debugger [WIP] (#247)
1 parent 2a86ac6 commit a8b462f

25 files changed

+7348
-5710
lines changed

buildspec.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ version: 0.2
33
phases:
44
build:
55
commands:
6+
# TODO-reinvent-2019 [akarpur]: Remove this (adding internal boto models)
7+
- aws configure add-model --service-model file://./tests/data/boto_models/sagemaker/2017-07-24/normal.json --service-name sagemaker
8+
69
- IGNORE_COVERAGE=-
710

811
# run integration tests

src/sagemaker/debugger.py

Lines changed: 376 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,376 @@
1+
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Amazon SageMaker Debugger is a service that provides full visibility
14+
into the training of machine learning (ML) models, enabling customers
15+
to automatically detect several classes of errors. Customers can configure
16+
Debugger when starting their training jobs by specifying debug level, models,
17+
and location where debug output will be stored. Optionally, customers can
18+
also specify custom error conditions that they want to be alerted on.
19+
Debugger automatically collects model specific data, monitors for errors,
20+
and alerts when it detects errors during training.
21+
"""
22+
from __future__ import absolute_import
23+
import smdebug_rulesconfig as rule_configs # noqa: F401 # pylint: disable=unused-import
24+
25+
26+
class Rule(object):
27+
"""Rules analyze tensors emitted during the training of a model. They
28+
monitor conditions that are critical for the success of a training job.
29+
30+
For example, they can detect whether gradients are getting too large or
31+
too small or if a model is being overfit. Debugger comes pre-packaged with
32+
certain built-in rules (created using the Rule.sagemaker classmethod).
33+
You can use these rules or write your own rules using the Amazon SageMaker
34+
Debugger APIs. You can also analyze raw tensor data without using rules in,
35+
for example, an Amazon SageMaker notebook, using Debugger's full set of APIs.
36+
"""
37+
38+
def __init__(
39+
self,
40+
name,
41+
image_uri,
42+
instance_type,
43+
container_local_path,
44+
s3_output_path,
45+
volume_size_in_gb,
46+
rule_parameters,
47+
collections_to_save,
48+
):
49+
"""Do not use this initialization method. Instead, use either the
50+
``Rule.sagemaker`` or ``Rule.custom`` class method.
51+
52+
Initialize a ``Rule`` instance. The Rule analyzes tensors emitted
53+
during the training of a model and monitors conditions that are critical
54+
for the success of a training job.
55+
56+
Args:
57+
name (str): The name of the debugger rule.
58+
image_uri (str): The URI of the image to be used by the debugger rule.
59+
instance_type (str): Type of EC2 instance to use, for example,
60+
'ml.c4.xlarge'.
61+
container_local_path (str): The path in the container.
62+
s3_output_path (str): The location in S3 to store the output.
63+
volume_size_in_gb (int): Size in GB of the EBS volume
64+
to use for storing data.
65+
rule_parameters (dict): A dictionary of parameters for the rule.
66+
collections_to_save ([sagemaker.debugger.CollectionConfig]): A list
67+
of CollectionConfig objects to be saved.
68+
"""
69+
self.name = name
70+
self.instance_type = instance_type
71+
self.container_local_path = container_local_path
72+
self.s3_output_path = s3_output_path
73+
self.volume_size_in_gb = volume_size_in_gb
74+
self.rule_parameters = rule_parameters
75+
self.collection_configs = collections_to_save
76+
self.image_uri = image_uri
77+
78+
@classmethod
79+
def sagemaker(
80+
cls,
81+
base_config,
82+
name=None,
83+
instance_type=None,
84+
container_local_path=None,
85+
s3_output_path=None,
86+
volume_size_in_gb=None,
87+
other_trials_s3_input_paths=None,
88+
rule_parameters=None,
89+
collections_to_save=None,
90+
):
91+
"""Initialize a ``Rule`` instance for a built-in SageMaker Debugging
92+
Rule. The Rule analyzes tensors emitted during the training of a model
93+
and monitors conditions that are critical for the success of a training
94+
job.
95+
96+
Args:
97+
base_config (dict): This is the base rule config returned from the
98+
built-in list of rules. For example, 'rule_configs.dead_relu()'.
99+
name (str): The name of the debugger rule. If one is not provided,
100+
the name of the base_config will be used.
101+
instance_type (str): Type of EC2 instance to use, for example,
102+
'ml.c4.xlarge'. If one is not provided, the instance type from
103+
the base_config will be used.
104+
container_local_path (str): The path in the container.
105+
s3_output_path (str): The location in S3 to store the output.
106+
volume_size_in_gb (int): Size in GB of the EBS volume
107+
to use for storing data.
108+
other_trials_s3_input_paths ([str]): S3 input paths for other trials.
109+
rule_parameters (dict): A dictionary of parameters for the rule.
110+
collections_to_save ([sagemaker.debugger.CollectionConfig]): A list
111+
of CollectionConfig objects to be saved.
112+
113+
Returns:
114+
sagemaker.debugger.Rule: The instance of the built-in Rule.
115+
"""
116+
other_trials_params = {}
117+
if other_trials_s3_input_paths is not None:
118+
other_trials_params["other_trials_s3_input_paths"] = other_trials_s3_input_paths
119+
120+
base_config_collections = []
121+
for config in base_config.get("CollectionConfigurations", []):
122+
collection_name = None
123+
collection_parameters = {}
124+
for key, value in config.items():
125+
if key == "CollectionName":
126+
collection_name = value
127+
if key == "CollectionParameters":
128+
collection_parameters = value
129+
base_config_collections.append(
130+
CollectionConfig(name=collection_name, parameters=collection_parameters)
131+
)
132+
133+
return cls(
134+
name=name or base_config["DebugRuleConfiguration"].get("RuleConfigurationName"),
135+
image_uri="DEFAULT_RULE_EVALUATOR_IMAGE",
136+
instance_type=instance_type or "t3.medium",
137+
# TODO-reinvent-2019 [akarpur]: Remove t3.medium from line above,
138+
# uncomment line below when 1P package updated
139+
# or base_config["DebugRuleConfiguration"].get("InstanceType"),
140+
container_local_path=container_local_path,
141+
s3_output_path=s3_output_path,
142+
volume_size_in_gb=volume_size_in_gb,
143+
rule_parameters=other_trials_params.update(
144+
rule_parameters or base_config["DebugRuleConfiguration"].get("RuleParameters", {})
145+
),
146+
collections_to_save=collections_to_save or base_config_collections,
147+
)
148+
149+
@classmethod
150+
def custom(
151+
cls,
152+
name,
153+
image_uri,
154+
instance_type,
155+
source=None,
156+
rule_to_invoke=None,
157+
container_local_path=None,
158+
s3_output_path=None,
159+
volume_size_in_gb=None,
160+
other_trials_s3_input_paths=None,
161+
rule_parameters=None,
162+
collections_to_save=None,
163+
):
164+
"""Initialize a ``Rule`` instance for a custom rule. The Rule
165+
analyzes tensors emitted during the training of a model and
166+
monitors conditions that are critical for the success of a
167+
training job.
168+
169+
Args:
170+
name (str): The name of the debugger rule.
171+
image_uri (str): The URI of the image to be used by the debugger rule.
172+
instance_type (str): Type of EC2 instance to use, for example,
173+
'ml.c4.xlarge'.
174+
source (str): A source file containing a rule to invoke. If provided,
175+
you must also provide rule_to_invoke.
176+
rule_to_invoke (str): The name of the rule to invoke within the source.
177+
If provided, you must also provide source.
178+
container_local_path (str): The path in the container.
179+
s3_output_path (str): The location in S3 to store the output.
180+
volume_size_in_gb (int): Size in GB of the EBS volume
181+
to use for storing data.
182+
other_trials_s3_input_paths ([str]): S3 input paths for other trials.
183+
rule_parameters (dict): A dictionary of parameters for the rule.
184+
collections_to_save ([sagemaker.debugger.CollectionConfig]): A list
185+
of CollectionConfig objects to be saved.
186+
187+
Returns:
188+
sagemaker.debugger.Rule: The instance of the custom Rule.
189+
"""
190+
if bool(source) ^ bool(rule_to_invoke):
191+
raise ValueError(
192+
"If you provide a source, you must also provide a rule to invoke (and vice versa)."
193+
)
194+
195+
source_params = {}
196+
if source is not None and rule_to_invoke is not None:
197+
source_params["source_s3_uri"] = source
198+
source_params["rule_to_invoke"] = rule_to_invoke
199+
200+
other_trials_params = {}
201+
if other_trials_s3_input_paths is not None:
202+
other_trials_params["other_trials_s3_input_paths"] = other_trials_s3_input_paths
203+
204+
combined_rule_params = source_params.update(other_trials_params) or {}
205+
206+
return cls(
207+
name=name,
208+
image_uri=image_uri,
209+
instance_type=instance_type,
210+
container_local_path=container_local_path,
211+
s3_output_path=s3_output_path,
212+
volume_size_in_gb=volume_size_in_gb,
213+
rule_parameters=combined_rule_params.update(rule_parameters or {}),
214+
collections_to_save=collections_to_save or [],
215+
)
216+
217+
def to_debugger_rule_config_dict(self):
218+
"""Generates a request dictionary using the parameters provided
219+
when initializing the object.
220+
221+
Returns:
222+
dict: An portion of an API request as a dictionary.
223+
"""
224+
if self.instance_type is None or self.volume_size_in_gb is None:
225+
raise RuntimeError(
226+
"""Cannot create a dictionary if the instance type and volume size are not provided.
227+
Please set the instance type and volume size for this Rule object."""
228+
)
229+
230+
debugger_rule_config_request = {
231+
"RuleConfigurationName": self.name,
232+
"RuleEvaluatorImage": self.image_uri,
233+
"InstanceType": self.instance_type,
234+
"VolumeSizeInGB": self.volume_size_in_gb,
235+
}
236+
237+
if self.container_local_path is not None:
238+
debugger_rule_config_request["LocalPath"] = self.container_local_path
239+
240+
if self.s3_output_path is not None:
241+
debugger_rule_config_request["S3OutputPath"] = self.s3_output_path
242+
243+
if self.rule_parameters:
244+
debugger_rule_config_request["RuleParameters"] = self.rule_parameters
245+
246+
return debugger_rule_config_request
247+
248+
249+
class DebuggerHookConfig(object):
250+
"""DebuggerHookConfig provides options to customize how debugging
251+
information is emitted.
252+
"""
253+
254+
def __init__(
255+
self,
256+
s3_output_path,
257+
container_local_path=None,
258+
hook_parameters=None,
259+
collection_configs=None,
260+
):
261+
"""Initialize an instance of ``DebuggerHookConfig``.
262+
DebuggerHookConfig provides options to customize how debugging
263+
information is emitted.
264+
265+
Args:
266+
s3_output_path (str): The location in S3 to store the output.
267+
container_local_path (str): The path in the container.
268+
hook_parameters (dict): A dictionary of parameters.
269+
collection_configs ([sagemaker.debugger.CollectionConfig]): A list
270+
of CollectionConfig objects to be provided to the API.
271+
"""
272+
self.s3_output_path = s3_output_path
273+
self.container_local_path = container_local_path
274+
self.hook_parameters = hook_parameters
275+
self.collection_configs = collection_configs
276+
277+
def to_request_dict(self):
278+
"""Generates a request dictionary using the parameters provided
279+
when initializing the object.
280+
281+
Returns:
282+
dict: An portion of an API request as a dictionary.
283+
"""
284+
debugger_hook_config_request = {"S3OutputPath": self.s3_output_path}
285+
286+
if self.container_local_path is not None:
287+
debugger_hook_config_request["LocalPath"] = self.container_local_path
288+
289+
if self.hook_parameters is not None:
290+
debugger_hook_config_request["HookParameters"] = self.hook_parameters
291+
292+
if self.collection_configs is not None:
293+
debugger_hook_config_request["CollectionConfigurations"] = [
294+
collection_config.to_request_dict() for collection_config in self.collection_configs
295+
]
296+
297+
return debugger_hook_config_request
298+
299+
300+
class TensorBoardOutputConfig(object):
301+
"""TensorBoardOutputConfig provides options to customize
302+
debugging visualization using TensorBoard."""
303+
304+
def __init__(self, s3_output_path, container_local_path=None):
305+
"""Initialize an instance of TensorBoardOutputConfig.
306+
TensorBoardOutputConfig provides options to customize
307+
debugging visualization using TensorBoard.
308+
309+
Args:
310+
s3_output_path (str): The location in S3 to store the output.
311+
container_local_path (str): The path in the container.
312+
"""
313+
self.s3_output_path = s3_output_path
314+
self.container_local_path = container_local_path
315+
316+
def to_request_dict(self):
317+
"""Generates a request dictionary using the parameters provided
318+
when initializing the object.
319+
320+
Returns:
321+
dict: An portion of an API request as a dictionary.
322+
"""
323+
tensorboard_output_config_request = {"S3OutputPath": self.s3_output_path}
324+
325+
if self.container_local_path is not None:
326+
tensorboard_output_config_request["LocalPath"] = self.container_local_path
327+
328+
return tensorboard_output_config_request
329+
330+
331+
class CollectionConfig(object):
332+
"""CollectionConfig object for SageMaker Debugger."""
333+
334+
def __init__(self, name, parameters):
335+
"""Initialize a ``CollectionConfig`` object.
336+
337+
Args:
338+
name (str): The name of the collection configuration.
339+
parameters (dict): The parameters for the collection
340+
configuration.
341+
"""
342+
self.name = name
343+
self.parameters = parameters
344+
345+
def __eq__(self, other):
346+
if not isinstance(other, CollectionConfig):
347+
raise TypeError(
348+
"CollectionConfig is only comparable with other CollectionConfig objects."
349+
)
350+
351+
return self.name == other.name and self.parameters == other.parameters
352+
353+
def __ne__(self, other):
354+
if not isinstance(other, CollectionConfig):
355+
raise TypeError(
356+
"CollectionConfig is only comparable with other CollectionConfig objects."
357+
)
358+
359+
return self.name != other.name or self.parameters != other.parameters
360+
361+
def __hash__(self):
362+
return hash((self.name, tuple(sorted(self.parameters.items()))))
363+
364+
def to_request_dict(self):
365+
"""Generates a request dictionary using the parameters provided
366+
when initializing the object.
367+
368+
Returns:
369+
dict: An portion of an API request as a dictionary.
370+
"""
371+
collection_config_request = {
372+
"CollectionName": self.name,
373+
"CollectionParameters": self.parameters,
374+
}
375+
376+
return collection_config_request

0 commit comments

Comments
 (0)