Skip to content

Commit 7fea237

Browse files
authored
Update hyperparameter tuning/analytics docstrings (#215)
1 parent 5c00695 commit 7fea237

File tree

7 files changed

+272
-75
lines changed

7 files changed

+272
-75
lines changed

doc/analytics.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
Analytics
2+
---------
3+
4+
.. autoclass:: sagemaker.analytics.AnalyticsMetricsBase
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:
8+
9+
.. autoclass:: sagemaker.analytics.HyperparameterTuningJobAnalytics
10+
:members:
11+
:undoc-members:
12+
:show-inheritance:
13+
14+
.. autoclass:: sagemaker.analytics.TrainingJobAnalytics
15+
:members:
16+
:undoc-members:
17+
:show-inheritance:

doc/index.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ Amazon SageMaker Python SDK is an open source library for training and deploying
44

55
With the SDK, you can train and deploy models using popular deep learning frameworks: **Apache MXNet** and **TensorFlow**. You can also train and deploy models with **algorithms provided by Amazon**, these are scalable implementations of core machine learning algorithms that are optimized for SageMaker and GPU training. If you have **your own algorithms** built into SageMaker-compatible Docker containers, you can train and host models using these as well.
66

7-
Here you'll find API docs for SageMaker Python SDK. The project home-page is in Github: https://github.com/aws/sagemaker-python-sdk, there you can find the SDK source, installation instructions and a general overview of the library there.
7+
Here you'll find API docs for SageMaker Python SDK. The project home-page is in Github: https://github.com/aws/sagemaker-python-sdk, there you can find the SDK source, installation instructions and a general overview of the library there.
88

99
Overview
1010
----------
@@ -14,9 +14,11 @@ The SageMaker Python SDK consists of a few primary interfaces:
1414
:maxdepth: 2
1515

1616
estimators
17+
tuner
1718
predictors
1819
session
1920
model
21+
analytics
2022

2123
MXNet
2224
----------

doc/tuner.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
HyperparameterTuner
2+
-------------------
3+
4+
.. autoclass:: sagemaker.tuner.HyperparameterTuner
5+
:members:
6+
:undoc-members:
7+
:show-inheritance:
8+
9+
.. autoclass:: sagemaker.tuner.ContinuousParameter
10+
:members:
11+
:undoc-members:
12+
:show-inheritance:
13+
14+
.. autoclass:: sagemaker.tuner.IntegerParameter
15+
:members:
16+
:undoc-members:
17+
:show-inheritance:
18+
19+
.. autoclass:: sagemaker.tuner.CategoricalParameter
20+
:members:
21+
:undoc-members:
22+
:show-inheritance:

src/sagemaker/analytics.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -64,25 +64,24 @@ def _fetch_dataframe(self):
6464
pass
6565

6666
def clear_cache(self):
67-
"""Clears the object of all local caches of API methods, so
67+
"""Clear the object of all local caches of API methods, so
6868
that the next time any properties are accessed they will be refreshed from
6969
the service.
7070
"""
7171
self._dataframe = None
7272

7373

7474
class HyperparameterTuningJobAnalytics(AnalyticsMetricsBase):
75-
"""Fetches results about this tuning job and makes them accessible for analytics.
75+
"""Fetch results about a hyperparameter tuning job and make them accessible for analytics.
7676
"""
7777

7878
def __init__(self, hyperparameter_tuning_job_name, sagemaker_session=None):
79-
"""Initialize an ``HyperparameterTuningJobAnalytics`` instance.
79+
"""Initialize a ``HyperparameterTuningJobAnalytics`` instance.
8080
8181
Args:
82-
hyperparameter_tuning_job_name (str): name of the HyperparameterTuningJob to
83-
analyze.
82+
hyperparameter_tuning_job_name (str): name of the HyperparameterTuningJob to analyze.
8483
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
85-
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
84+
Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created
8685
using the default AWS configuration chain.
8786
"""
8887
sagemaker_session = sagemaker_session or Session()
@@ -100,16 +99,16 @@ def __repr__(self):
10099
return "<sagemaker.HyperparameterTuningJobAnalytics for %s>" % self.name
101100

102101
def clear_cache(self):
103-
"""Clears the object of all local caches of API methods.
102+
"""Clear the object of all local caches of API methods.
104103
"""
105104
super(HyperparameterTuningJobAnalytics, self).clear_cache()
106105
self._tuning_job_describe_result = None
107106
self._training_job_summaries = None
108107

109108
def _fetch_dataframe(self):
110-
"""Returns a pandas dataframe with all the training jobs, their
111-
hyperparameters, results, and metadata about the training jobs.
112-
Includes a column to indicate that any job was the best seen so far.
109+
"""Return a pandas dataframe with all the training jobs, along with their
110+
hyperparameters, results, and metadata. This also includes a column to indicate
111+
if a training job was the best seen so far.
113112
"""
114113
def reshape(training_summary):
115114
# Helper method to reshape a single training job summary into a dataframe record
@@ -139,8 +138,8 @@ def reshape(training_summary):
139138

140139
@property
141140
def tuning_ranges(self):
142-
"""A dict describing the ranges of all tuned hyperparameters.
143-
Dict's key is the name of the hyper param. Dict's value is the range.
141+
"""A dictionary describing the ranges of all tuned hyperparameters.
142+
The keys are the names of the hyperparameter, and the values are the ranges.
144143
"""
145144
out = {}
146145
for _, ranges in self.description()['HyperParameterTuningJobConfig']['ParameterRanges'].items():
@@ -149,10 +148,13 @@ def tuning_ranges(self):
149148
return out
150149

151150
def description(self, force_refresh=False):
152-
"""Response to DescribeHyperParameterTuningJob
151+
"""Call ``DescribeHyperParameterTuningJob`` for the hyperparameter tuning job.
153152
154153
Args:
155154
force_refresh (bool): Set to True to fetch the latest data from SageMaker API.
155+
156+
Returns:
157+
dict: The Amazon SageMaker response for ``DescribeHyperParameterTuningJob``.
156158
"""
157159
if force_refresh:
158160
self.clear_cache()
@@ -163,10 +165,13 @@ def description(self, force_refresh=False):
163165
return self._tuning_job_describe_result
164166

165167
def training_job_summaries(self, force_refresh=False):
166-
"""A list of everything (paginated) from ListTrainingJobsForTuningJob
168+
"""A (paginated) list of everything from ``ListTrainingJobsForTuningJob``.
167169
168170
Args:
169171
force_refresh (bool): Set to True to fetch the latest data from SageMaker API.
172+
173+
Returns:
174+
dict: The Amazon SageMaker response for ``ListTrainingJobsForTuningJob``.
170175
"""
171176
if force_refresh:
172177
self.clear_cache()
@@ -191,19 +196,19 @@ def training_job_summaries(self, force_refresh=False):
191196

192197

193198
class TrainingJobAnalytics(AnalyticsMetricsBase):
194-
"""Fetches training curve data from CloudWatch Metrics for a specific training job.
199+
"""Fetch training curve data from CloudWatch Metrics for a specific training job.
195200
"""
196201

197202
CLOUDWATCH_NAMESPACE = '/aws/sagemaker/HyperParameterTuningJobs'
198203

199204
def __init__(self, training_job_name, metric_names, sagemaker_session=None):
200-
"""Initialize an ``TrainingJobAnalytics`` instance.
205+
"""Initialize a ``TrainingJobAnalytics`` instance.
201206
202207
Args:
203208
training_job_name (str): name of the TrainingJob to analyze.
204209
metric_names (list): string names of all the metrics to collect for this training job
205210
sagemaker_session (sagemaker.session.Session): Session object which manages interactions with
206-
Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one
211+
Amazon SageMaker APIs and any other AWS services needed. If not specified, one is specified
207212
using the default AWS configuration chain.
208213
"""
209214
sagemaker_session = sagemaker_session or Session()
@@ -223,7 +228,7 @@ def __repr__(self):
223228
return "<sagemaker.TrainingJobAnalytics for %s>" % self.name
224229

225230
def clear_cache(self):
226-
"""Clears the object of all local caches of API methods, so
231+
"""Clear the object of all local caches of API methods, so
227232
that the next time any properties are accessed they will be refreshed from
228233
the service.
229234
"""
@@ -232,7 +237,7 @@ def clear_cache(self):
232237
self._time_interval = self._determine_timeinterval()
233238

234239
def _determine_timeinterval(self):
235-
"""Returns a dict with two datetime objects, start_time and end_time
240+
"""Return a dictionary with two datetime objects, start_time and end_time,
236241
covering the interval of the training job
237242
"""
238243
description = self._sage_client.describe_training_job(TrainingJobName=self.name)
@@ -249,7 +254,7 @@ def _fetch_dataframe(self):
249254
return pd.DataFrame(self._data)
250255

251256
def _fetch_metric(self, metric_name):
252-
"""Fetches all the values of a named metric, and adds them to _data
257+
"""Fetch all the values of a named metric, and add them to _data
253258
"""
254259
request = {
255260
'Namespace': self.CLOUDWATCH_NAMESPACE,
@@ -284,7 +289,7 @@ def _fetch_metric(self, metric_name):
284289
self._add_single_metric(elapsed_seconds, metric_name, value)
285290

286291
def _add_single_metric(self, timestamp, metric_name, value):
287-
"""Stores a single metric in the _data dict which can be
292+
"""Store a single metric in the _data dict which can be
288293
converted to a dataframe.
289294
"""
290295
# note that this method is built this way to make it possible to

src/sagemaker/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ def delete_endpoint(self):
319319

320320
@property
321321
def training_job_analytics(self):
322-
"""Returns a TrainingJobAnalytics object for the current training job.
322+
"""Return a ``TrainingJobAnalytics`` object for the current training job.
323323
"""
324324
if self._current_job_name is None:
325325
raise ValueError('Estimator is not associated with a TrainingJob')

src/sagemaker/session.py

Lines changed: 33 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,12 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
222222
job_name (str): Name of the training job being created.
223223
output_config (dict): The S3 URI where you want to store the training results and optional KMS key ID.
224224
resource_config (dict): Contains values for ResourceConfig:
225+
225226
* instance_count (int): Number of EC2 instances to use for training.
226227
The key in resource_config is 'InstanceCount'.
227228
* instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
228229
The key in resource_config is 'InstanceType'.
230+
229231
hyperparameters (dict): Hyperparameters for model training. The hyperparameters are made accessible as
230232
a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for
231233
keys and values, but ``str()`` will be called to convert them before training.
@@ -269,22 +271,28 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
269271
270272
Args:
271273
job_name (str): Name of the tuning job being created.
272-
strategy (str): Strategy to be used.
273-
objective_type (str): Minimize/Maximize
274-
objective_metric_name (str): Name of the metric to use when evaluating training job.
275-
max_jobs (int): Maximum total number of jobs to start.
276-
max_parallel_jobs (int): Maximum number of parallel jobs to start.
277-
parameter_ranges (dict): Parameter ranges in a dictionary of types: Continuous, Integer, Categorical
278-
static_hyperparameters (dict): Hyperparameters for model training. The hyperparameters are made accessible
279-
as a dict[str, str] to the training code on SageMaker. For convenience, this accepts other types for
280-
keys and values, but ``str()`` will be called to convert them before training.
274+
strategy (str): Strategy to be used for hyperparameter estimations.
275+
objective_type (str): The type of the objective metric for evaluating training jobs. This value can be
276+
either 'Minimize' or 'Maximize'.
277+
objective_metric_name (str): Name of the metric for evaluating training jobs.
278+
max_jobs (int): Maximum total number of training jobs to start for the hyperparameter tuning job.
279+
max_parallel_jobs (int): Maximum number of parallel training jobs to start.
280+
parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can be one of three types:
281+
Continuous, Integer, or Categorical.
282+
static_hyperparameters (dict): Hyperparameters for model training. These hyperparameters remain
283+
unchanged across all of the training jobs for the hyperparameter tuning job. The hyperparameters are
284+
made accessible as a dictionary for the training code on SageMaker.
281285
image (str): Docker image containing training code.
282286
input_mode (str): The input mode that the algorithm supports. Valid modes:
283287
284288
* 'File' - Amazon SageMaker copies the training dataset from the S3 location to
285289
a directory in the Docker container.
286290
* 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe.
287-
metric_definitions (list[dict]): Metrics definition with 'name' and 'regex' keys.
291+
292+
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the
293+
training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for
294+
the regular expression used to extract the metric from the logs. This should be defined only for
295+
hyperparameter tuning jobs that don't use an Amazon algorithm.
288296
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs
289297
that create Amazon SageMaker endpoints use this role to access training data and model artifacts.
290298
You must grant sufficient permissions to this role.
@@ -293,11 +301,15 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
293301
https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job
294302
output_config (dict): The S3 URI where you want to store the training results and optional KMS key ID.
295303
resource_config (dict): Contains values for ResourceConfig:
296-
instance_count (int): Number of EC2 instances to use for training.
297-
instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
298-
stop_condition (dict): Defines when training shall finish. Contains entries that can be understood by the
299-
service like ``MaxRuntimeInSeconds``.
300-
tags (list[dict]): List of tags for labeling the tuning job.
304+
305+
* instance_count (int): Number of EC2 instances to use for training.
306+
The key in resource_config is 'InstanceCount'.
307+
* instance_type (str): Type of EC2 instance to use for training, for example, 'ml.c4.xlarge'.
308+
The key in resource_config is 'InstanceType'.
309+
310+
stop_condition (dict): When training should finish, e.g. ``MaxRuntimeInSeconds``.
311+
tags (list[dict]): List of tags for labeling the tuning job. For more, see
312+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
301313
"""
302314
tune_request = {
303315
'HyperParameterTuningJobName': job_name,
@@ -338,10 +350,13 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
338350
self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request)
339351

340352
def stop_tuning_job(self, name):
341-
"""Attempts to stop tuning job on Amazon SageMaker with specified name.
353+
"""Stop the Amazon SageMaker hyperparameter tuning job with the specified name.
342354
343355
Args:
344-
name: Name of Amazon SageMaker tuning job.
356+
name (str): Name of the Amazon SageMaker hyperparameter tuning job.
357+
358+
Raises:
359+
ClientError: If an error occurs while trying to stop the hyperparameter tuning job.
345360
"""
346361
try:
347362
LOGGER.info('Stopping tuning job: {}'.format(name))
@@ -491,7 +506,7 @@ def wait_for_job(self, job, poll=5):
491506
return desc
492507

493508
def wait_for_tuning_job(self, job, poll=5):
494-
"""Wait for an Amazon SageMaker tuning job to complete.
509+
"""Wait for an Amazon SageMaker hyperparameter tuning job to complete.
495510
496511
Args:
497512
job (str): Name of the tuning job to wait for.

0 commit comments

Comments
 (0)