Skip to content

feature: Support for environment variables in the HPO #3614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,6 +2150,7 @@ def tune( # noqa: C901
checkpoint_s3_uri=None,
checkpoint_local_path=None,
random_seed=None,
environment=None,
):
"""Create an Amazon SageMaker hyperparameter tuning job.

Expand Down Expand Up @@ -2233,6 +2234,8 @@ def tune( # noqa: C901
random_seed (int): An initial value used to initialize a pseudo-random number generator.
Setting a random seed will make the hyperparameter tuning search strategies to
produce more consistent configurations for the same tuning job. (default: ``None``).
environment (dict[str, str]) : Environment variables to be set for
use during training jobs (default: ``None``)
"""

tune_request = {
Expand Down Expand Up @@ -2265,6 +2268,7 @@ def tune( # noqa: C901
use_spot_instances=use_spot_instances,
checkpoint_s3_uri=checkpoint_s3_uri,
checkpoint_local_path=checkpoint_local_path,
environment=environment,
),
}

Expand Down Expand Up @@ -2508,6 +2512,7 @@ def _map_training_config(
checkpoint_s3_uri=None,
checkpoint_local_path=None,
max_retry_attempts=None,
environment=None,
):
"""Construct a dictionary of training job configuration from the arguments.

Expand Down Expand Up @@ -2562,6 +2567,8 @@ def _map_training_config(
parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can
be one of three types: Continuous, Integer, or Categorical.
max_retry_attempts (int): The number of times to retry the job.
environment (dict[str, str]) : Environment variables to be set for
use during training jobs (default: ``None``)

Returns:
A dictionary of training job configuration. For format details, please refer to
Expand Down Expand Up @@ -2624,6 +2631,9 @@ def _map_training_config(

if max_retry_attempts is not None:
training_job_definition["RetryStrategy"] = {"MaximumRetryAttempts": max_retry_attempts}

if environment is not None:
training_job_definition["Environment"] = environment
return training_job_definition

def stop_tuning_job(self, name):
Expand Down
3 changes: 3 additions & 0 deletions src/sagemaker/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1892,6 +1892,9 @@ def _prepare_training_config(
if estimator.max_retry_attempts is not None:
training_config["max_retry_attempts"] = estimator.max_retry_attempts

if estimator.environment is not None:
training_config["environment"] = estimator.environment

return training_config

def stop(self):
Expand Down
9 changes: 9 additions & 0 deletions tests/unit/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,6 +911,7 @@ def test_train_pack_to_request(sagemaker_session):
"OutputDataConfig": SAMPLE_OUTPUT,
"ResourceConfig": RESOURCE_CONFIG,
"StoppingCondition": SAMPLE_STOPPING_CONDITION,
"Environment": ENV_INPUT,
},
}

Expand All @@ -937,6 +938,7 @@ def test_train_pack_to_request(sagemaker_session):
"OutputDataConfig": SAMPLE_OUTPUT,
"ResourceConfig": RESOURCE_CONFIG,
"StoppingCondition": SAMPLE_STOPPING_CONDITION,
"Environment": ENV_INPUT,
},
{
"DefinitionName": "estimator_2",
Expand All @@ -953,6 +955,7 @@ def test_train_pack_to_request(sagemaker_session):
"OutputDataConfig": SAMPLE_OUTPUT,
"ResourceConfig": RESOURCE_CONFIG,
"StoppingCondition": SAMPLE_STOPPING_CONDITION,
"Environment": ENV_INPUT,
},
],
}
Expand Down Expand Up @@ -1009,6 +1012,7 @@ def assert_create_tuning_job_request(**kwrags):
warm_start_config=WarmStartConfig(
warm_start_type=WarmStartTypes(warm_start_type), parents=parents
).to_input_req(),
environment=ENV_INPUT,
)


Expand Down Expand Up @@ -1094,6 +1098,7 @@ def assert_create_tuning_job_request(**kwrags):
"output_config": SAMPLE_OUTPUT,
"resource_config": RESOURCE_CONFIG,
"stop_condition": SAMPLE_STOPPING_CONDITION,
"environment": ENV_INPUT,
},
tags=None,
warm_start_config=None,
Expand Down Expand Up @@ -1135,6 +1140,7 @@ def assert_create_tuning_job_request(**kwrags):
"objective_type": "Maximize",
"objective_metric_name": "val-score",
"parameter_ranges": SAMPLE_PARAM_RANGES,
"environment": ENV_INPUT,
},
{
"static_hyperparameters": STATIC_HPs_2,
Expand All @@ -1150,6 +1156,7 @@ def assert_create_tuning_job_request(**kwrags):
"objective_type": "Maximize",
"objective_metric_name": "value-score",
"parameter_ranges": SAMPLE_PARAM_RANGES_2,
"environment": ENV_INPUT,
},
],
tags=None,
Expand Down Expand Up @@ -1190,6 +1197,7 @@ def assert_create_tuning_job_request(**kwrags):
stop_condition=SAMPLE_STOPPING_CONDITION,
tags=None,
warm_start_config=None,
environment=ENV_INPUT,
)


Expand Down Expand Up @@ -1231,6 +1239,7 @@ def assert_create_tuning_job_request(**kwrags):
tags=None,
warm_start_config=None,
strategy_config=SAMPLE_HYPERBAND_STRATEGY_CONFIG,
environment=ENV_INPUT,
)


Expand Down
8 changes: 8 additions & 0 deletions tests/unit/tuner_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
ESTIMATOR_NAME = "estimator_name"
ESTIMATOR_NAME_TWO = "estimator_name_two"

ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}

SAGEMAKER_SESSION = Mock()

ESTIMATOR = Estimator(
Expand All @@ -78,13 +80,15 @@
INSTANCE_TYPE,
output_path="s3://bucket/prefix",
sagemaker_session=SAGEMAKER_SESSION,
environment=ENV_INPUT,
)
ESTIMATOR_TWO = PCA(
ROLE,
INSTANCE_COUNT,
INSTANCE_TYPE,
NUM_COMPONENTS,
sagemaker_session=SAGEMAKER_SESSION,
environment=ENV_INPUT,
)

WARM_START_CONFIG = WarmStartConfig(
Expand Down Expand Up @@ -148,6 +152,7 @@
],
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
"OutputDataConfig": {"S3OutputPath": BUCKET_NAME},
"Environment": ENV_INPUT,
},
"TrainingJobCounters": {
"ClientError": 0,
Expand Down Expand Up @@ -212,6 +217,7 @@
],
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
"OutputDataConfig": {"S3OutputPath": BUCKET_NAME},
"Environment": ENV_INPUT,
},
{
"DefinitionName": ESTIMATOR_NAME_TWO,
Expand Down Expand Up @@ -252,6 +258,7 @@
],
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
"OutputDataConfig": {"S3OutputPath": BUCKET_NAME},
"Environment": ENV_INPUT,
},
],
"TrainingJobCounters": {
Expand Down Expand Up @@ -291,6 +298,7 @@
"OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"},
"TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"},
"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA},
"Environment": ENV_INPUT,
}

ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"}
Expand Down