Skip to content

Commit 39752f3

Browse files
author
Chia-Eng
committed
feature: Add environment variable support for SageMaker training job
1 parent 15c5c6a commit 39752f3

File tree

12 files changed

+136
-0
lines changed

12 files changed

+136
-0
lines changed

src/sagemaker/estimator.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __init__(
120120
enable_network_isolation=False,
121121
profiler_config=None,
122122
disable_profiler=False,
123+
environment = None,
123124
**kwargs,
124125
):
125126
"""Initialize an ``EstimatorBase`` instance.
@@ -263,6 +264,8 @@ def __init__(
263264
``disable_profiler`` parameter to ``True``.
264265
disable_profiler (bool): Specifies whether Debugger monitoring and profiling
265266
will be disabled (default: ``False``).
267+
environment (dict[str, str]) : A string to string map contains environment
268+
variables to set in the Docker container.
266269
267270
"""
268271
instance_count = renamed_kwargs(
@@ -349,6 +352,8 @@ def __init__(
349352
self.profiler_config = profiler_config
350353
self.disable_profiler = disable_profiler
351354

355+
self.environment = environment
356+
352357
if not _region_supports_profiler(self.sagemaker_session.boto_region_name):
353358
self.disable_profiler = True
354359

@@ -1465,6 +1470,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
14651470
train_args["tags"] = estimator.tags
14661471
train_args["metric_definitions"] = estimator.metric_definitions
14671472
train_args["experiment_config"] = experiment_config
1473+
train_args["environment"] = estimator.environment
14681474

14691475
if isinstance(inputs, TrainingInput):
14701476
if "InputMode" in inputs.config:
@@ -1653,6 +1659,7 @@ def __init__(
16531659
enable_sagemaker_metrics=None,
16541660
profiler_config=None,
16551661
disable_profiler=False,
1662+
environment = None,
16561663
**kwargs,
16571664
):
16581665
"""Initialize an ``Estimator`` instance.
@@ -1834,6 +1841,7 @@ def __init__(
18341841
enable_network_isolation=enable_network_isolation,
18351842
profiler_config=profiler_config,
18361843
disable_profiler=disable_profiler,
1844+
environment=environment,
18371845
**kwargs,
18381846
)
18391847

src/sagemaker/session.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ def train( # noqa: C901
456456
enable_sagemaker_metrics=None,
457457
profiler_rule_configs=None,
458458
profiler_config=None,
459+
environment=None
459460
):
460461
"""Create an Amazon SageMaker training job.
461462
@@ -525,6 +526,8 @@ def train( # noqa: C901
525526
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
526527
profiler_config (dict): Configuration for how profiling information is emitted
527528
with SageMaker Profiler. (default: ``None``).
529+
environment (dict[str, str]) : A string to string map contains environment
530+
variables to set in the Docker container.
528531
529532
Returns:
530533
str: ARN of the training job, if it is created.
@@ -556,6 +559,7 @@ def train( # noqa: C901
556559
enable_sagemaker_metrics=enable_sagemaker_metrics,
557560
profiler_rule_configs=profiler_rule_configs,
558561
profiler_config=profiler_config,
562+
environment=environment
559563
)
560564
LOGGER.info("Creating training-job with name: %s", job_name)
561565
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
@@ -588,6 +592,7 @@ def _get_train_request( # noqa: C901
588592
enable_sagemaker_metrics=None,
589593
profiler_rule_configs=None,
590594
profiler_config=None,
595+
environment=None,
591596
):
592597
"""Constructs a request compatible for creating an Amazon SageMaker training job.
593598
@@ -657,6 +662,8 @@ def _get_train_request( # noqa: C901
657662
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
658663
profiler_config(dict): Configuration for how profiling information is emitted with
659664
SageMaker Profiler. (default: ``None``).
665+
environment (dict[str, str]) : A string to string map contains environment
666+
variables to set in the Docker container.
660667
661668
Returns:
662669
Dict: a training request dict
@@ -699,6 +706,9 @@ def _get_train_request( # noqa: C901
699706
if hyperparameters and len(hyperparameters) > 0:
700707
train_request["HyperParameters"] = hyperparameters
701708

709+
if environment is not None:
710+
train_request["Environment"] = environment
711+
702712
if tags is not None:
703713
train_request["Tags"] = tags
704714

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def _create_train_job(tf_version, horovod=False, ps=False, py_version="py2", smd
130130
"tags": None,
131131
"vpc_config": None,
132132
"metric_definitions": None,
133+
'environment': None,
133134
"experiment_config": None,
134135
"profiler_rule_configs": [
135136
{

tests/unit/sagemaker/tensorflow/test_estimator_init.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
REGION = "us-west-2"
2222

23+
ENV_INPUT= {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
2324

2425
@pytest.fixture()
2526
def sagemaker_session():
@@ -67,6 +68,23 @@ def test_framework_name(sagemaker_session):
6768
tf = _build_tf(sagemaker_session, framework_version="1.15.2", py_version="py3")
6869
assert tf._framework_name == "tensorflow"
6970

71+
def test_tf_add_environment_variables(sagemaker_session):
72+
tf = _build_tf(
73+
sagemaker_session,
74+
framework_version="1.15.2",
75+
py_version="py3",
76+
environment=ENV_INPUT,
77+
)
78+
assert tf.environment == ENV_INPUT
79+
80+
def test_tf_miss_environment_variables(sagemaker_session):
81+
tf = _build_tf(
82+
sagemaker_session,
83+
framework_version="1.15.2",
84+
py_version="py3",
85+
environment=None,
86+
)
87+
assert not tf.environment
7088

7189
def test_enable_sm_metrics(sagemaker_session):
7290
tf = _build_tf(

tests/unit/test_chainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def _create_train_job(version, py_version):
143143
"tags": None,
144144
"vpc_config": None,
145145
"metric_definitions": None,
146+
'environment': None,
146147
"experiment_config": None,
147148
"debugger_hook_config": {
148149
"CollectionConfigurations": [],

tests/unit/test_estimator.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
CODECOMMIT_REPO_SSH = "ssh://git-codecommit.us-west-2.amazonaws.com/v1/repos/test-repo/"
7272
CODECOMMIT_BRANCH = "master"
7373
REPO_DIR = "/tmp/repo_dir"
74+
ENV_INPUT = {'env_key1':'env_val1', 'env_key2':'env_val2', 'env_key3':'env_val3'}
7475

7576
DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}}
7677

@@ -241,6 +242,7 @@ def test_framework_all_init_args(sagemaker_session):
241242
checkpoint_local_path="file://local/checkpoint",
242243
enable_sagemaker_metrics=True,
243244
enable_network_isolation=True,
245+
environment=ENV_INPUT,
244246
)
245247
_TrainingJob.start_new(f, "s3://mydata", None)
246248
sagemaker_session.train.assert_called_once()
@@ -275,6 +277,7 @@ def test_framework_all_init_args(sagemaker_session):
275277
},
276278
"metric_definitions": [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}],
277279
"encrypt_inter_container_traffic": True,
280+
"environment": {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"},
278281
"experiment_config": None,
279282
"checkpoint_s3_uri": "s3://bucket/checkpoint",
280283
"checkpoint_local_path": "file://local/checkpoint",
@@ -1085,6 +1088,7 @@ def test_framework_with_spot_and_checkpoints(sagemaker_session):
10851088
"use_spot_instances": True,
10861089
"checkpoint_s3_uri": "s3://mybucket/checkpoints/",
10871090
"checkpoint_local_path": "/tmp/checkpoints",
1091+
"environment": None,
10881092
"experiment_config": None,
10891093
}
10901094

@@ -2389,6 +2393,7 @@ def test_unsupported_type_in_dict():
23892393
"tags": None,
23902394
"vpc_config": None,
23912395
"metric_definitions": None,
2396+
"environment": None,
23922397
"experiment_config": None,
23932398
}
23942399

@@ -2677,6 +2682,39 @@ def test_generic_to_fit_with_sagemaker_metrics_missing(sagemaker_session):
26772682
args = sagemaker_session.train.call_args[1]
26782683
assert "enable_sagemaker_metrics" not in args
26792684

2685+
def test_add_environment_variables_to_train_args(sagemaker_session):
2686+
e = Estimator(
2687+
IMAGE_URI,
2688+
ROLE,
2689+
INSTANCE_COUNT,
2690+
INSTANCE_TYPE,
2691+
output_path=OUTPUT_PATH,
2692+
sagemaker_session=sagemaker_session,
2693+
environment=ENV_INPUT,
2694+
)
2695+
2696+
e.fit()
2697+
2698+
sagemaker_session.train.assert_called_once()
2699+
args = sagemaker_session.train.call_args[1]
2700+
assert args["environment"]
2701+
2702+
def test_no_environment_variables_in_train_args(sagemaker_session):
2703+
e = Estimator(
2704+
IMAGE_URI,
2705+
ROLE,
2706+
INSTANCE_COUNT,
2707+
INSTANCE_TYPE,
2708+
output_path=OUTPUT_PATH,
2709+
sagemaker_session=sagemaker_session,
2710+
environment=None,
2711+
)
2712+
2713+
e.fit()
2714+
2715+
sagemaker_session.train.assert_called_once()
2716+
args = sagemaker_session.train.call_args[1]
2717+
assert not args["environment"]
26802718

26812719
def test_generic_to_fit_with_sagemaker_metrics_enabled(sagemaker_session):
26822720
e = Estimator(

tests/unit/test_mxnet.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@
6565

6666
MODEL_PKG_RESPONSE = {"ModelPackageArn": "arn:model-pkg-arn"}
6767

68+
ENV_INPUT= {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
69+
6870

6971
@pytest.fixture()
7072
def sagemaker_session():
@@ -144,6 +146,7 @@ def _get_train_args(job_name):
144146
"tags": None,
145147
"vpc_config": None,
146148
"metric_definitions": None,
149+
"environment": None,
147150
"experiment_config": None,
148151
"debugger_hook_config": {
149152
"CollectionConfigurations": [],
@@ -958,6 +961,31 @@ def test_create_model_with_custom_hosting_image(sagemaker_session):
958961

959962
assert model.image_uri == custom_hosting_image
960963

964+
def test_mx_add_environment_variables(sagemaker_session, mxnet_training_version, mxnet_training_py_version):
965+
mx = MXNet(
966+
entry_point=SCRIPT_PATH,
967+
framework_version=mxnet_training_version,
968+
py_version=mxnet_training_py_version,
969+
role=ROLE,
970+
sagemaker_session=sagemaker_session,
971+
instance_count=INSTANCE_COUNT,
972+
instance_type=INSTANCE_TYPE,
973+
environment=ENV_INPUT,
974+
)
975+
assert mx.environment == ENV_INPUT
976+
977+
def test_mx_missing_environment_variables(sagemaker_session, mxnet_training_version, mxnet_training_py_version):
978+
mx = MXNet(
979+
entry_point=SCRIPT_PATH,
980+
framework_version=mxnet_training_version,
981+
py_version=mxnet_training_py_version,
982+
role=ROLE,
983+
sagemaker_session=sagemaker_session,
984+
instance_count=INSTANCE_COUNT,
985+
instance_type=INSTANCE_TYPE,
986+
environment=None,
987+
)
988+
assert not mx.environment
961989

962990
def test_mx_enable_sm_metrics(sagemaker_session, mxnet_training_version, mxnet_training_py_version):
963991
mx = MXNet(

tests/unit/test_pytorch.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646

4747
ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]}
4848

49+
ENV_INPUT= {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
50+
4951
LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]}
5052

5153
EXPERIMENT_CONFIG = {
@@ -146,6 +148,7 @@ def _create_train_job(version, py_version):
146148
"tags": None,
147149
"vpc_config": None,
148150
"metric_definitions": None,
151+
"environment": None,
149152
"experiment_config": None,
150153
"debugger_hook_config": {
151154
"CollectionConfigurations": [],
@@ -636,6 +639,28 @@ def test_pt_disable_sm_metrics(
636639
)
637640
assert not pytorch.enable_sagemaker_metrics
638641

642+
def test_pt_add_environment_variables(
643+
sagemaker_session, pytorch_training_version, pytorch_training_py_version
644+
):
645+
pytorch = _pytorch_estimator(
646+
sagemaker_session,
647+
framework_version=pytorch_training_version,
648+
py_version=pytorch_training_py_version,
649+
environment=ENV_INPUT,
650+
)
651+
assert pytorch.environment
652+
653+
def test_pt_miss_environment_variables(
654+
sagemaker_session, pytorch_training_version, pytorch_training_py_version
655+
):
656+
pytorch = _pytorch_estimator(
657+
sagemaker_session,
658+
framework_version=pytorch_training_version,
659+
py_version=pytorch_training_py_version,
660+
environment=None,
661+
)
662+
assert not pytorch.environment
663+
639664

640665
def test_pt_default_sm_metrics(
641666
sagemaker_session, pytorch_training_version, pytorch_training_py_version

tests/unit/test_rl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ def _create_train_job(toolkit, toolkit_version, framework):
146146
{"Name": "reward-training", "Regex": "^Training>.*Total reward=(.*?),"},
147147
{"Name": "reward-testing", "Regex": "^Testing>.*Total reward=(.*?),"},
148148
],
149+
"environment": None,
149150
"experiment_config": None,
150151
"debugger_hook_config": {
151152
"CollectionConfigurations": [],

tests/unit/test_session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636

3737
SAMPLE_PARAM_RANGES = [{"Name": "mini_batch_size", "MinValue": "10", "MaxValue": "100"}]
3838

39+
ENV_INPUT= {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
40+
3941
REGION = "us-west-2"
4042
STS_ENDPOINT = "sts.us-west-2.amazonaws.com"
4143

@@ -1226,6 +1228,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
12261228
checkpoint_s3_uri="s3://mybucket/checkpoints/",
12271229
checkpoint_local_path="/tmp/checkpoints",
12281230
enable_sagemaker_metrics=True,
1231+
environment=ENV_INPUT,
12291232
)
12301233

12311234
_, _, actual_train_args = sagemaker_session.sagemaker_client.method_calls[0]
@@ -1239,6 +1242,7 @@ def test_train_pack_to_request_with_optional_params(sagemaker_session):
12391242
assert actual_train_args["EnableManagedSpotTraining"] is True
12401243
assert actual_train_args["CheckpointConfig"]["S3Uri"] == "s3://mybucket/checkpoints/"
12411244
assert actual_train_args["CheckpointConfig"]["LocalPath"] == "/tmp/checkpoints"
1245+
assert actual_train_args["Environment"] == ENV_INPUT
12421246

12431247

12441248
def test_transform_pack_to_request(sagemaker_session):

tests/unit/test_sklearn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def _create_train_job(version):
132132
"metric_definitions": None,
133133
"tags": None,
134134
"vpc_config": None,
135+
"environment": None,
135136
"experiment_config": None,
136137
"debugger_hook_config": {
137138
"CollectionConfigurations": [],

tests/unit/test_xgboost.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def _create_train_job(version, instance_count=1, instance_type="ml.c4.4xlarge"):
145145
"metric_definitions": None,
146146
"tags": None,
147147
"vpc_config": None,
148+
"environment": None,
148149
"experiment_config": None,
149150
"debugger_hook_config": {
150151
"CollectionConfigurations": [],

0 commit comments

Comments
 (0)