Skip to content

Commit f59d0c2

Browse files
chiaengcChia-Engicywang86rui
authored andcommitted
feature: Add environment variable support for SageMaker training job (aws#2218)
Co-authored-by: Chia-Eng <[email protected]> Co-authored-by: icywang86rui <[email protected]>
1 parent 11ee6c4 commit f59d0c2

File tree

14 files changed

+147
-1
lines changed

14 files changed

+147
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def __init__(
123123
enable_network_isolation=False,
124124
profiler_config=None,
125125
disable_profiler=False,
126+
environment=None,
126127
**kwargs,
127128
):
128129
"""Initialize an ``EstimatorBase`` instance.
@@ -266,6 +267,8 @@ def __init__(
266267
``disable_profiler`` parameter to ``True``.
267268
disable_profiler (bool): Specifies whether Debugger monitoring and profiling
268269
will be disabled (default: ``False``).
270+
environment (dict[str, str]) : Environment variables to be set for
271+
use during training job (default: ``None``)
269272
270273
"""
271274
instance_count = renamed_kwargs(
@@ -352,6 +355,8 @@ def __init__(
352355
self.profiler_config = profiler_config
353356
self.disable_profiler = disable_profiler
354357

358+
self.environment = environment
359+
355360
if not _region_supports_profiler(self.sagemaker_session.boto_region_name):
356361
self.disable_profiler = True
357362

@@ -1471,6 +1476,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
14711476
train_args["tags"] = estimator.tags
14721477
train_args["metric_definitions"] = estimator.metric_definitions
14731478
train_args["experiment_config"] = experiment_config
1479+
train_args["environment"] = estimator.environment
14741480

14751481
if isinstance(inputs, TrainingInput):
14761482
if "InputMode" in inputs.config:
@@ -1659,6 +1665,7 @@ def __init__(
16591665
enable_sagemaker_metrics=None,
16601666
profiler_config=None,
16611667
disable_profiler=False,
1668+
environment=None,
16621669
**kwargs,
16631670
):
16641671
"""Initialize an ``Estimator`` instance.
@@ -1807,6 +1814,8 @@ def __init__(
18071814
``disable_profiler`` parameter to ``True``.
18081815
disable_profiler (bool): Specifies whether Debugger monitoring and profiling
18091816
will be disabled (default: ``False``).
1817+
environment (dict[str, str]) : Environment variables to be set for
1818+
use during training job (default: ``None``)
18101819
"""
18111820
self.image_uri = image_uri
18121821
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
@@ -1840,6 +1849,7 @@ def __init__(
18401849
enable_network_isolation=enable_network_isolation,
18411850
profiler_config=profiler_config,
18421851
disable_profiler=disable_profiler,
1852+
environment=environment,
18431853
**kwargs,
18441854
)
18451855

src/sagemaker/session.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ def train( # noqa: C901
456456
enable_sagemaker_metrics=None,
457457
profiler_rule_configs=None,
458458
profiler_config=None,
459+
environment=None,
459460
):
460461
"""Create an Amazon SageMaker training job.
461462
@@ -522,9 +523,12 @@ def train( # noqa: C901
522523
Series. For more information see:
523524
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
524525
(default: ``None``).
525-
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
526+
profiler_rule_configs (list[dict]): A list of profiler rule
527+
configurations.src/sagemaker/lineage/artifact.py:285
526528
profiler_config (dict): Configuration for how profiling information is emitted
527529
with SageMaker Profiler. (default: ``None``).
530+
environment (dict[str, str]) : Environment variables to be set for
531+
use during training job (default: ``None``)
528532
529533
Returns:
530534
str: ARN of the training job, if it is created.
@@ -556,6 +560,7 @@ def train( # noqa: C901
556560
enable_sagemaker_metrics=enable_sagemaker_metrics,
557561
profiler_rule_configs=profiler_rule_configs,
558562
profiler_config=profiler_config,
563+
environment=environment,
559564
)
560565
LOGGER.info("Creating training-job with name: %s", job_name)
561566
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
@@ -588,6 +593,7 @@ def _get_train_request( # noqa: C901
588593
enable_sagemaker_metrics=None,
589594
profiler_rule_configs=None,
590595
profiler_config=None,
596+
environment=None,
591597
):
592598
"""Constructs a request compatible for creating an Amazon SageMaker training job.
593599
@@ -657,6 +663,8 @@ def _get_train_request( # noqa: C901
657663
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
658664
profiler_config(dict): Configuration for how profiling information is emitted with
659665
SageMaker Profiler. (default: ``None``).
666+
environment (dict[str, str]) : Environment variables to be set for
667+
use during training job (default: ``None``)
660668
661669
Returns:
662670
Dict: a training request dict
@@ -699,6 +707,9 @@ def _get_train_request( # noqa: C901
699707
if hyperparameters and len(hyperparameters) > 0:
700708
train_request["HyperParameters"] = hyperparameters
701709

710+
if environment is not None:
711+
train_request["Environment"] = environment
712+
702713
if tags is not None:
703714
train_request["Tags"] = tags
704715

tests/integ/test_tf.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
PARAMETER_SERVER_DISTRIBUTION = {"parameter_server": {"enabled": True}}
3737
MPI_DISTRIBUTION = {"mpi": {"enabled": True}}
3838
TAGS = [{"Key": "some-key", "Value": "some-value"}]
39+
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
3940

4041

4142
def test_mnist_with_checkpoint_config(
@@ -59,6 +60,7 @@ def test_mnist_with_checkpoint_config(
5960
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
6061
checkpoint_s3_uri=checkpoint_s3_uri,
6162
checkpoint_local_path=checkpoint_local_path,
63+
environment=ENV_INPUT,
6264
)
6365
inputs = estimator.sagemaker_session.upload_data(
6466
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist"
@@ -82,7 +84,13 @@ def test_mnist_with_checkpoint_config(
8284
actual_training_checkpoint_config = sagemaker_session.sagemaker_client.describe_training_job(
8385
TrainingJobName=training_job_name
8486
)["CheckpointConfig"]
87+
actual_training_environment_variable_config = (
88+
sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)[
89+
"Environment"
90+
]
91+
)
8592
assert actual_training_checkpoint_config == expected_training_checkpoint_config
93+
assert actual_training_environment_variable_config == ENV_INPUT
8694

8795

8896
def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_version):

tests/unit/sagemaker/huggingface/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,7 @@ def _create_train_job(version, base_framework_version):
149149
"tags": None,
150150
"vpc_config": None,
151151
"metric_definitions": None,
152+
"environment": None,
152153
"experiment_config": None,
153154
"debugger_hook_config": {
154155
"CollectionConfigurations": [],

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def _create_train_job(tf_version, horovod=False, ps=False, py_version="py2", smd
130130
"tags": None,
131131
"vpc_config": None,
132132
"metric_definitions": None,
133+
"environment": None,
133134
"experiment_config": None,
134135
"profiler_rule_configs": [
135136
{

tests/unit/sagemaker/tensorflow/test_estimator_init.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020

2121
REGION = "us-west-2"
2222

23+
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
24+
2325

2426
@pytest.fixture()
2527
def sagemaker_session():
@@ -68,6 +70,26 @@ def test_framework_name(sagemaker_session):
6870
assert tf._framework_name == "tensorflow"
6971

7072

73+
def test_tf_add_environment_variables(sagemaker_session):
74+
tf = _build_tf(
75+
sagemaker_session,
76+
framework_version="1.15.2",
77+
py_version="py3",
78+
environment=ENV_INPUT,
79+
)
80+
assert tf.environment == ENV_INPUT
81+
82+
83+
def test_tf_miss_environment_variables(sagemaker_session):
84+
tf = _build_tf(
85+
sagemaker_session,
86+
framework_version="1.15.2",
87+
py_version="py3",
88+
environment=None,
89+
)
90+
assert not tf.environment
91+
92+
7193
def test_enable_sm_metrics(sagemaker_session):
7294
tf = _build_tf(
7395
sagemaker_session,

tests/unit/test_chainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def _create_train_job(version, py_version):
143143
"tags": None,
144144
"vpc_config": None,
145145
"metric_definitions": None,
146+
"environment": None,
146147
"experiment_config": None,
147148
"debugger_hook_config": {
148149
"CollectionConfigurations": [],

tests/unit/test_estimator.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
CODECOMMIT_REPO_SSH = "ssh://git-codecommit.us-west-2.amazonaws.com/v1/repos/test-repo/"
7272
CODECOMMIT_BRANCH = "master"
7373
REPO_DIR = "/tmp/repo_dir"
74+
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
7475

7576
DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}}
7677

@@ -241,6 +242,7 @@ def test_framework_all_init_args(sagemaker_session):
241242
checkpoint_local_path="file://local/checkpoint",
242243
enable_sagemaker_metrics=True,
243244
enable_network_isolation=True,
245+
environment=ENV_INPUT,
244246
)
245247
_TrainingJob.start_new(f, "s3://mydata", None)
246248
sagemaker_session.train.assert_called_once()
@@ -275,6 +277,7 @@ def test_framework_all_init_args(sagemaker_session):
275277
},
276278
"metric_definitions": [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}],
277279
"encrypt_inter_container_traffic": True,
280+
"environment": {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"},
278281
"experiment_config": None,
279282
"checkpoint_s3_uri": "s3://bucket/checkpoint",
280283
"checkpoint_local_path": "file://local/checkpoint",
@@ -1085,6 +1088,7 @@ def test_framework_with_spot_and_checkpoints(sagemaker_session):
10851088
"use_spot_instances": True,
10861089
"checkpoint_s3_uri": "s3://mybucket/checkpoints/",
10871090
"checkpoint_local_path": "/tmp/checkpoints",
1091+
"environment": None,
10881092
"experiment_config": None,
10891093
}
10901094

@@ -2389,6 +2393,7 @@ def test_unsupported_type_in_dict():
23892393
"tags": None,
23902394
"vpc_config": None,
23912395
"metric_definitions": None,
2396+
"environment": None,
23922397
"experiment_config": None,
23932398
}
23942399

@@ -2678,6 +2683,24 @@ def test_generic_to_fit_with_sagemaker_metrics_missing(sagemaker_session):
26782683
assert "enable_sagemaker_metrics" not in args
26792684

26802685

2686+
def test_add_environment_variables_to_train_args(sagemaker_session):
2687+
e = Estimator(
2688+
IMAGE_URI,
2689+
ROLE,
2690+
INSTANCE_COUNT,
2691+
INSTANCE_TYPE,
2692+
output_path=OUTPUT_PATH,
2693+
sagemaker_session=sagemaker_session,
2694+
environment=ENV_INPUT,
2695+
)
2696+
2697+
e.fit()
2698+
2699+
sagemaker_session.train.assert_called_once()
2700+
args = sagemaker_session.train.call_args[1]
2701+
assert args["environment"] == ENV_INPUT
2702+
2703+
26812704
def test_generic_to_fit_with_sagemaker_metrics_enabled(sagemaker_session):
26822705
e = Estimator(
26832706
IMAGE_URI,

tests/unit/test_mxnet.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565

6666
MODEL_PKG_RESPONSE = {"ModelPackageArn": "arn:model-pkg-arn"}
6767

68+
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
69+
6870

6971
@pytest.fixture()
7072
def sagemaker_session():
@@ -144,6 +146,7 @@ def _get_train_args(job_name):
144146
"tags": None,
145147
"vpc_config": None,
146148
"metric_definitions": None,
149+
"environment": None,
147150
"experiment_config": None,
148151
"debugger_hook_config": {
149152
"CollectionConfigurations": [],
@@ -959,6 +962,38 @@ def test_create_model_with_custom_hosting_image(sagemaker_session):
959962
assert model.image_uri == custom_hosting_image
960963

961964

965+
def test_mx_add_environment_variables(
966+
sagemaker_session, mxnet_training_version, mxnet_training_py_version
967+
):
968+
mx = MXNet(
969+
entry_point=SCRIPT_PATH,
970+
framework_version=mxnet_training_version,
971+
py_version=mxnet_training_py_version,
972+
role=ROLE,
973+
sagemaker_session=sagemaker_session,
974+
instance_count=INSTANCE_COUNT,
975+
instance_type=INSTANCE_TYPE,
976+
environment=ENV_INPUT,
977+
)
978+
assert mx.environment == ENV_INPUT
979+
980+
981+
def test_mx_missing_environment_variables(
982+
sagemaker_session, mxnet_training_version, mxnet_training_py_version
983+
):
984+
mx = MXNet(
985+
entry_point=SCRIPT_PATH,
986+
framework_version=mxnet_training_version,
987+
py_version=mxnet_training_py_version,
988+
role=ROLE,
989+
sagemaker_session=sagemaker_session,
990+
instance_count=INSTANCE_COUNT,
991+
instance_type=INSTANCE_TYPE,
992+
environment=None,
993+
)
994+
assert not mx.environment
995+
996+
962997
def test_mx_enable_sm_metrics(sagemaker_session, mxnet_training_version, mxnet_training_py_version):
963998
mx = MXNet(
964999
entry_point=SCRIPT_PATH,

tests/unit/test_pytorch.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646

4747
ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]}
4848

49+
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
50+
4951
LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]}
5052

5153
EXPERIMENT_CONFIG = {
@@ -146,6 +148,7 @@ def _create_train_job(version, py_version):
146148
"tags": None,
147149
"vpc_config": None,
148150
"metric_definitions": None,
151+
"environment": None,
149152
"experiment_config": None,
150153
"debugger_hook_config": {
151154
"CollectionConfigurations": [],
@@ -637,6 +640,30 @@ def test_pt_disable_sm_metrics(
637640
assert not pytorch.enable_sagemaker_metrics
638641

639642

643+
def test_pt_add_environment_variables(
644+
sagemaker_session, pytorch_training_version, pytorch_training_py_version
645+
):
646+
pytorch = _pytorch_estimator(
647+
sagemaker_session,
648+
framework_version=pytorch_training_version,
649+
py_version=pytorch_training_py_version,
650+
environment=ENV_INPUT,
651+
)
652+
assert pytorch.environment
653+
654+
655+
def test_pt_miss_environment_variables(
656+
sagemaker_session, pytorch_training_version, pytorch_training_py_version
657+
):
658+
pytorch = _pytorch_estimator(
659+
sagemaker_session,
660+
framework_version=pytorch_training_version,
661+
py_version=pytorch_training_py_version,
662+
environment=None,
663+
)
664+
assert not pytorch.environment
665+
666+
640667
def test_pt_default_sm_metrics(
641668
sagemaker_session, pytorch_training_version, pytorch_training_py_version
642669
):

tests/unit/test_rl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def _create_train_job(toolkit, toolkit_version, framework):
146146
{"Name": "reward-training", "Regex": "^Training>.*Total reward=(.*?),"},
147147
{"Name": "reward-testing", "Regex": "^Testing>.*Total reward=(.*?),"},
148148
],
149+
"environment": None,
149150
"experiment_config": None,
150151
"debugger_hook_config": {
151152
"CollectionConfigurations": [],

0 commit comments

Comments
 (0)