Skip to content

Commit edf9cba

Browse files
committed
support parameterize tuning job parameter ranges
1 parent dad08c4 commit edf9cba

File tree

6 files changed

+130
-14
lines changed

6 files changed

+130
-14
lines changed

src/sagemaker/parameter.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
# language governing permissions and limitations under the License.
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
15+
1516
import json
17+
from sagemaker.workflow.parameters import Parameter as PipelineParameter
1618

1719

1820
class ParameterRange(object):
@@ -68,8 +70,12 @@ def as_tuning_range(self, name):
6870
"""
6971
return {
7072
"Name": name,
71-
"MinValue": str(self.min_value),
72-
"MaxValue": str(self.max_value),
73+
"MinValue": str(self.min_value)
74+
if not isinstance(self.min_value, PipelineParameter)
75+
else self.min_value,
76+
"MaxValue": str(self.max_value)
77+
if not isinstance(self.max_value, PipelineParameter)
78+
else self.max_value,
7379
"ScalingType": self.scaling_type,
7480
}
7581

@@ -103,9 +109,9 @@ def __init__(self, values): # pylint: disable=super-init-not-called
103109
This input will be converted into a list of strings.
104110
"""
105111
if isinstance(values, list):
106-
self.values = [str(v) for v in values]
112+
self.values = [str(v) if not isinstance(v, PipelineParameter) else v for v in values]
107113
else:
108-
self.values = [str(values)]
114+
self.values = [str(values) if not isinstance(values, PipelineParameter) else values]
109115

110116
def as_tuning_range(self, name):
111117
"""Represent the parameter range as a dictionary.

src/sagemaker/session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2210,6 +2210,7 @@ def _map_training_config(
22102210
use_spot_instances=False,
22112211
checkpoint_s3_uri=None,
22122212
checkpoint_local_path=None,
2213+
max_retry_attempts=None,
22132214
):
22142215
"""Construct a dictionary of training job configuration from the arguments.
22152216
@@ -2263,6 +2264,7 @@ def _map_training_config(
22632264
objective_metric_name (str): Name of the metric for evaluating training jobs.
22642265
parameter_ranges (dict): Dictionary of parameter ranges. These parameter ranges can
22652266
be one of three types: Continuous, Integer, or Categorical.
2267+
max_retry_attempts (int): The number of times to retry the job.
22662268
22672269
Returns:
22682270
A dictionary of training job configuration. For format details, please refer to
@@ -2319,6 +2321,8 @@ def _map_training_config(
23192321
if parameter_ranges is not None:
23202322
training_job_definition["HyperParameterRanges"] = parameter_ranges
23212323

2324+
if max_retry_attempts is not None:
2325+
training_job_definition["RetryStrategy"] = {"MaximumRetryAttempts": max_retry_attempts}
23222326
return training_job_definition
23232327

23242328
def stop_tuning_job(self, name):

src/sagemaker/tuner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1507,7 +1507,10 @@ def _get_tuner_args(cls, tuner, inputs):
15071507

15081508
if tuner.estimator is not None:
15091509
tuner_args["training_config"] = cls._prepare_training_config(
1510-
inputs, tuner.estimator, tuner.static_hyperparameters, tuner.metric_definitions
1510+
inputs=inputs,
1511+
estimator=tuner.estimator,
1512+
static_hyperparameters=tuner.static_hyperparameters,
1513+
metric_definitions=tuner.metric_definitions,
15111514
)
15121515

15131516
if tuner.estimator_dict is not None:
@@ -1580,6 +1583,9 @@ def _prepare_training_config(
15801583
if parameter_ranges is not None:
15811584
training_config["parameter_ranges"] = parameter_ranges
15821585

1586+
if estimator.max_retry_attempts is not None:
1587+
training_config["max_retry_attempts"] = estimator.max_retry_attempts
1588+
15831589
return training_config
15841590

15851591
def stop(self):

src/sagemaker/workflow/pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,7 @@ def _interpolate(obj: Union[RequestType, Any], callback_output_to_step_map: Dict
304304
"""
305305
if isinstance(obj, (Expression, Parameter, Properties)):
306306
return obj.expr
307+
307308
if isinstance(obj, CallbackOutput):
308309
step_name = callback_output_to_step_map[obj.output_name]
309310
return obj.expr(step_name)

tests/integ/test_workflow.py

Lines changed: 95 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -885,7 +885,7 @@ def test_conditional_pytorch_training_model_registration(
885885
pass
886886

887887

888-
def test_tuning(
888+
def test_tuning_single_algo(
889889
sagemaker_session,
890890
role,
891891
cpu_instance_type,
@@ -908,14 +908,17 @@ def test_tuning(
908908
role=role,
909909
framework_version="1.5.0",
910910
py_version="py3",
911-
instance_count=1,
912-
instance_type="ml.m5.xlarge",
911+
instance_count=instance_count,
912+
instance_type=instance_type,
913913
sagemaker_session=sagemaker_session,
914914
enable_sagemaker_metrics=True,
915+
max_retry_attempts=3,
915916
)
916917

918+
min_batch_size = ParameterString(name="MinBatchSize", default_value="64")
919+
max_batch_size = ParameterString(name="MaxBatchSize", default_value="128")
917920
hyperparameter_ranges = {
918-
"batch-size": IntegerParameter(64, 128),
921+
"batch-size": IntegerParameter(min_batch_size, max_batch_size),
919922
}
920923

921924
tuner = HyperparameterTuner(
@@ -971,7 +974,7 @@ def test_tuning(
971974

972975
pipeline = Pipeline(
973976
name=pipeline_name,
974-
parameters=[instance_count, instance_type],
977+
parameters=[instance_count, instance_type, min_batch_size, max_batch_size],
975978
steps=[step_tune, step_best_model, step_second_best_model],
976979
sagemaker_session=sagemaker_session,
977980
)
@@ -995,6 +998,93 @@ def test_tuning(
995998
pass
996999

9971000

1001+
def test_tuning_multi_algos(
1002+
sagemaker_session,
1003+
role,
1004+
cpu_instance_type,
1005+
pipeline_name,
1006+
region_name,
1007+
):
1008+
base_dir = os.path.join(DATA_DIR, "pytorch_mnist")
1009+
entry_point = os.path.join(base_dir, "mnist.py")
1010+
input_path = sagemaker_session.upload_data(
1011+
path=os.path.join(base_dir, "training"),
1012+
key_prefix="integ-test-data/pytorch_mnist/training",
1013+
)
1014+
1015+
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
1016+
instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")
1017+
1018+
pytorch_estimator = PyTorch(
1019+
entry_point=entry_point,
1020+
role=role,
1021+
framework_version="1.5.0",
1022+
py_version="py3",
1023+
instance_count=instance_count,
1024+
instance_type=instance_type,
1025+
sagemaker_session=sagemaker_session,
1026+
enable_sagemaker_metrics=True,
1027+
max_retry_attempts=3,
1028+
)
1029+
1030+
min_batch_size = ParameterString(name="MinBatchSize", default_value="64")
1031+
max_batch_size = ParameterString(name="MaxBatchSize", default_value="128")
1032+
1033+
tuner = HyperparameterTuner.create(
1034+
estimator_dict={
1035+
"estimator-1": pytorch_estimator,
1036+
"estimator-2": pytorch_estimator,
1037+
},
1038+
objective_metric_name_dict={
1039+
"estimator-1": "test:acc",
1040+
"estimator-2": "test:acc",
1041+
},
1042+
hyperparameter_ranges_dict={
1043+
"estimator-1": {"batch-size": IntegerParameter(min_batch_size, max_batch_size)},
1044+
"estimator-2": {"batch-size": IntegerParameter(min_batch_size, max_batch_size)},
1045+
},
1046+
metric_definitions_dict={
1047+
"estimator-1": [{"Name": "test:acc", "Regex": "Overall test accuracy: (.*?);"}],
1048+
"estimator-2": [{"Name": "test:acc", "Regex": "Overall test accuracy: (.*?);"}],
1049+
},
1050+
)
1051+
inputs = {
1052+
"estimator-1": TrainingInput(s3_data=input_path),
1053+
"estimator-2": TrainingInput(s3_data=input_path),
1054+
}
1055+
1056+
step_tune = TuningStep(
1057+
name="my-tuning-step",
1058+
tuner=tuner,
1059+
inputs=inputs,
1060+
)
1061+
1062+
pipeline = Pipeline(
1063+
name=pipeline_name,
1064+
parameters=[instance_count, instance_type, min_batch_size, max_batch_size],
1065+
steps=[step_tune],
1066+
sagemaker_session=sagemaker_session,
1067+
)
1068+
1069+
try:
1070+
response = pipeline.create(role)
1071+
create_arn = response["PipelineArn"]
1072+
assert re.match(
1073+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}", create_arn
1074+
)
1075+
1076+
execution = pipeline.start(parameters={})
1077+
assert re.match(
1078+
fr"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
1079+
execution.arn,
1080+
)
1081+
finally:
1082+
try:
1083+
pipeline.delete()
1084+
except Exception:
1085+
pass
1086+
1087+
9981088
def test_mxnet_model_registration(
9991089
sagemaker_session,
10001090
role,

tests/unit/sagemaker/workflow/test_steps.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -715,14 +715,16 @@ def test_multi_algo_tuning_step(sagemaker_session):
715715
data_source_uri_parameter = ParameterString(
716716
name="DataSourceS3Uri", default_value=f"s3://{BUCKET}/train_manifest"
717717
)
718+
instance_count = ParameterInteger(name="InstanceCount", default_value=1)
718719
estimator = Estimator(
719720
image_uri=IMAGE_URI,
720721
role=ROLE,
721-
instance_count=1,
722+
instance_count=instance_count,
722723
instance_type="ml.c5.4xlarge",
723724
profiler_config=ProfilerConfig(system_monitor_interval_millis=500),
724725
rules=[],
725726
sagemaker_session=sagemaker_session,
727+
max_retry_attempts=10,
726728
)
727729

728730
estimator.set_hyperparameters(
@@ -738,8 +740,9 @@ def test_multi_algo_tuning_step(sagemaker_session):
738740
augmentation_type="crop",
739741
)
740742

743+
initial_lr_param = ParameterString(name="InitialLR", default_value="0.0001")
741744
hyperparameter_ranges = {
742-
"learning_rate": ContinuousParameter(0.0001, 0.05),
745+
"learning_rate": ContinuousParameter(initial_lr_param, 0.05),
743746
"momentum": ContinuousParameter(0.0, 0.99),
744747
"weight_decay": ContinuousParameter(0.0, 0.99),
745748
}
@@ -824,7 +827,7 @@ def test_multi_algo_tuning_step(sagemaker_session):
824827
"ContinuousParameterRanges": [
825828
{
826829
"Name": "learning_rate",
827-
"MinValue": "0.0001",
830+
"MinValue": initial_lr_param,
828831
"MaxValue": "0.05",
829832
"ScalingType": "Auto",
830833
},
@@ -844,6 +847,9 @@ def test_multi_algo_tuning_step(sagemaker_session):
844847
"CategoricalParameterRanges": [],
845848
"IntegerParameterRanges": [],
846849
},
850+
"RetryStrategy": {
851+
"MaximumRetryAttempts": 10,
852+
},
847853
},
848854
{
849855
"StaticHyperParameters": {
@@ -888,7 +894,7 @@ def test_multi_algo_tuning_step(sagemaker_session):
888894
"ContinuousParameterRanges": [
889895
{
890896
"Name": "learning_rate",
891-
"MinValue": "0.0001",
897+
"MinValue": initial_lr_param,
892898
"MaxValue": "0.05",
893899
"ScalingType": "Auto",
894900
},
@@ -908,6 +914,9 @@ def test_multi_algo_tuning_step(sagemaker_session):
908914
"CategoricalParameterRanges": [],
909915
"IntegerParameterRanges": [],
910916
},
917+
"RetryStrategy": {
918+
"MaximumRetryAttempts": 10,
919+
},
911920
},
912921
],
913922
},

0 commit comments

Comments
 (0)