Skip to content

Commit 9c0a4c3

Browse files
author
Chia-Eng
committed
Address PR comments
1 parent 8d06bcd commit 9c0a4c3

File tree

3 files changed

+7
-8
lines changed

3 files changed

+7
-8
lines changed

src/sagemaker/estimator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,7 @@ def __init__(
264264
``disable_profiler`` parameter to ``True``.
265265
disable_profiler (bool): Specifies whether Debugger monitoring and profiling
266266
will be disabled (default: ``False``).
267-
environment (dict[str, str]) : A string to string map contains environment
268-
variables to set in the Docker container.
267+
environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``)
269268
270269
"""
271270
instance_count = renamed_kwargs(
@@ -1808,8 +1807,7 @@ def __init__(
18081807
``disable_profiler`` parameter to ``True``.
18091808
disable_profiler (bool): Specifies whether Debugger monitoring and profiling
18101809
will be disabled (default: ``False``).
1811-
environment (dict[str, str]) : A string to string map contains environment
1812-
variables to set in the Docker container.
1810+
environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``)
18131811
"""
18141812
self.image_uri = image_uri
18151813
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}

src/sagemaker/session.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,7 @@ def train( # noqa: C901
526526
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
527527
profiler_config (dict): Configuration for how profiling information is emitted
528528
with SageMaker Profiler. (default: ``None``).
529-
environment (dict[str, str]) : A string to string map contains environment
530-
variables to set in the Docker container.
529+
environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``)
531530
532531
Returns:
533532
str: ARN of the training job, if it is created.
@@ -662,8 +661,7 @@ def _get_train_request( # noqa: C901
662661
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
663662
profiler_config(dict): Configuration for how profiling information is emitted with
664663
SageMaker Profiler. (default: ``None``).
665-
environment (dict[str, str]) : A string to string map contains environment
666-
variables to set in the Docker container.
664+
environment (dict[str, str]) : Environment variables to be set for use during training job (default: ``None``)
667665
668666
Returns:
669667
Dict: a training request dict

tests/integ/test_tf.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
PARAMETER_SERVER_DISTRIBUTION = {"parameter_server": {"enabled": True}}
3737
MPI_DISTRIBUTION = {"mpi": {"enabled": True}}
3838
TAGS = [{"Key": "some-key", "Value": "some-value"}]
39+
ENV_INPUT= {'env_key1': 'env_val1', 'env_key2': 'env_val2', 'env_key3': 'env_val3'}
3940

4041

4142
def test_mnist_with_checkpoint_config(
@@ -59,6 +60,7 @@ def test_mnist_with_checkpoint_config(
5960
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
6061
checkpoint_s3_uri=checkpoint_s3_uri,
6162
checkpoint_local_path=checkpoint_local_path,
63+
environment=ENV_INPUT,
6264
)
6365
inputs = estimator.sagemaker_session.upload_data(
6466
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist"
@@ -83,6 +85,7 @@ def test_mnist_with_checkpoint_config(
8385
TrainingJobName=training_job_name
8486
)["CheckpointConfig"]
8587
assert actual_training_checkpoint_config == expected_training_checkpoint_config
88+
assert actual_training_checkpoint_config['Environment'] == ENV_INPUT
8689

8790

8891
def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_version):

0 commit comments

Comments
 (0)