Skip to content

Commit 65133a1

Browse files
ishaaqknakad
authored andcommitted
feature: add enable_sagemaker_metrics flag (#254)
Add the enable_sagemaker_metrics flag for Estimators. This corresponds to the new EnableSageMakerMetricsTimeSeries parameter. By default, if the flag is unset then we don't pass on the EnableSageMakerMetricsTimeSeries parameter unless you're using the following framework versions: MXNet >= 1.6 TensorFlow >=1.15 Pytorch >=1.3
1 parent ac94c28 commit 65133a1

File tree

10 files changed

+205
-0
lines changed

10 files changed

+205
-0
lines changed

src/sagemaker/estimator.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __init__(
9898
rules=None,
9999
debugger_hook_config=None,
100100
tensorboard_output_config=None,
101+
enable_sagemaker_metrics=None,
101102
):
102103
"""Initialize an ``EstimatorBase`` instance.
103104
@@ -195,6 +196,10 @@ def __init__(
195196
started. If the path is unset then SageMaker assumes the
196197
checkpoints will be provided under `/opt/ml/checkpoints/`.
197198
(default: ``None``).
199+
enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
200+
Series. For more information see:
201+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
202+
(default: ``None``).
198203
"""
199204
self.role = role
200205
self.train_instance_count = train_instance_count
@@ -250,6 +255,8 @@ def __init__(
250255
self.debugger_rule_configs = None
251256
self.collection_configs = None
252257

258+
self.enable_sagemaker_metrics = enable_sagemaker_metrics
259+
253260
@abstractmethod
254261
def train_image(self):
255262
"""Return the Docker image to use for training.
@@ -958,6 +965,9 @@ def start_new(cls, estimator, inputs):
958965

959966
cls._add_spot_checkpoint_args(local_mode, estimator, train_args)
960967

968+
if estimator.enable_sagemaker_metrics is not None:
969+
train_args["enable_sagemaker_metrics"] = estimator.enable_sagemaker_metrics
970+
961971
estimator.sagemaker_session.train(**train_args)
962972

963973
return cls(estimator.sagemaker_session, estimator._current_job_name)
@@ -1060,6 +1070,7 @@ def __init__(
10601070
rules=None,
10611071
debugger_hook_config=None,
10621072
tensorboard_output_config=None,
1073+
enable_sagemaker_metrics=None,
10631074
):
10641075
"""Initialize an ``Estimator`` instance.
10651076
@@ -1171,6 +1182,10 @@ def __init__(
11711182
user entry script for training. The user entry script, files in
11721183
source_dir (if specified), and dependencies will be uploaded in
11731184
a tar to S3. Also known as internet-free mode (default: ``False``).
1185+
enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
1186+
Series. For more information see:
1187+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
1188+
(default: ``None``).
11741189
"""
11751190
self.image_name = image_name
11761191
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
@@ -1201,6 +1216,7 @@ def __init__(
12011216
rules=rules,
12021217
debugger_hook_config=debugger_hook_config,
12031218
tensorboard_output_config=tensorboard_output_config,
1219+
enable_sagemaker_metrics=enable_sagemaker_metrics,
12041220
)
12051221

12061222
def enable_network_isolation(self):
@@ -1354,6 +1370,7 @@ def __init__(
13541370
git_config=None,
13551371
checkpoint_s3_uri=None,
13561372
checkpoint_local_path=None,
1373+
enable_sagemaker_metrics=None,
13571374
**kwargs
13581375
):
13591376
"""Base class initializer. Subclasses which override ``__init__`` should
@@ -1500,6 +1517,10 @@ def __init__(
15001517
started. If the path is unset then SageMaker assumes the
15011518
checkpoints will be provided under `/opt/ml/checkpoints/`.
15021519
(default: ``None``).
1520+
enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
1521+
Series. For more information see:
1522+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
1523+
(default: ``None``).
15031524
**kwargs: Additional kwargs passed to the ``EstimatorBase``
15041525
constructor.
15051526
"""
@@ -1530,6 +1551,7 @@ def __init__(
15301551
self._hyperparameters = hyperparameters or {}
15311552
self.checkpoint_s3_uri = checkpoint_s3_uri
15321553
self.checkpoint_local_path = checkpoint_local_path
1554+
self.enable_sagemaker_metrics = enable_sagemaker_metrics
15331555

15341556
def enable_network_isolation(self):
15351557
"""Return True if this Estimator can use network isolation to run.

src/sagemaker/mxnet/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
framework_version_from_tag,
2222
empty_framework_version_warning,
2323
python_deprecation_warning,
24+
is_version_equal_or_higher,
2425
)
2526
from sagemaker.mxnet.defaults import MXNET_VERSION
2627
from sagemaker.mxnet.model import MXNetModel
@@ -103,6 +104,11 @@ def __init__(
103104
logger.warning(empty_framework_version_warning(MXNET_VERSION, self.LATEST_VERSION))
104105
self.framework_version = framework_version or MXNET_VERSION
105106

107+
if "enable_sagemaker_metrics" not in kwargs:
108+
# enable sagemaker metrics for MXNet v1.6 or greater:
109+
if is_version_equal_or_higher([1, 6], self.framework_version):
110+
kwargs["enable_sagemaker_metrics"] = True
111+
106112
super(MXNet, self).__init__(
107113
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
108114
)

src/sagemaker/pytorch/estimator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
framework_version_from_tag,
2222
empty_framework_version_warning,
2323
python_deprecation_warning,
24+
is_version_equal_or_higher,
2425
)
2526
from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION
2627
from sagemaker.pytorch.model import PyTorchModel
@@ -98,6 +99,11 @@ def __init__(
9899
logger.warning(empty_framework_version_warning(PYTORCH_VERSION, PYTORCH_VERSION))
99100
self.framework_version = framework_version or PYTORCH_VERSION
100101

102+
if "enable_sagemaker_metrics" not in kwargs:
103+
# enable sagemaker metrics for PT v1.3 or greater:
104+
if is_version_equal_or_higher([1, 3], self.framework_version):
105+
kwargs["enable_sagemaker_metrics"] = True
106+
101107
super(PyTorch, self).__init__(
102108
entry_point, source_dir, hyperparameters, image_name=image_name, **kwargs
103109
)

src/sagemaker/session.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,7 @@ def train( # noqa: C901
370370
debugger_rule_configs=None,
371371
debugger_hook_config=None,
372372
tensorboard_output_config=None,
373+
enable_sagemaker_metrics=None,
373374
):
374375
"""Create an Amazon SageMaker training job.
375376
@@ -432,6 +433,10 @@ def train( # noqa: C901
432433
started. If the path is unset then SageMaker assumes the
433434
checkpoints will be provided under `/opt/ml/checkpoints/`.
434435
(default: ``None``).
436+
enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
437+
Series. For more information see:
438+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
439+
(default: ``None``).
435440
436441
Returns:
437442
str: ARN of the training job, if it is created.
@@ -467,6 +472,11 @@ def train( # noqa: C901
467472
if metric_definitions is not None:
468473
train_request["AlgorithmSpecification"]["MetricDefinitions"] = metric_definitions
469474

475+
if enable_sagemaker_metrics is not None:
476+
train_request["AlgorithmSpecification"][
477+
"EnableSageMakerMetricsTimeSeries"
478+
] = enable_sagemaker_metrics
479+
470480
if hyperparameters and len(hyperparameters) > 0:
471481
train_request["HyperParameters"] = hyperparameters
472482

src/sagemaker/tensorflow/estimator.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,11 @@ def __init__(
286286
if not py_version:
287287
py_version = "py3" if self._only_python_3_supported() else "py2"
288288

289+
if "enable_sagemaker_metrics" not in kwargs:
290+
# enable sagemaker metrics for TF v1.15 or greater:
291+
if fw.is_version_equal_or_higher([1, 15], self.framework_version):
292+
kwargs["enable_sagemaker_metrics"] = True
293+
289294
super(TensorFlow, self).__init__(image_name=image_name, **kwargs)
290295
self.checkpoint_path = checkpoint_path
291296

tests/unit/test_estimator.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@ def test_framework_all_init_args(sagemaker_session):
205205
encrypt_inter_container_traffic=True,
206206
checkpoint_s3_uri="s3://bucket/checkpoint",
207207
checkpoint_local_path="file://local/checkpoint",
208+
enable_sagemaker_metrics=True,
208209
)
209210
_TrainingJob.start_new(f, "s3://mydata")
210211
sagemaker_session.train.assert_called_once()
@@ -241,6 +242,7 @@ def test_framework_all_init_args(sagemaker_session):
241242
"encrypt_inter_container_traffic": True,
242243
"checkpoint_s3_uri": "s3://bucket/checkpoint",
243244
"checkpoint_local_path": "file://local/checkpoint",
245+
"enable_sagemaker_metrics": True,
244246
}
245247

246248

@@ -1835,6 +1837,59 @@ def test_generic_to_fit_with_network_isolation(sagemaker_session):
18351837
assert args["enable_network_isolation"]
18361838

18371839

1840+
def test_generic_to_fit_with_sagemaker_metrics_missing(sagemaker_session):
1841+
e = Estimator(
1842+
IMAGE_NAME,
1843+
ROLE,
1844+
INSTANCE_COUNT,
1845+
INSTANCE_TYPE,
1846+
output_path=OUTPUT_PATH,
1847+
sagemaker_session=sagemaker_session,
1848+
)
1849+
1850+
e.fit()
1851+
1852+
sagemaker_session.train.assert_called_once()
1853+
args = sagemaker_session.train.call_args[1]
1854+
assert "enable_sagemaker_metrics" not in args
1855+
1856+
1857+
def test_generic_to_fit_with_sagemaker_metrics_enabled(sagemaker_session):
1858+
e = Estimator(
1859+
IMAGE_NAME,
1860+
ROLE,
1861+
INSTANCE_COUNT,
1862+
INSTANCE_TYPE,
1863+
output_path=OUTPUT_PATH,
1864+
sagemaker_session=sagemaker_session,
1865+
enable_sagemaker_metrics=True,
1866+
)
1867+
1868+
e.fit()
1869+
1870+
sagemaker_session.train.assert_called_once()
1871+
args = sagemaker_session.train.call_args[1]
1872+
assert args["enable_sagemaker_metrics"]
1873+
1874+
1875+
def test_generic_to_fit_with_sagemaker_metrics_disabled(sagemaker_session):
1876+
e = Estimator(
1877+
IMAGE_NAME,
1878+
ROLE,
1879+
INSTANCE_COUNT,
1880+
INSTANCE_TYPE,
1881+
output_path=OUTPUT_PATH,
1882+
sagemaker_session=sagemaker_session,
1883+
enable_sagemaker_metrics=False,
1884+
)
1885+
1886+
e.fit()
1887+
1888+
sagemaker_session.train.assert_called_once()
1889+
args = sagemaker_session.train.call_args[1]
1890+
assert not args["enable_sagemaker_metrics"]
1891+
1892+
18381893
def test_generic_to_deploy(sagemaker_session):
18391894
e = Estimator(
18401895
IMAGE_NAME,

tests/unit/test_mxnet.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,3 +736,53 @@ def test_create_model_with_custom_hosting_image(sagemaker_session):
736736
model = mx.create_model(image_name=custom_hosting_image)
737737

738738
assert model.image == custom_hosting_image
739+
740+
741+
def test_mx_enable_sm_metrics(sagemaker_session):
742+
mx = MXNet(
743+
entry_point=SCRIPT_PATH,
744+
role=ROLE,
745+
sagemaker_session=sagemaker_session,
746+
train_instance_count=INSTANCE_COUNT,
747+
train_instance_type=INSTANCE_TYPE,
748+
enable_sagemaker_metrics=True,
749+
)
750+
assert mx.enable_sagemaker_metrics
751+
752+
753+
def test_mx_disable_sm_metrics(sagemaker_session):
754+
mx = MXNet(
755+
entry_point=SCRIPT_PATH,
756+
role=ROLE,
757+
sagemaker_session=sagemaker_session,
758+
train_instance_count=INSTANCE_COUNT,
759+
train_instance_type=INSTANCE_TYPE,
760+
enable_sagemaker_metrics=False,
761+
)
762+
assert not mx.enable_sagemaker_metrics
763+
764+
765+
def test_mx_disable_sm_metrics_if_pt_ver_is_less_than_1_6(sagemaker_session):
766+
for fw_version in ["1.1", "1.2", "1.3", "1.4", "1.5"]:
767+
mx = MXNet(
768+
entry_point=SCRIPT_PATH,
769+
role=ROLE,
770+
sagemaker_session=sagemaker_session,
771+
train_instance_count=INSTANCE_COUNT,
772+
train_instance_type=INSTANCE_TYPE,
773+
framework_version=fw_version,
774+
)
775+
assert mx.enable_sagemaker_metrics is None
776+
777+
778+
def test_mx_enable_sm_metrics_if_fw_ver_is_at_least_1_6(sagemaker_session):
779+
for fw_version in ["1.6", "1.7", "2.0", "2.1"]:
780+
mx = MXNet(
781+
entry_point=SCRIPT_PATH,
782+
role=ROLE,
783+
sagemaker_session=sagemaker_session,
784+
train_instance_count=INSTANCE_COUNT,
785+
train_instance_type=INSTANCE_TYPE,
786+
framework_version=fw_version,
787+
)
788+
assert mx.enable_sagemaker_metrics

tests/unit/test_pytorch.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,3 +523,25 @@ def test_empty_framework_version(warning, sagemaker_session):
523523

524524
assert estimator.framework_version == defaults.PYTORCH_VERSION
525525
warning.assert_called_with(defaults.PYTORCH_VERSION, defaults.PYTORCH_VERSION)
526+
527+
528+
def test_pt_enable_sm_metrics(sagemaker_session):
529+
pytorch = _pytorch_estimator(sagemaker_session, enable_sagemaker_metrics=True)
530+
assert pytorch.enable_sagemaker_metrics
531+
532+
533+
def test_pt_disable_sm_metrics(sagemaker_session):
534+
pytorch = _pytorch_estimator(sagemaker_session, enable_sagemaker_metrics=False)
535+
assert not pytorch.enable_sagemaker_metrics
536+
537+
538+
def test_pt_disable_sm_metrics_if_pt_ver_is_less_than_1_15(sagemaker_session):
539+
for fw_version in ["1.1", "1.2"]:
540+
pytorch = _pytorch_estimator(sagemaker_session, framework_version=fw_version)
541+
assert pytorch.enable_sagemaker_metrics is None
542+
543+
544+
def test_pt_enable_sm_metrics_if_fw_ver_is_at_least_1_15(sagemaker_session):
545+
for fw_version in ["1.3", "1.4", "2.0", "2.1"]:
546+
pytorch = _pytorch_estimator(sagemaker_session, framework_version=fw_version)
547+
assert pytorch.enable_sagemaker_metrics

tests/unit/test_session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -665,6 +665,7 @@ def test_train_pack_to_request(sagemaker_session):
665665
tags=None,
666666
vpc_config=VPC_CONFIG,
667667
metric_definitions=None,
668+
enable_sagemaker_metrics=None,
668669
)
669670

670671
assert sagemaker_session.sagemaker_client.method_calls[0] == (
@@ -1172,6 +1173,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
11721173
train_use_spot_instances=True,
11731174
checkpoint_s3_uri="s3://mybucket/checkpoints/",
11741175
checkpoint_local_path="/tmp/checkpoints",
1176+
enable_sagemaker_metrics=True,
11751177
)
11761178

11771179
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
@@ -1180,6 +1182,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
11801182
assert actual_train_args["HyperParameters"] == hyperparameters
11811183
assert actual_train_args["Tags"] == TAGS
11821184
assert actual_train_args["AlgorithmSpecification"]["MetricDefinitions"] == METRIC_DEFINITONS
1185+
assert actual_train_args["AlgorithmSpecification"]["EnableSageMakerMetricsTimeSeries"] is True
11831186
assert actual_train_args["EnableInterContainerTrafficEncryption"] is True
11841187
assert actual_train_args["EnableManagedSpotTraining"] is True
11851188
assert actual_train_args["CheckpointConfig"]["S3Uri"] == "s3://mybucket/checkpoints/"

tests/unit/test_tf_estimator.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,3 +1200,29 @@ def test_tf_script_mode_attach(sagemaker_session, tf_version):
12001200
assert estimator.hyperparameters() is not None
12011201
assert estimator.source_dir == "s3://some/sourcedir.tar.gz"
12021202
assert estimator.entry_point == "iris-dnn-classifier.py"
1203+
1204+
1205+
@patch("sagemaker.utils.create_tar_file", MagicMock())
1206+
def test_tf_enable_sm_metrics(sagemaker_session):
1207+
tf = _build_tf(sagemaker_session, enable_sagemaker_metrics=True)
1208+
assert tf.enable_sagemaker_metrics
1209+
1210+
1211+
@patch("sagemaker.utils.create_tar_file", MagicMock())
1212+
def test_tf_disable_sm_metrics(sagemaker_session):
1213+
tf = _build_tf(sagemaker_session, enable_sagemaker_metrics=False)
1214+
assert not tf.enable_sagemaker_metrics
1215+
1216+
1217+
@patch("sagemaker.utils.create_tar_file", MagicMock())
1218+
def test_tf_disable_sm_metrics_if_fw_ver_is_less_than_1_15(sagemaker_session):
1219+
for fw_version in ["1.11", "1.12", "1.13", "1.14"]:
1220+
tf = _build_tf(sagemaker_session, framework_version=fw_version)
1221+
assert tf.enable_sagemaker_metrics is None
1222+
1223+
1224+
@patch("sagemaker.utils.create_tar_file", MagicMock())
1225+
def test_tf_enable_sm_metrics_if_fw_ver_is_at_least_1_15(sagemaker_session):
1226+
for fw_version in ["1.15", "1.16", "2.0", "2.1"]:
1227+
tf = _build_tf(sagemaker_session, framework_version=fw_version)
1228+
assert tf.enable_sagemaker_metrics

0 commit comments

Comments
 (0)