Skip to content

Commit c74eeb9

Browse files
apackerlaurenyu
authored andcommitted
Remove MetricDefinition lookup via tuning job in TrainingJobAnalytics (#485)
1 parent 5b18ce4 commit c74eeb9

File tree

3 files changed

+20
-35
lines changed

3 files changed

+20
-35
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ CHANGELOG
2121
* enhancement: Frameworks: update warning for not setting framework_version as we aren't planning a breaking change anymore
2222
* enhancement: Session: remove hardcoded 'training' from job status error message
2323
* bug-fix: Updated Cloudwatch namespace for metrics in TrainingJobsAnalytics
24-
24+
* bug-fix: Changes to use correct s3 bucket and time range for dataframes in TrainingJobAnalytics.
25+
* enhancement: Remove MetricDefinition lookup via tuning job in TrainingJobAnalytics
2526

2627
1.14.1
2728
======

src/sagemaker/analytics.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from six import with_metaclass
2121

2222
from sagemaker.session import Session
23-
from sagemaker.utils import DeferredError, extract_name_from_job_arn
23+
from sagemaker.utils import DeferredError
2424

2525
try:
2626
import pandas as pd
@@ -310,18 +310,9 @@ def _add_single_metric(self, timestamp, metric_name, value):
310310
def _metric_names_for_training_job(self):
311311
"""Helper method to discover the metrics defined for a training job.
312312
"""
313-
# First look up the tuning job
314313
training_description = self._sage_client.describe_training_job(TrainingJobName=self._training_job_name)
315-
tuning_job_arn = training_description.get('TuningJobArn', None)
316-
if not tuning_job_arn:
317-
raise ValueError(
318-
"No metrics available. Training Job Analytics only available through Hyperparameter Tuning Jobs"
319-
)
320-
tuning_job_name = extract_name_from_job_arn(tuning_job_arn)
321-
tuning_job_description = self._sage_client.describe_hyper_parameter_tuning_job(
322-
HyperParameterTuningJobName=tuning_job_name
323-
)
324-
training_job_definition = tuning_job_description['TrainingJobDefinition']
325-
metric_definitions = training_job_definition['AlgorithmSpecification']['MetricDefinitions']
314+
315+
metric_definitions = training_description['AlgorithmSpecification']['MetricDefinitions']
326316
metric_names = [md['Name'] for md in metric_definitions]
317+
327318
return metric_names

tests/unit/test_estimator.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -824,28 +824,21 @@ def test_generic_training_job_analytics(sagemaker_session):
824824
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value={
825825
'TuningJobArn': 'arn:aws:sagemaker:us-west-2:968277160000:hyper-parameter-tuning-job/mock-tuner',
826826
'TrainingStartTime': 1530562991.299,
827-
})
828-
sagemaker_session.sagemaker_client.describe_hyper_parameter_tuning_job = Mock(
829-
name='describe_hyper_parameter_tuning_job',
830-
return_value={
831-
'TrainingJobDefinition': {
832-
"AlgorithmSpecification": {
833-
"TrainingImage": "some-image-url",
834-
"TrainingInputMode": "File",
835-
"MetricDefinitions": [
836-
{
837-
"Name": "train:loss",
838-
"Regex": "train_loss=([0-9]+\\.[0-9]+)"
839-
},
840-
{
841-
"Name": "validation:loss",
842-
"Regex": "valid_loss=([0-9]+\\.[0-9]+)"
843-
}
844-
]
827+
"AlgorithmSpecification": {
828+
"TrainingImage": "some-image-url",
829+
"TrainingInputMode": "File",
830+
"MetricDefinitions": [
831+
{
832+
"Name": "train:loss",
833+
"Regex": "train_loss=([0-9]+\\.[0-9]+)"
834+
},
835+
{
836+
"Name": "validation:loss",
837+
"Regex": "valid_loss=([0-9]+\\.[0-9]+)"
845838
}
846-
}
847-
}
848-
)
839+
]
840+
},
841+
})
849842

850843
e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path=OUTPUT_PATH,
851844
sagemaker_session=sagemaker_session)

0 commit comments

Comments
 (0)