Skip to content

Commit cee70dc

Browse files
repushkoAnton Repushko
andauthored
feature: Support for environment variables in the HPO (#3614)
Co-authored-by: Anton Repushko <[email protected]>
1 parent 75d1f2c commit cee70dc

File tree

4 files changed

+30
-0
lines changed

4 files changed

+30
-0
lines changed

src/sagemaker/session.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2200,6 +2200,7 @@ def tune( # noqa: C901
22002200
checkpoint_s3_uri=None,
22012201
checkpoint_local_path=None,
22022202
random_seed=None,
2203+
environment=None,
22032204
):
22042205
"""Create an Amazon SageMaker hyperparameter tuning job.
22052206
@@ -2283,6 +2284,8 @@ def tune( # noqa: C901
22832284
random_seed (int): An initial value used to initialize a pseudo-random number generator.
22842285
Setting a random seed will make the hyperparameter tuning search strategies to
22852286
produce more consistent configurations for the same tuning job. (default: ``None``).
2287+
environment (dict[str, str]) : Environment variables to be set for
2288+
use during training jobs (default: ``None``)
22862289
"""
22872290

22882291
tune_request = {
@@ -2315,6 +2318,7 @@ def tune( # noqa: C901
23152318
use_spot_instances=use_spot_instances,
23162319
checkpoint_s3_uri=checkpoint_s3_uri,
23172320
checkpoint_local_path=checkpoint_local_path,
2321+
environment=environment,
23182322
),
23192323
}
23202324

@@ -2558,6 +2562,7 @@ def _map_training_config(
25582562
checkpoint_s3_uri=None,
25592563
checkpoint_local_path=None,
25602564
max_retry_attempts=None,
2565+
environment=None,
25612566
):
25622567
"""Construct a dictionary of training job configuration from the arguments.
25632568
@@ -2612,6 +2617,8 @@ def _map_training_config(
26122617
parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can
26132618
be one of three types: Continuous, Integer, or Categorical.
26142619
max_retry_attempts (int): The number of times to retry the job.
2620+
environment (dict[str, str]) : Environment variables to be set for
2621+
use during training jobs (default: ``None``)
26152622
26162623
Returns:
26172624
A dictionary of training job configuration. For format details, please refer to
@@ -2674,6 +2681,9 @@ def _map_training_config(
26742681

26752682
if max_retry_attempts is not None:
26762683
training_job_definition["RetryStrategy"] = {"MaximumRetryAttempts": max_retry_attempts}
2684+
2685+
if environment is not None:
2686+
training_job_definition["Environment"] = environment
26772687
return training_job_definition
26782688

26792689
def stop_tuning_job(self, name):

src/sagemaker/tuner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1892,6 +1892,9 @@ def _prepare_training_config(
18921892
if estimator.max_retry_attempts is not None:
18931893
training_config["max_retry_attempts"] = estimator.max_retry_attempts
18941894

1895+
if estimator.environment is not None:
1896+
training_config["environment"] = estimator.environment
1897+
18951898
return training_config
18961899

18971900
def stop(self):

tests/unit/test_session.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -928,6 +928,7 @@ def test_train_pack_to_request(sagemaker_session):
928928
"OutputDataConfig": SAMPLE_OUTPUT,
929929
"ResourceConfig": RESOURCE_CONFIG,
930930
"StoppingCondition": SAMPLE_STOPPING_CONDITION,
931+
"Environment": ENV_INPUT,
931932
},
932933
}
933934

@@ -957,6 +958,7 @@ def test_train_pack_to_request(sagemaker_session):
957958
"OutputDataConfig": SAMPLE_OUTPUT,
958959
"ResourceConfig": RESOURCE_CONFIG,
959960
"StoppingCondition": SAMPLE_STOPPING_CONDITION,
961+
"Environment": ENV_INPUT,
960962
},
961963
{
962964
"DefinitionName": "estimator_2",
@@ -973,6 +975,7 @@ def test_train_pack_to_request(sagemaker_session):
973975
"OutputDataConfig": SAMPLE_OUTPUT,
974976
"ResourceConfig": RESOURCE_CONFIG,
975977
"StoppingCondition": SAMPLE_STOPPING_CONDITION,
978+
"Environment": ENV_INPUT,
976979
},
977980
],
978981
}
@@ -1032,6 +1035,7 @@ def assert_create_tuning_job_request(**kwrags):
10321035
warm_start_config=WarmStartConfig(
10331036
warm_start_type=WarmStartTypes(warm_start_type), parents=parents
10341037
).to_input_req(),
1038+
environment=ENV_INPUT,
10351039
)
10361040

10371041

@@ -1122,6 +1126,7 @@ def assert_create_tuning_job_request(**kwrags):
11221126
"output_config": SAMPLE_OUTPUT,
11231127
"resource_config": RESOURCE_CONFIG,
11241128
"stop_condition": SAMPLE_STOPPING_CONDITION,
1129+
"environment": ENV_INPUT,
11251130
},
11261131
tags=None,
11271132
warm_start_config=None,
@@ -1163,6 +1168,7 @@ def assert_create_tuning_job_request(**kwrags):
11631168
"objective_type": "Maximize",
11641169
"objective_metric_name": "val-score",
11651170
"parameter_ranges": SAMPLE_PARAM_RANGES,
1171+
"environment": ENV_INPUT,
11661172
},
11671173
{
11681174
"static_hyperparameters": STATIC_HPs_2,
@@ -1178,6 +1184,7 @@ def assert_create_tuning_job_request(**kwrags):
11781184
"objective_type": "Maximize",
11791185
"objective_metric_name": "value-score",
11801186
"parameter_ranges": SAMPLE_PARAM_RANGES_2,
1187+
"environment": ENV_INPUT,
11811188
},
11821189
],
11831190
tags=None,
@@ -1218,6 +1225,7 @@ def assert_create_tuning_job_request(**kwrags):
12181225
stop_condition=SAMPLE_STOPPING_CONDITION,
12191226
tags=None,
12201227
warm_start_config=None,
1228+
environment=ENV_INPUT,
12211229
)
12221230

12231231

@@ -1259,6 +1267,7 @@ def assert_create_tuning_job_request(**kwrags):
12591267
tags=None,
12601268
warm_start_config=None,
12611269
strategy_config=SAMPLE_HYPERBAND_STRATEGY_CONFIG,
1270+
environment=ENV_INPUT,
12621271
)
12631272

12641273

tests/unit/tuner_test_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969
ESTIMATOR_NAME = "estimator_name"
7070
ESTIMATOR_NAME_TWO = "estimator_name_two"
7171

72+
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
73+
7274
SAGEMAKER_SESSION = Mock()
7375

7476
ESTIMATOR = Estimator(
@@ -78,13 +80,15 @@
7880
INSTANCE_TYPE,
7981
output_path="s3://bucket/prefix",
8082
sagemaker_session=SAGEMAKER_SESSION,
83+
environment=ENV_INPUT,
8184
)
8285
ESTIMATOR_TWO = PCA(
8386
ROLE,
8487
INSTANCE_COUNT,
8588
INSTANCE_TYPE,
8689
NUM_COMPONENTS,
8790
sagemaker_session=SAGEMAKER_SESSION,
91+
environment=ENV_INPUT,
8892
)
8993

9094
WARM_START_CONFIG = WarmStartConfig(
@@ -148,6 +152,7 @@
148152
],
149153
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
150154
"OutputDataConfig": {"S3OutputPath": BUCKET_NAME},
155+
"Environment": ENV_INPUT,
151156
},
152157
"TrainingJobCounters": {
153158
"ClientError": 0,
@@ -212,6 +217,7 @@
212217
],
213218
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
214219
"OutputDataConfig": {"S3OutputPath": BUCKET_NAME},
220+
"Environment": ENV_INPUT,
215221
},
216222
{
217223
"DefinitionName": ESTIMATOR_NAME_TWO,
@@ -252,6 +258,7 @@
252258
],
253259
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
254260
"OutputDataConfig": {"S3OutputPath": BUCKET_NAME},
261+
"Environment": ENV_INPUT,
255262
},
256263
],
257264
"TrainingJobCounters": {
@@ -291,6 +298,7 @@
291298
"OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"},
292299
"TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"},
293300
"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA},
301+
"Environment": ENV_INPUT,
294302
}
295303

296304
ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"}

0 commit comments

Comments
 (0)