Skip to content

Support MetricDefinitions for general training jobs #484

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Nov 16, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ CHANGELOG
* feature: HyperparameterTuner: add support for Automatic Model Tuning's Warm Start Jobs
* feature: HyperparameterTuner: Make input channels optional
* feature: Add support for Chainer 5.0
* feature: Estimator: add support for MetricDefinitions

1.14.2
======
Expand Down
19 changes: 19 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,25 @@ Here is an end to end example of how to use a SageMaker Estimator:
# Tears down the SageMaker endpoint
mxnet_estimator.delete_endpoint()

Training Metrics
~~~~~~~~~~~~~~~~
The SageMaker Python SDK allows you to specify a name and a regular expression for metrics you want to track for training.
A regular expression (regex) matches what is in the training algorithm logs, like a search function.
Here is an example of how to define metrics:

.. code:: python

# Configure an BYO Estimator with metric definitions (no training happens yet)
byo_estimator = Estimator(image_name=image_name,
role='SageMakerRole', train_instance_count=1,
train_instance_type='ml.c4.xlarge',
sagemaker_session=sagemaker_session,
metric_definitions=[{'Name': 'test:msd', 'Regex': '#quality_metric: host=\S+, test msd <loss>=(\S+)'},
{'Name': 'test:ssd', 'Regex': '#quality_metric: host=\S+, test ssd <loss>=(\S+)'}])

All Amazon SageMaker algorithms come with built-in support for metrics.
You can go to `the AWS documentation <https://docs.aws.amazon.com/sagemaker/latest/dg/algos.html>`__ for more details about built-in metrics of each Amazon SageMaker algorithm.

Local Mode
~~~~~~~~~~

Expand Down
21 changes: 17 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ class EstimatorBase(with_metaclass(ABCMeta, object)):
def __init__(self, role, train_instance_count, train_instance_type,
train_volume_size=30, train_volume_kms_key=None, train_max_run=24 * 60 * 60, input_mode='File',
output_path=None, output_kms_key=None, base_job_name=None, sagemaker_session=None, tags=None,
subnets=None, security_group_ids=None, model_uri=None, model_channel_name='model'):
subnets=None, security_group_ids=None, model_uri=None, model_channel_name='model',
metric_definitions=None):
"""Initialize an ``EstimatorBase`` instance.

Args:
Expand Down Expand Up @@ -97,6 +98,10 @@ def __init__(self, role, train_instance_count, train_instance_type,

More information: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html#td-deserialization
model_channel_name (str): Name of the channel where 'model_uri' will be downloaded (default: 'model').
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the
training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for
the regular expression used to extract the metric from the logs. This should be defined only
for jobs that don't use an Amazon algorithm.
"""
self.role = role
self.train_instance_count = train_instance_count
Expand All @@ -106,6 +111,7 @@ def __init__(self, role, train_instance_count, train_instance_type,
self.train_max_run = train_max_run
self.input_mode = input_mode
self.tags = tags
self.metric_definitions = metric_definitions
self.model_uri = model_uri
self.model_channel_name = model_channel_name

Expand Down Expand Up @@ -324,6 +330,9 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
init_params['hyperparameters'] = job_details['HyperParameters']
init_params['image'] = job_details['AlgorithmSpecification']['TrainingImage']

if 'MetricDefinitons' in job_details['AlgorithmSpecification']:
init_params['metric_definitions'] = job_details['AlgorithmSpecification']['MetricsDefinition']

subnets, security_group_ids = vpc_utils.from_dict(job_details.get(vpc_utils.VPC_CONFIG_KEY))
if subnets:
init_params['subnets'] = subnets
Expand Down Expand Up @@ -441,7 +450,7 @@ def start_new(cls, estimator, inputs):
job_name=estimator._current_job_name, output_config=config['output_config'],
resource_config=config['resource_config'], vpc_config=config['vpc_config'],
hyperparameters=hyperparameters, stop_condition=config['stop_condition'],
tags=estimator.tags)
tags=estimator.tags, metric_definitions=estimator.metric_definitions)

return cls(estimator.sagemaker_session, estimator._current_job_name)

Expand All @@ -466,7 +475,7 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,
train_volume_size=30, train_volume_kms_key=None, train_max_run=24 * 60 * 60,
input_mode='File', output_path=None, output_kms_key=None, base_job_name=None,
sagemaker_session=None, hyperparameters=None, tags=None, subnets=None, security_group_ids=None,
model_uri=None, model_channel_name='model'):
model_uri=None, model_channel_name='model', metric_definitions=None):
"""Initialize an ``Estimator`` instance.

Args:
Expand Down Expand Up @@ -517,14 +526,18 @@ def __init__(self, image_name, role, train_instance_count, train_instance_type,

More information: https://docs.aws.amazon.com/sagemaker/latest/dg/cdf-training.html#td-deserialization
model_channel_name (str): Name of the channel where 'model_uri' will be downloaded (default: 'model').
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the
training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for
the regular expression used to extract the metric from the logs. This should be defined only
for jobs that don't use an Amazon algorithm.
"""
self.image_name = image_name
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
super(Estimator, self).__init__(role, train_instance_count, train_instance_type,
train_volume_size, train_volume_kms_key, train_max_run, input_mode,
output_path, output_kms_key, base_job_name, sagemaker_session,
tags, subnets, security_group_ids, model_uri=model_uri,
model_channel_name=model_channel_name)
model_channel_name=model_channel_name, metric_definitions=metric_definitions)

def train_image(self):
"""
Expand Down
10 changes: 8 additions & 2 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def default_bucket(self):
return self._default_bucket

def train(self, image, input_mode, input_config, role, job_name, output_config,
resource_config, vpc_config, hyperparameters, stop_condition, tags):
resource_config, vpc_config, hyperparameters, stop_condition, tags, metric_definitions):
"""Create an Amazon SageMaker training job.

Args:
Expand Down Expand Up @@ -243,6 +243,9 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
service like ``MaxRuntimeInSeconds``.
tags (list[dict]): List of tags for labeling a training job. For more, see
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the
training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for
the regular expression used to extract the metric from the logs.

Returns:
str: ARN of the training job, if it is created.
Expand All @@ -263,6 +266,9 @@ def train(self, image, input_mode, input_config, role, job_name, output_config,
if input_config is not None:
train_request['InputDataConfig'] = input_config

if metric_definitions is not None:
train_request['AlgorithmSpecification']['MetricDefinitions'] = metric_definitions

if hyperparameters and len(hyperparameters) > 0:
train_request['HyperParameters'] = hyperparameters

Expand Down Expand Up @@ -306,7 +312,7 @@ def tune(self, job_name, strategy, objective_type, objective_metric_name,
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s) used to evaluate the
training jobs. Each dictionary contains two keys: 'Name' for the name of the metric, and 'Regex' for
the regular expression used to extract the metric from the logs. This should be defined only for
hyperparameter tuning jobs that don't use an Amazon algorithm.
jobs that don't use an Amazon algorithm.
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training jobs and APIs
that create Amazon SageMaker endpoints use this role to access training data and model artifacts.
You must grant sufficient permissions to this role.
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def _create_train_job(version):
'MaxRuntimeInSeconds': 24 * 60 * 60
},
'tags': None,
'vpc_config': None
'vpc_config': None,
'metric_definitions': None
}


Expand Down
9 changes: 6 additions & 3 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def test_framework_all_init_args(sagemaker_session):
sagemaker_session=sagemaker_session, train_volume_size=123, train_volume_kms_key='volumekms',
train_max_run=456, input_mode='inputmode', output_path='outputpath', output_kms_key='outputkms',
base_job_name='basejobname', tags=[{'foo': 'bar'}], subnets=['123', '456'],
security_group_ids=['789', '012'])
security_group_ids=['789', '012'],
metric_definitions=[{'Name': 'validation-rmse', 'Regex': 'validation-rmse=(\\d+)'}])
_TrainingJob.start_new(f, 's3://mydata')
sagemaker_session.train.assert_called_once()
_, args = sagemaker_session.train.call_args
Expand All @@ -158,7 +159,8 @@ def test_framework_all_init_args(sagemaker_session):
'stop_condition': {'MaxRuntimeInSeconds': 456},
'role': sagemaker_session.expand_role(), 'job_name': None,
'resource_config': {'VolumeSizeInGB': 123, 'InstanceCount': 3, 'VolumeKmsKeyId': 'volumekms',
'InstanceType': 'ml.m4.xlarge'}}
'InstanceType': 'ml.m4.xlarge'},
'metric_definitions': [{'Name': 'validation-rmse', 'Regex': 'validation-rmse=(\\d+)'}]}


def test_sagemaker_s3_uri_invalid(sagemaker_session):
Expand Down Expand Up @@ -711,7 +713,8 @@ def test_unsupported_type_in_dict():
},
'stop_condition': {'MaxRuntimeInSeconds': 86400},
'tags': None,
'vpc_config': None
'vpc_config': None,
'metric_definitions': None
}

INPUT_CONFIG = [{
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ def _create_train_job(version):
'MaxRuntimeInSeconds': 24 * 60 * 60
},
'tags': None,
'vpc_config': None
'vpc_config': None,
'metric_definitions': None
}


Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def _create_train_job(version):
'MaxRuntimeInSeconds': 24 * 60 * 60
},
'tags': None,
'vpc_config': None
'vpc_config': None,
'metric_definitions': None
}


Expand Down
8 changes: 6 additions & 2 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ def test_s3_input_all_arguments():
JOB_NAME = 'jobname'
TAGS = [{'Name': 'some-tag', 'Value': 'value-for-tag'}]
VPC_CONFIG = {'Subnets': ['foo'], 'SecurityGroupIds': ['bar']}
METRIC_DEFINITONS = [{'Name': 'validation-rmse', 'Regex': 'validation-rmse=(\\d+)'}]

DEFAULT_EXPECTED_TRAIN_JOB_ARGS = {
'OutputDataConfig': {
Expand Down Expand Up @@ -268,7 +269,8 @@ def test_train_pack_to_request(sagemaker_session):

sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
hyperparameters=None, stop_condition=stop_cond, tags=None, vpc_config=VPC_CONFIG)
hyperparameters=None, stop_condition=stop_cond, tags=None, vpc_config=VPC_CONFIG,
metric_definitions=None)

assert sagemaker_session.sagemaker_client.method_calls[0] == (
'create_training_job', (), DEFAULT_EXPECTED_TRAIN_JOB_ARGS)
Expand Down Expand Up @@ -439,13 +441,15 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):

sagemaker_session.train(image=IMAGE, input_mode='File', input_config=in_config, role=EXPANDED_ROLE,
job_name=JOB_NAME, output_config=out_config, resource_config=resource_config,
vpc_config=VPC_CONFIG, hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS)
vpc_config=VPC_CONFIG, hyperparameters=hyperparameters, stop_condition=stop_cond, tags=TAGS,
metric_definitions=METRIC_DEFINITONS)

_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]

assert actual_train_args['VpcConfig'] == VPC_CONFIG
assert actual_train_args['HyperParameters'] == hyperparameters
assert actual_train_args['Tags'] == TAGS
assert actual_train_args['AlgorithmSpecification']['MetricDefinitions'] == METRIC_DEFINITONS


def test_transform_pack_to_request(sagemaker_session):
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_tf_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def _create_train_job(tf_version, script_mode=False, repo_name=IMAGE_REPO_NAME,
'MaxRuntimeInSeconds': 24 * 60 * 60
},
'tags': None,
'vpc_config': None
'vpc_config': None,
'metric_definitions': None
}


Expand Down
3 changes: 2 additions & 1 deletion tests/unit/test_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,8 @@ def test_deploy_default(tuner):
returned_training_job_description = {
'AlgorithmSpecification': {
'TrainingInputMode': 'File',
'TrainingImage': IMAGE_NAME
'TrainingImage': IMAGE_NAME,
'MetricDefinitions': METRIC_DEFINTIONS,
},
'HyperParameters': {
'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"',
Expand Down