Skip to content

feature: Add environment variable support for SageMaker training job #2218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
enable_network_isolation=False,
profiler_config=None,
disable_profiler=False,
environment=None,
**kwargs,
):
"""Initialize an ``EstimatorBase`` instance.
Expand Down Expand Up @@ -266,6 +267,8 @@ def __init__(
``disable_profiler`` parameter to ``True``.
disable_profiler (bool): Specifies whether Debugger monitoring and profiling
will be disabled (default: ``False``).
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)

"""
instance_count = renamed_kwargs(
Expand Down Expand Up @@ -352,6 +355,8 @@ def __init__(
self.profiler_config = profiler_config
self.disable_profiler = disable_profiler

self.environment = environment

if not _region_supports_profiler(self.sagemaker_session.boto_region_name):
self.disable_profiler = True

Expand Down Expand Up @@ -1471,6 +1476,7 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
train_args["tags"] = estimator.tags
train_args["metric_definitions"] = estimator.metric_definitions
train_args["experiment_config"] = experiment_config
train_args["environment"] = estimator.environment

if isinstance(inputs, TrainingInput):
if "InputMode" in inputs.config:
Expand Down Expand Up @@ -1659,6 +1665,7 @@ def __init__(
enable_sagemaker_metrics=None,
profiler_config=None,
disable_profiler=False,
environment=None,
**kwargs,
):
"""Initialize an ``Estimator`` instance.
Expand Down Expand Up @@ -1807,6 +1814,8 @@ def __init__(
``disable_profiler`` parameter to ``True``.
disable_profiler (bool): Specifies whether Debugger monitoring and profiling
will be disabled (default: ``False``).
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)
"""
self.image_uri = image_uri
self.hyperparam_dict = hyperparameters.copy() if hyperparameters else {}
Expand Down Expand Up @@ -1840,6 +1849,7 @@ def __init__(
enable_network_isolation=enable_network_isolation,
profiler_config=profiler_config,
disable_profiler=disable_profiler,
environment=environment,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably also add a doc string for environment in this Estimator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added doc for environment

**kwargs,
)

Expand Down
13 changes: 12 additions & 1 deletion src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@ def train( # noqa: C901
enable_sagemaker_metrics=None,
profiler_rule_configs=None,
profiler_config=None,
environment=None,
):
"""Create an Amazon SageMaker training job.

Expand Down Expand Up @@ -522,9 +523,12 @@ def train( # noqa: C901
Series. For more information see:
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
(default: ``None``).
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
profiler_rule_configs (list[dict]): A list of profiler rule
configurations.src/sagemaker/lineage/artifact.py:285
profiler_config (dict): Configuration for how profiling information is emitted
with SageMaker Profiler. (default: ``None``).
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)

Returns:
str: ARN of the training job, if it is created.
Expand Down Expand Up @@ -556,6 +560,7 @@ def train( # noqa: C901
enable_sagemaker_metrics=enable_sagemaker_metrics,
profiler_rule_configs=profiler_rule_configs,
profiler_config=profiler_config,
environment=environment,
)
LOGGER.info("Creating training-job with name: %s", job_name)
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
Expand Down Expand Up @@ -588,6 +593,7 @@ def _get_train_request( # noqa: C901
enable_sagemaker_metrics=None,
profiler_rule_configs=None,
profiler_config=None,
environment=None,
):
"""Constructs a request compatible for creating an Amazon SageMaker training job.

Expand Down Expand Up @@ -657,6 +663,8 @@ def _get_train_request( # noqa: C901
profiler_rule_configs (list[dict]): A list of profiler rule configurations.
profiler_config(dict): Configuration for how profiling information is emitted with
SageMaker Profiler. (default: ``None``).
environment (dict[str, str]) : Environment variables to be set for
use during training job (default: ``None``)

Returns:
Dict: a training request dict
Expand Down Expand Up @@ -699,6 +707,9 @@ def _get_train_request( # noqa: C901
if hyperparameters and len(hyperparameters) > 0:
train_request["HyperParameters"] = hyperparameters

if environment is not None:
train_request["Environment"] = environment

if tags is not None:
train_request["Tags"] = tags

Expand Down
8 changes: 8 additions & 0 deletions tests/integ/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
PARAMETER_SERVER_DISTRIBUTION = {"parameter_server": {"enabled": True}}
MPI_DISTRIBUTION = {"mpi": {"enabled": True}}
TAGS = [{"Key": "some-key", "Value": "some-value"}]
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}


def test_mnist_with_checkpoint_config(
Expand All @@ -59,6 +60,7 @@ def test_mnist_with_checkpoint_config(
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
checkpoint_s3_uri=checkpoint_s3_uri,
checkpoint_local_path=checkpoint_local_path,
environment=ENV_INPUT,
)
inputs = estimator.sagemaker_session.upload_data(
path=os.path.join(MNIST_RESOURCE_PATH, "data"), key_prefix="scriptmode/mnist"
Expand All @@ -82,7 +84,13 @@ def test_mnist_with_checkpoint_config(
actual_training_checkpoint_config = sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=training_job_name
)["CheckpointConfig"]
actual_training_environment_variable_config = (
sagemaker_session.sagemaker_client.describe_training_job(TrainingJobName=training_job_name)[
"Environment"
]
)
assert actual_training_checkpoint_config == expected_training_checkpoint_config
assert actual_training_environment_variable_config == ENV_INPUT


def test_server_side_encryption(sagemaker_session, tf_full_version, tf_full_py_version):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/sagemaker/huggingface/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def _create_train_job(version, base_framework_version):
"tags": None,
"vpc_config": None,
"metric_definitions": None,
"environment": None,
"experiment_config": None,
"debugger_hook_config": {
"CollectionConfigurations": [],
Expand Down
1 change: 1 addition & 0 deletions tests/unit/sagemaker/tensorflow/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def _create_train_job(tf_version, horovod=False, ps=False, py_version="py2", smd
"tags": None,
"vpc_config": None,
"metric_definitions": None,
"environment": None,
"experiment_config": None,
"profiler_rule_configs": [
{
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/sagemaker/tensorflow/test_estimator_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

REGION = "us-west-2"

ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}


@pytest.fixture()
def sagemaker_session():
Expand Down Expand Up @@ -68,6 +70,26 @@ def test_framework_name(sagemaker_session):
assert tf._framework_name == "tensorflow"


def test_tf_add_environment_variables(sagemaker_session):
tf = _build_tf(
sagemaker_session,
framework_version="1.15.2",
py_version="py3",
environment=ENV_INPUT,
)
assert tf.environment == ENV_INPUT


def test_tf_miss_environment_variables(sagemaker_session):
tf = _build_tf(
sagemaker_session,
framework_version="1.15.2",
py_version="py3",
environment=None,
)
assert not tf.environment


def test_enable_sm_metrics(sagemaker_session):
tf = _build_tf(
sagemaker_session,
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_chainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def _create_train_job(version, py_version):
"tags": None,
"vpc_config": None,
"metric_definitions": None,
"environment": None,
"experiment_config": None,
"debugger_hook_config": {
"CollectionConfigurations": [],
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/test_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
CODECOMMIT_REPO_SSH = "ssh://git-codecommit.us-west-2.amazonaws.com/v1/repos/test-repo/"
CODECOMMIT_BRANCH = "master"
REPO_DIR = "/tmp/repo_dir"
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}

DESCRIBE_TRAINING_JOB_RESULT = {"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA}}

Expand Down Expand Up @@ -241,6 +242,7 @@ def test_framework_all_init_args(sagemaker_session):
checkpoint_local_path="file://local/checkpoint",
enable_sagemaker_metrics=True,
enable_network_isolation=True,
environment=ENV_INPUT,
)
_TrainingJob.start_new(f, "s3://mydata", None)
sagemaker_session.train.assert_called_once()
Expand Down Expand Up @@ -275,6 +277,7 @@ def test_framework_all_init_args(sagemaker_session):
},
"metric_definitions": [{"Name": "validation-rmse", "Regex": "validation-rmse=(\\d+)"}],
"encrypt_inter_container_traffic": True,
"environment": {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"},
"experiment_config": None,
"checkpoint_s3_uri": "s3://bucket/checkpoint",
"checkpoint_local_path": "file://local/checkpoint",
Expand Down Expand Up @@ -1085,6 +1088,7 @@ def test_framework_with_spot_and_checkpoints(sagemaker_session):
"use_spot_instances": True,
"checkpoint_s3_uri": "s3://mybucket/checkpoints/",
"checkpoint_local_path": "/tmp/checkpoints",
"environment": None,
"experiment_config": None,
}

Expand Down Expand Up @@ -2389,6 +2393,7 @@ def test_unsupported_type_in_dict():
"tags": None,
"vpc_config": None,
"metric_definitions": None,
"environment": None,
"experiment_config": None,
}

Expand Down Expand Up @@ -2678,6 +2683,24 @@ def test_generic_to_fit_with_sagemaker_metrics_missing(sagemaker_session):
assert "enable_sagemaker_metrics" not in args


def test_add_environment_variables_to_train_args(sagemaker_session):
e = Estimator(
IMAGE_URI,
ROLE,
INSTANCE_COUNT,
INSTANCE_TYPE,
output_path=OUTPUT_PATH,
sagemaker_session=sagemaker_session,
environment=ENV_INPUT,
)

e.fit()

sagemaker_session.train.assert_called_once()
args = sagemaker_session.train.call_args[1]
assert args["environment"] == ENV_INPUT


def test_generic_to_fit_with_sagemaker_metrics_enabled(sagemaker_session):
e = Estimator(
IMAGE_URI,
Expand Down
35 changes: 35 additions & 0 deletions tests/unit/test_mxnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@

MODEL_PKG_RESPONSE = {"ModelPackageArn": "arn:model-pkg-arn"}

ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}


@pytest.fixture()
def sagemaker_session():
Expand Down Expand Up @@ -144,6 +146,7 @@ def _get_train_args(job_name):
"tags": None,
"vpc_config": None,
"metric_definitions": None,
"environment": None,
"experiment_config": None,
"debugger_hook_config": {
"CollectionConfigurations": [],
Expand Down Expand Up @@ -959,6 +962,38 @@ def test_create_model_with_custom_hosting_image(sagemaker_session):
assert model.image_uri == custom_hosting_image


def test_mx_add_environment_variables(
sagemaker_session, mxnet_training_version, mxnet_training_py_version
):
mx = MXNet(
entry_point=SCRIPT_PATH,
framework_version=mxnet_training_version,
py_version=mxnet_training_py_version,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
environment=ENV_INPUT,
)
assert mx.environment == ENV_INPUT


def test_mx_missing_environment_variables(
sagemaker_session, mxnet_training_version, mxnet_training_py_version
):
mx = MXNet(
entry_point=SCRIPT_PATH,
framework_version=mxnet_training_version,
py_version=mxnet_training_py_version,
role=ROLE,
sagemaker_session=sagemaker_session,
instance_count=INSTANCE_COUNT,
instance_type=INSTANCE_TYPE,
environment=None,
)
assert not mx.environment


def test_mx_enable_sm_metrics(sagemaker_session, mxnet_training_version, mxnet_training_py_version):
mx = MXNet(
entry_point=SCRIPT_PATH,
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/test_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@

ENDPOINT_CONFIG_DESC = {"ProductionVariants": [{"ModelName": "model-1"}, {"ModelName": "model-2"}]}

ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}

LIST_TAGS_RESULT = {"Tags": [{"Key": "TagtestKey", "Value": "TagtestValue"}]}

EXPERIMENT_CONFIG = {
Expand Down Expand Up @@ -146,6 +148,7 @@ def _create_train_job(version, py_version):
"tags": None,
"vpc_config": None,
"metric_definitions": None,
"environment": None,
"experiment_config": None,
"debugger_hook_config": {
"CollectionConfigurations": [],
Expand Down Expand Up @@ -637,6 +640,30 @@ def test_pt_disable_sm_metrics(
assert not pytorch.enable_sagemaker_metrics


def test_pt_add_environment_variables(
sagemaker_session, pytorch_training_version, pytorch_training_py_version
):
pytorch = _pytorch_estimator(
sagemaker_session,
framework_version=pytorch_training_version,
py_version=pytorch_training_py_version,
environment=ENV_INPUT,
)
assert pytorch.environment


def test_pt_miss_environment_variables(
sagemaker_session, pytorch_training_version, pytorch_training_py_version
):
pytorch = _pytorch_estimator(
sagemaker_session,
framework_version=pytorch_training_version,
py_version=pytorch_training_py_version,
environment=None,
)
assert not pytorch.environment


def test_pt_default_sm_metrics(
sagemaker_session, pytorch_training_version, pytorch_training_py_version
):
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_rl.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def _create_train_job(toolkit, toolkit_version, framework):
{"Name": "reward-training", "Regex": "^Training>.*Total reward=(.*?),"},
{"Name": "reward-testing", "Regex": "^Testing>.*Total reward=(.*?),"},
],
"environment": None,
"experiment_config": None,
"debugger_hook_config": {
"CollectionConfigurations": [],
Expand Down
Loading