Skip to content

Commit 47c86ae

Browse files
author
Anton Repushko
committed
feature: support for environment variables in the HPO
1 parent 25c49d4 commit 47c86ae

File tree

4 files changed

+30
-0
lines changed

4 files changed

+30
-0
lines changed

src/sagemaker/session.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2150,6 +2150,7 @@ def tune( # noqa: C901
21502150
checkpoint_s3_uri=None,
21512151
checkpoint_local_path=None,
21522152
random_seed=None,
2153+
environment=None,
21532154
):
21542155
"""Create an Amazon SageMaker hyperparameter tuning job.
21552156
@@ -2233,6 +2234,8 @@ def tune( # noqa: C901
22332234
random_seed (int): An initial value used to initialize a pseudo-random number generator.
22342235
Setting a random seed will make the hyperparameter tuning search strategies to
22352236
produce more consistent configurations for the same tuning job. (default: ``None``).
2237+
environment (dict[str, str]) : Environment variables to be set for
2238+
use during training jobs (default: ``None``)
22362239
"""
22372240

22382241
tune_request = {
@@ -2265,6 +2268,7 @@ def tune( # noqa: C901
22652268
use_spot_instances=use_spot_instances,
22662269
checkpoint_s3_uri=checkpoint_s3_uri,
22672270
checkpoint_local_path=checkpoint_local_path,
2271+
environment=environment,
22682272
),
22692273
}
22702274

@@ -2508,6 +2512,7 @@ def _map_training_config(
25082512
checkpoint_s3_uri=None,
25092513
checkpoint_local_path=None,
25102514
max_retry_attempts=None,
2515+
environment=None,
25112516
):
25122517
"""Construct a dictionary of training job configuration from the arguments.
25132518
@@ -2562,6 +2567,8 @@ def _map_training_config(
25622567
parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can
25632568
be one of three types: Continuous, Integer, or Categorical.
25642569
max_retry_attempts (int): The number of times to retry the job.
2570+
environment (dict[str, str]) : Environment variables to be set for
2571+
use during training jobs (default: ``None``)
25652572
25662573
Returns:
25672574
A dictionary of training job configuration. For format details, please refer to
@@ -2624,6 +2631,9 @@ def _map_training_config(
26242631

26252632
if max_retry_attempts is not None:
26262633
training_job_definition["RetryStrategy"] = {"MaximumRetryAttempts": max_retry_attempts}
2634+
2635+
if environment is not None:
2636+
training_job_definition["Environment"] = environment
26272637
return training_job_definition
26282638

26292639
def stop_tuning_job(self, name):

src/sagemaker/tuner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1892,6 +1892,9 @@ def _prepare_training_config(
18921892
if estimator.max_retry_attempts is not None:
18931893
training_config["max_retry_attempts"] = estimator.max_retry_attempts
18941894

1895+
if estimator.environment is not None:
1896+
training_config["environment"] = estimator.environment
1897+
18951898
return training_config
18961899

18971900
def stop(self):

tests/unit/test_session.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,7 @@ def test_train_pack_to_request(sagemaker_session):
911911
"OutputDataConfig": SAMPLE_OUTPUT,
912912
"ResourceConfig": RESOURCE_CONFIG,
913913
"StoppingCondition": SAMPLE_STOPPING_CONDITION,
914+
"Environment": ENV_INPUT,
914915
},
915916
}
916917

@@ -937,6 +938,7 @@ def test_train_pack_to_request(sagemaker_session):
937938
"OutputDataConfig": SAMPLE_OUTPUT,
938939
"ResourceConfig": RESOURCE_CONFIG,
939940
"StoppingCondition": SAMPLE_STOPPING_CONDITION,
941+
"Environment": ENV_INPUT,
940942
},
941943
{
942944
"DefinitionName": "estimator_2",
@@ -953,6 +955,7 @@ def test_train_pack_to_request(sagemaker_session):
953955
"OutputDataConfig": SAMPLE_OUTPUT,
954956
"ResourceConfig": RESOURCE_CONFIG,
955957
"StoppingCondition": SAMPLE_STOPPING_CONDITION,
958+
"Environment": ENV_INPUT,
956959
},
957960
],
958961
}
@@ -1009,6 +1012,7 @@ def assert_create_tuning_job_request(**kwrags):
10091012
warm_start_config=WarmStartConfig(
10101013
warm_start_type=WarmStartTypes(warm_start_type), parents=parents
10111014
).to_input_req(),
1015+
environment=ENV_INPUT,
10121016
)
10131017

10141018

@@ -1094,6 +1098,7 @@ def assert_create_tuning_job_request(**kwrags):
10941098
"output_config": SAMPLE_OUTPUT,
10951099
"resource_config": RESOURCE_CONFIG,
10961100
"stop_condition": SAMPLE_STOPPING_CONDITION,
1101+
"environment": ENV_INPUT,
10971102
},
10981103
tags=None,
10991104
warm_start_config=None,
@@ -1135,6 +1140,7 @@ def assert_create_tuning_job_request(**kwrags):
11351140
"objective_type": "Maximize",
11361141
"objective_metric_name": "val-score",
11371142
"parameter_ranges": SAMPLE_PARAM_RANGES,
1143+
"environment": ENV_INPUT,
11381144
},
11391145
{
11401146
"static_hyperparameters": STATIC_HPs_2,
@@ -1150,6 +1156,7 @@ def assert_create_tuning_job_request(**kwrags):
11501156
"objective_type": "Maximize",
11511157
"objective_metric_name": "value-score",
11521158
"parameter_ranges": SAMPLE_PARAM_RANGES_2,
1159+
"environment": ENV_INPUT,
11531160
},
11541161
],
11551162
tags=None,
@@ -1190,6 +1197,7 @@ def assert_create_tuning_job_request(**kwrags):
11901197
stop_condition=SAMPLE_STOPPING_CONDITION,
11911198
tags=None,
11921199
warm_start_config=None,
1200+
environment=ENV_INPUT,
11931201
)
11941202

11951203

@@ -1231,6 +1239,7 @@ def assert_create_tuning_job_request(**kwrags):
12311239
tags=None,
12321240
warm_start_config=None,
12331241
strategy_config=SAMPLE_HYPERBAND_STRATEGY_CONFIG,
1242+
environment=ENV_INPUT,
12341243
)
12351244

12361245

tests/unit/tuner_test_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969
ESTIMATOR_NAME = "estimator_name"
7070
ESTIMATOR_NAME_TWO = "estimator_name_two"
7171

72+
ENV_INPUT = {"env_key1": "env_val1", "env_key2": "env_val2", "env_key3": "env_val3"}
73+
7274
SAGEMAKER_SESSION = Mock()
7375

7476
ESTIMATOR = Estimator(
@@ -78,13 +80,15 @@
7880
INSTANCE_TYPE,
7981
output_path="s3://bucket/prefix",
8082
sagemaker_session=SAGEMAKER_SESSION,
83+
environment=ENV_INPUT,
8184
)
8285
ESTIMATOR_TWO = PCA(
8386
ROLE,
8487
INSTANCE_COUNT,
8588
INSTANCE_TYPE,
8689
NUM_COMPONENTS,
8790
sagemaker_session=SAGEMAKER_SESSION,
91+
environment=ENV_INPUT,
8892
)
8993

9094
WARM_START_CONFIG = WarmStartConfig(
@@ -148,6 +152,7 @@
148152
],
149153
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
150154
"OutputDataConfig": {"S3OutputPath": BUCKET_NAME},
155+
"Environment": ENV_INPUT,
151156
},
152157
"TrainingJobCounters": {
153158
"ClientError": 0,
@@ -212,6 +217,7 @@
212217
],
213218
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
214219
"OutputDataConfig": {"S3OutputPath": BUCKET_NAME},
220+
"Environment": ENV_INPUT,
215221
},
216222
{
217223
"DefinitionName": ESTIMATOR_NAME_TWO,
@@ -252,6 +258,7 @@
252258
],
253259
"StoppingCondition": {"MaxRuntimeInSeconds": 86400},
254260
"OutputDataConfig": {"S3OutputPath": BUCKET_NAME},
261+
"Environment": ENV_INPUT,
255262
},
256263
],
257264
"TrainingJobCounters": {
@@ -291,6 +298,7 @@
291298
"OutputDataConfig": {"KmsKeyId": "", "S3OutputPath": "s3://place/output/neo"},
292299
"TrainingJobOutput": {"S3TrainingJobOutput": "s3://here/output.tar.gz"},
293300
"ModelArtifacts": {"S3ModelArtifacts": MODEL_DATA},
301+
"Environment": ENV_INPUT,
294302
}
295303

296304
ENDPOINT_DESC = {"EndpointConfigName": "test-endpoint"}

0 commit comments

Comments
 (0)