Skip to content

Commit 3f29553

Browse files
committed
feat: jumpstart instance specific hyperparameters
1 parent 7f6f3f9 commit 3f29553

File tree

8 files changed

+610
-11
lines changed

8 files changed

+610
-11
lines changed

src/sagemaker/hyperparameters.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def retrieve_default(
3131
region: Optional[str] = None,
3232
model_id: Optional[str] = None,
3333
model_version: Optional[str] = None,
34+
instance_type: Optional[str] = None,
3435
include_container_hyperparameters: bool = False,
3536
tolerate_vulnerable_model: bool = False,
3637
tolerate_deprecated_model: bool = False,
@@ -75,12 +76,13 @@ def retrieve_default(
7576
)
7677

7778
return artifacts._retrieve_default_hyperparameters(
78-
model_id,
79-
model_version,
80-
region,
81-
include_container_hyperparameters,
82-
tolerate_vulnerable_model,
83-
tolerate_deprecated_model,
79+
model_id=model_id,
80+
model_version=model_version,
81+
instance_type=instance_type,
82+
region=region,
83+
include_container_hyperparameters=include_container_hyperparameters,
84+
tolerate_vulnerable_model=tolerate_vulnerable_model,
85+
tolerate_deprecated_model=tolerate_deprecated_model,
8486
sagemaker_session=sagemaker_session,
8587
)
8688

src/sagemaker/jumpstart/artifacts/hyperparameters.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def _retrieve_default_hyperparameters(
3535
tolerate_vulnerable_model: bool = False,
3636
tolerate_deprecated_model: bool = False,
3737
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
38+
instance_type: Optional[str] = None,
3839
):
3940
"""Retrieves the training hyperparameters for the model matching the given arguments.
4041
@@ -86,4 +87,19 @@ def _retrieve_default_hyperparameters(
8687
include_container_hyperparameters and hyperparameter.scope == VariableScope.CONTAINER
8788
) or hyperparameter.scope == VariableScope.ALGORITHM:
8889
default_hyperparameters[hyperparameter.name] = str(hyperparameter.default)
90+
91+
instance_specific_hyperparameters = (
92+
model_specs.training_instance_type_variants.get_instance_specific_hyperparameters(
93+
instance_type
94+
)
95+
if instance_type
96+
and getattr(model_specs, "training_instance_type_variants", None) is not None
97+
else []
98+
)
99+
100+
for instance_specific_hyperparameter in instance_specific_hyperparameters:
101+
default_hyperparameters[instance_specific_hyperparameter.name] = str(
102+
instance_specific_hyperparameter.default
103+
)
104+
89105
return default_hyperparameters

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ def _add_hyperparameters_to_kwargs(
597597
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
598598
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
599599
sagemaker_session=kwargs.sagemaker_session,
600+
instance_type=kwargs.instance_type,
600601
)
601602

602603
for key, value in default_hyperparameters.items():

src/sagemaker/jumpstart/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,7 @@ def _is_valid_model_id_hook():
314314

315315
super(JumpStartModel, self).__init__(**model_init_kwargs.to_kwargs_dict())
316316

317-
def retrieve_all_examples(self) -> Optional[List[JumpStartSerializablePayload]]:
317+
def retrieve_all_example_payloads(self) -> Optional[List[JumpStartSerializablePayload]]:
318318
"""Returns all example payloads associated with the model.
319319
320320
Raises:

src/sagemaker/jumpstart/types.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,50 @@ def to_json(self) -> Dict[str, Any]:
402402
json_obj = {att: getattr(self, att) for att in self.__slots__ if hasattr(self, att)}
403403
return json_obj
404404

405+
def get_instance_specific_hyperparameters(
406+
self, instance_type: str
407+
) -> List[JumpStartHyperparameter]:
408+
"""Returns instance specific hyperparameters.
409+
410+
Returns empty list if a model, instance type tuple does not have specific
411+
hyperparameters.
412+
"""
413+
414+
if self.variants is None:
415+
return []
416+
417+
instance_specific_hyperparameters: List[JumpStartHyperparameter] = [
418+
JumpStartHyperparameter(json)
419+
for json in self.variants.get(instance_type, {})
420+
.get("properties", {})
421+
.get("hyperparameters", [])
422+
]
423+
424+
instance_type_family = get_instance_type_family(instance_type)
425+
426+
instance_family_hyperparameters: List[JumpStartHyperparameter] = [
427+
JumpStartHyperparameter(json)
428+
for json in (
429+
self.variants.get(instance_type_family, {})
430+
.get("properties", {})
431+
.get("hyperparameters", [])
432+
if instance_type_family not in {"", None}
433+
else []
434+
)
435+
]
436+
437+
instance_specific_hyperparameter_names = {
438+
hyperparameter.name for hyperparameter in instance_specific_hyperparameters
439+
}
440+
441+
hyperparams_to_return = deepcopy(instance_specific_hyperparameters)
442+
443+
for hyperparameter in instance_family_hyperparameters:
444+
if hyperparameter.name not in instance_specific_hyperparameter_names:
445+
hyperparams_to_return.append(hyperparameter)
446+
447+
return hyperparams_to_return
448+
405449
def get_instance_specific_environment_variables(self, instance_type: str) -> Dict[str, str]:
406450
"""Returns instance specific environment variables.
407451

tests/unit/sagemaker/hyperparameters/jumpstart/test_default.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from sagemaker import hyperparameters
2121

22-
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec
22+
from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec
2323

2424

2525
mock_client = boto3.client("s3")
@@ -116,3 +116,74 @@ def test_jumpstart_default_hyperparameters(patched_get_model_specs):
116116
hyperparameters.retrieve_default(
117117
model_id=model_id,
118118
)
119+
120+
121+
@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
122+
def test_jumpstart_sdk_hyperparameters_instance_type_overrides(patched_get_model_specs):
123+
124+
patched_get_model_specs.side_effect = get_special_model_spec
125+
126+
model_id = "variant-model"
127+
region = "us-west-2"
128+
129+
# assert that we can add hyperparameters to default
130+
vars = hyperparameters.retrieve_default(
131+
region=region,
132+
model_id=model_id,
133+
model_version="*",
134+
sagemaker_session=mock_session,
135+
instance_type="ml.p2.48xlarge",
136+
)
137+
assert vars == {
138+
"adam-learning-rate": "0.05",
139+
"batch-size": "4",
140+
"epochs": "3",
141+
"num_bag_sets": "5",
142+
"num_stack_levels": "6",
143+
"refit_full": "False",
144+
"sagemaker_container_log_level": "20",
145+
"sagemaker_program": "transfer_learning.py",
146+
"sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz",
147+
"save_space": "False",
148+
"set_best_to_refit_full": "False",
149+
"verbosity": "2",
150+
}
151+
152+
# assert that we can override default environment variables (instance family + instance type
153+
# specific)
154+
vars = hyperparameters.retrieve_default(
155+
region=region,
156+
model_id=model_id,
157+
model_version="*",
158+
sagemaker_session=mock_session,
159+
instance_type="ml.p2.12xlarge",
160+
)
161+
assert vars == {
162+
"adam-learning-rate": "0.05",
163+
"batch-size": "1",
164+
"epochs": "3",
165+
"num_bag_sets": "1",
166+
"num_stack_levels": "0",
167+
"refit_full": "False",
168+
"eval_metric": "auto",
169+
"num_bag_folds": "0",
170+
"presets": "medium_quality",
171+
"auto_stack": "False",
172+
"sagemaker_container_log_level": "20",
173+
"sagemaker_program": "transfer_learning.py",
174+
"sagemaker_submit_directory": "/opt/ml/input/data/code/sourcedir.tar.gz",
175+
"save_space": "False",
176+
"set_best_to_refit_full": "False",
177+
"verbosity": "2",
178+
}
179+
180+
# assert that we can return default hyperparameters for unrecognized instance
181+
vars = hyperparameters.retrieve_default(
182+
region=region,
183+
model_id=model_id,
184+
model_version="*",
185+
sagemaker_session=mock_session,
186+
instance_type="ml.p9999.48xlarge",
187+
)
188+
189+
assert vars == {"epochs": "3", "adam-learning-rate": "0.05", "batch-size": "4"}

tests/unit/sagemaker/jumpstart/constants.py

Lines changed: 135 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,141 @@
214214
"framework_version": "1.5.0",
215215
"py_version": "py3",
216216
},
217-
"training_instance_type_variants": None,
217+
"training_instance_type_variants": {
218+
"variants": {
219+
"ml.p2.12xlarge": {
220+
"properties": {
221+
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"},
222+
"hyperparameters": [
223+
{
224+
"name": "eval_metric",
225+
"type": "text",
226+
"default": "auto",
227+
"scope": "algorithm",
228+
},
229+
{
230+
"name": "presets",
231+
"type": "text",
232+
"default": "medium_quality",
233+
"options": [
234+
"best_quality",
235+
"high_quality",
236+
"good_quality",
237+
"medium_quality",
238+
"optimize_for_deployment",
239+
"interpretable",
240+
],
241+
"scope": "algorithm",
242+
},
243+
{
244+
"name": "auto_stack",
245+
"type": "text",
246+
"default": "False",
247+
"options": ["True", "False"],
248+
"scope": "algorithm",
249+
},
250+
{
251+
"name": "num_bag_folds",
252+
"type": "text",
253+
"default": "0",
254+
"options": ["0", "2", "3", "4", "5", "6", "7", "8", "9", "10"],
255+
"scope": "algorithm",
256+
},
257+
{
258+
"name": "num_bag_sets",
259+
"type": "int",
260+
"default": 1,
261+
"min": 1,
262+
"scope": "algorithm",
263+
},
264+
{
265+
"name": "batch-size",
266+
"type": "int",
267+
"default": 1,
268+
"min": 1,
269+
"scope": "algorithm",
270+
},
271+
{
272+
"name": "num_stack_levels",
273+
"type": "int",
274+
"default": 0,
275+
"min": 0,
276+
"max": 3,
277+
"scope": "algorithm",
278+
},
279+
],
280+
}
281+
},
282+
"p2": {
283+
"properties": {
284+
"hyperparameters": [
285+
{
286+
"name": "num_bag_sets",
287+
"type": "int",
288+
"default": 5,
289+
"min": 5,
290+
"scope": "algorithm",
291+
},
292+
{
293+
"name": "num_stack_levels",
294+
"type": "int",
295+
"default": 6,
296+
"min": 7,
297+
"max": 3,
298+
"scope": "algorithm",
299+
},
300+
{
301+
"name": "refit_full",
302+
"type": "text",
303+
"default": "False",
304+
"options": ["True", "False"],
305+
"scope": "algorithm",
306+
},
307+
{
308+
"name": "set_best_to_refit_full",
309+
"type": "text",
310+
"default": "False",
311+
"options": ["True", "False"],
312+
"scope": "algorithm",
313+
},
314+
{
315+
"name": "save_space",
316+
"type": "text",
317+
"default": "False",
318+
"options": ["True", "False"],
319+
"scope": "algorithm",
320+
},
321+
{
322+
"name": "verbosity",
323+
"type": "int",
324+
"default": 2,
325+
"min": 0,
326+
"max": 4,
327+
"scope": "algorithm",
328+
},
329+
{
330+
"name": "sagemaker_submit_directory",
331+
"type": "text",
332+
"default": "/opt/ml/input/data/code/sourcedir.tar.gz",
333+
"scope": "container",
334+
},
335+
{
336+
"name": "sagemaker_program",
337+
"type": "text",
338+
"default": "transfer_learning.py",
339+
"scope": "container",
340+
},
341+
{
342+
"name": "sagemaker_container_log_level",
343+
"type": "text",
344+
"default": "20",
345+
"scope": "container",
346+
},
347+
]
348+
}
349+
},
350+
}
351+
},
218352
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
219353
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
220354
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",

0 commit comments

Comments
 (0)