Skip to content

Commit 03f901f

Browse files
authored
feat: inference instance type conditioned on training instance type (#4230)
1 parent 866a2d9 commit 03f901f

File tree

10 files changed

+763
-5
lines changed

10 files changed

+763
-5
lines changed

src/sagemaker/instance_types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def retrieve_default(
3333
tolerate_vulnerable_model: bool = False,
3434
tolerate_deprecated_model: bool = False,
3535
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
36+
training_instance_type: Optional[str] = None,
3637
) -> str:
3738
"""Retrieves the default instance type for the model matching the given arguments.
3839
@@ -56,6 +57,11 @@ def retrieve_default(
5657
object, used for SageMaker interactions. If not
5758
specified, one is created using the default AWS configuration
5859
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
60+
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
61+
instance type used for the training job that produced the fine-tuned weights.
62+
Optionally supply this to get a inference instance type conditioned
63+
on the training instance, to ensure compatability of training artifact to inference
64+
instance. (Default: None).
5965
Returns:
6066
str: The default instance type to use for the model.
6167
@@ -78,6 +84,7 @@ def retrieve_default(
7884
tolerate_vulnerable_model,
7985
tolerate_deprecated_model,
8086
sagemaker_session=sagemaker_session,
87+
training_instance_type=training_instance_type,
8188
)
8289

8390

@@ -89,6 +96,7 @@ def retrieve(
8996
tolerate_vulnerable_model: bool = False,
9097
tolerate_deprecated_model: bool = False,
9198
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
99+
training_instance_type: Optional[str] = None,
92100
) -> List[str]:
93101
"""Retrieves the supported training instance types for the model matching the given arguments.
94102
@@ -110,6 +118,12 @@ def retrieve(
110118
object, used for SageMaker interactions. If not
111119
specified, one is created using the default AWS configuration
112120
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
121+
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
122+
instance type used for the training job that produced the fine-tuned weights.
123+
Optionally supply this to get a inference instance type conditioned
124+
on the training instance, to ensure compatability of training artifact to inference
125+
instance. (Default: None).
126+
113127
Returns:
114128
list: The supported instance types to use for the model.
115129
@@ -132,4 +146,5 @@ def retrieve(
132146
tolerate_vulnerable_model,
133147
tolerate_deprecated_model,
134148
sagemaker_session=sagemaker_session,
149+
training_instance_type=training_instance_type,
135150
)

src/sagemaker/jumpstart/artifacts/instance_types.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def _retrieve_default_instance_type(
3737
tolerate_vulnerable_model: bool = False,
3838
tolerate_deprecated_model: bool = False,
3939
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
40+
training_instance_type: Optional[str] = None,
4041
) -> str:
4142
"""Retrieves the default instance type for the model.
4243
@@ -60,6 +61,11 @@ def _retrieve_default_instance_type(
6061
object, used for SageMaker interactions. If not
6162
specified, one is created using the default AWS configuration
6263
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64+
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
65+
instance type used for the training job that produced the fine-tuned weights.
66+
Optionally supply this to get a inference instance type conditioned
67+
on the training instance, to ensure compatability of training artifact to inference
68+
instance. (Default: None).
6369
Returns:
6470
str: the default instance type to use for the model or None.
6571
@@ -82,7 +88,21 @@ def _retrieve_default_instance_type(
8288
)
8389

8490
if scope == JumpStartScriptScope.INFERENCE:
85-
default_instance_type = model_specs.default_inference_instance_type
91+
instance_specific_default_instance_type = (
92+
(
93+
model_specs.training_instance_type_variants.get_instance_specific_default_inference_instance_type( # pylint: disable=C0301 # noqa: E501
94+
training_instance_type
95+
)
96+
)
97+
if training_instance_type is not None
98+
and getattr(model_specs, "training_instance_type_variants", None) is not None
99+
else None
100+
)
101+
default_instance_type = (
102+
instance_specific_default_instance_type
103+
if instance_specific_default_instance_type is not None
104+
else model_specs.default_inference_instance_type
105+
)
86106
elif scope == JumpStartScriptScope.TRAINING:
87107
default_instance_type = model_specs.default_training_instance_type
88108
else:
@@ -103,6 +123,7 @@ def _retrieve_instance_types(
103123
tolerate_vulnerable_model: bool = False,
104124
tolerate_deprecated_model: bool = False,
105125
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
126+
training_instance_type: Optional[str] = None,
106127
) -> List[str]:
107128
"""Retrieves the supported instance types for the model.
108129
@@ -126,6 +147,11 @@ def _retrieve_instance_types(
126147
object, used for SageMaker interactions. If not
127148
specified, one is created using the default AWS configuration
128149
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
150+
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
151+
instance type used for the training job that produced the fine-tuned weights.
152+
Optionally supply this to get a inference instance type conditioned
153+
on the training instance, to ensure compatability of training artifact to inference
154+
instance. (Default: None).
129155
Returns:
130156
list: the supported instance types to use for the model or None.
131157
@@ -148,8 +174,24 @@ def _retrieve_instance_types(
148174
)
149175

150176
if scope == JumpStartScriptScope.INFERENCE:
151-
instance_types = model_specs.supported_inference_instance_types
177+
default_instance_types = model_specs.supported_inference_instance_types or []
178+
instance_specific_instance_types = (
179+
model_specs.training_instance_type_variants.get_instance_specific_supported_inference_instance_types( # pylint: disable=C0301 # noqa: E501
180+
training_instance_type
181+
)
182+
if training_instance_type is not None
183+
and getattr(model_specs, "training_instance_type_variants", None) is not None
184+
else []
185+
)
186+
instance_types = (
187+
instance_specific_instance_types
188+
if len(instance_specific_instance_types) > 0
189+
else default_instance_types
190+
)
191+
152192
elif scope == JumpStartScriptScope.TRAINING:
193+
if training_instance_type is not None:
194+
raise ValueError("Cannot use `training_instance_type` argument " "with training scope.")
153195
instance_types = model_specs.supported_training_instance_types
154196
else:
155197
raise NotImplementedError(

src/sagemaker/jumpstart/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,6 @@ def deploy(
988988
use_compiled_model (bool): Flag to select whether to use compiled
989989
(optimized) model. (Default: False).
990990
"""
991-
992991
self.orig_predictor_cls = predictor_cls
993992

994993
sagemaker_session = sagemaker_session or self.sagemaker_session
@@ -1039,6 +1038,7 @@ def deploy(
10391038
dependencies=dependencies,
10401039
git_config=git_config,
10411040
use_compiled_model=use_compiled_model,
1041+
training_instance_type=self.instance_type,
10421042
)
10431043

10441044
predictor = super(JumpStartEstimator, self).deploy(

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,7 @@ def get_deploy_kwargs(
280280
tolerate_vulnerable_model: Optional[bool] = None,
281281
use_compiled_model: Optional[bool] = None,
282282
model_name: Optional[str] = None,
283+
training_instance_type: Optional[str] = None,
283284
) -> JumpStartEstimatorDeployKwargs:
284285
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object."""
285286

@@ -313,7 +314,7 @@ def get_deploy_kwargs(
313314
model_id=model_id,
314315
model_from_estimator=True,
315316
model_version=model_version,
316-
instance_type=model_deploy_kwargs.instance_type, # prevent excess logging
317+
instance_type=model_deploy_kwargs.instance_type if training_instance_type is None else None,
317318
region=region,
318319
image_uri=image_uri,
319320
source_dir=source_dir,
@@ -333,6 +334,7 @@ def get_deploy_kwargs(
333334
git_config=git_config,
334335
tolerate_vulnerable_model=tolerate_vulnerable_model,
335336
tolerate_deprecated_model=tolerate_deprecated_model,
337+
training_instance_type=training_instance_type,
336338
)
337339

338340
estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs(

src/sagemaker/jumpstart/factory/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM
181181
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
182182
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
183183
sagemaker_session=kwargs.sagemaker_session,
184+
training_instance_type=kwargs.training_instance_type,
184185
)
185186

186187
if orig_instance_type is None:
@@ -643,6 +644,7 @@ def get_init_kwargs(
643644
dependencies: Optional[List[str]] = None,
644645
git_config: Optional[Dict[str, str]] = None,
645646
model_package_arn: Optional[str] = None,
647+
training_instance_type: Optional[str] = None,
646648
) -> JumpStartModelInitKwargs:
647649
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""
648650

@@ -671,6 +673,7 @@ def get_init_kwargs(
671673
tolerate_deprecated_model=tolerate_deprecated_model,
672674
tolerate_vulnerable_model=tolerate_vulnerable_model,
673675
model_package_arn=model_package_arn,
676+
training_instance_type=training_instance_type,
674677
)
675678

676679
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)

src/sagemaker/jumpstart/types.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,56 @@ def get_instance_specific_environment_variables(self, instance_type: str) -> Dic
581581

582582
return instance_family_environment_variables
583583

584+
def get_instance_specific_default_inference_instance_type(
585+
self, instance_type: str
586+
) -> Optional[str]:
587+
"""Returns instance specific default inference instance type.
588+
589+
Returns None if a model, instance type tuple does not have instance
590+
specific inference instance types.
591+
"""
592+
593+
return self._get_instance_specific_property(
594+
instance_type, "default_inference_instance_type"
595+
)
596+
597+
def get_instance_specific_supported_inference_instance_types(
598+
self, instance_type: str
599+
) -> List[str]:
600+
"""Returns instance specific supported inference instance types.
601+
602+
Returns empty list if a model, instance type tuple does not have instance
603+
specific inference instance types.
604+
"""
605+
606+
if self.variants is None:
607+
return []
608+
609+
instance_specific_inference_instance_types: List[str] = (
610+
self.variants.get(instance_type, {})
611+
.get("properties", {})
612+
.get("supported_inference_instance_types", [])
613+
)
614+
615+
instance_type_family = get_instance_type_family(instance_type)
616+
617+
instance_family_inference_instance_types: List[str] = (
618+
self.variants.get(instance_type_family, {})
619+
.get("properties", {})
620+
.get("supported_inference_instance_types", [])
621+
if instance_type_family not in {"", None}
622+
else []
623+
)
624+
625+
return sorted(
626+
list(
627+
set(
628+
instance_specific_inference_instance_types
629+
+ instance_family_inference_instance_types
630+
)
631+
)
632+
)
633+
584634
def get_image_uri(self, instance_type: str, region: str) -> Optional[str]:
585635
"""Returns image uri from instance type and region.
586636
@@ -971,6 +1021,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
9711021
"dependencies",
9721022
"git_config",
9731023
"model_package_arn",
1024+
"training_instance_type",
9741025
]
9751026

9761027
SERIALIZATION_EXCLUSION_SET = {
@@ -981,6 +1032,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
9811032
"tolerate_deprecated_model",
9821033
"region",
9831034
"model_package_arn",
1035+
"training_instance_type",
9841036
}
9851037

9861038
def __init__(
@@ -1009,6 +1061,7 @@ def __init__(
10091061
tolerate_vulnerable_model: Optional[bool] = None,
10101062
tolerate_deprecated_model: Optional[bool] = None,
10111063
model_package_arn: Optional[str] = None,
1064+
training_instance_type: Optional[str] = None,
10121065
) -> None:
10131066
"""Instantiates JumpStartModelInitKwargs object."""
10141067

@@ -1036,6 +1089,7 @@ def __init__(
10361089
self.tolerate_deprecated_model = tolerate_deprecated_model
10371090
self.tolerate_vulnerable_model = tolerate_vulnerable_model
10381091
self.model_package_arn = model_package_arn
1092+
self.training_instance_type = training_instance_type
10391093

10401094

10411095
class JumpStartModelDeployKwargs(JumpStartKwargs):
@@ -1065,6 +1119,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
10651119
"tolerate_vulnerable_model",
10661120
"tolerate_deprecated_model",
10671121
"sagemaker_session",
1122+
"training_instance_type",
10681123
]
10691124

10701125
SERIALIZATION_EXCLUSION_SET = {
@@ -1074,6 +1129,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
10741129
"tolerate_deprecated_model",
10751130
"tolerate_vulnerable_model",
10761131
"sagemaker_session",
1132+
"training_instance_type",
10771133
}
10781134

10791135
def __init__(
@@ -1101,6 +1157,7 @@ def __init__(
11011157
tolerate_deprecated_model: Optional[bool] = None,
11021158
tolerate_vulnerable_model: Optional[bool] = None,
11031159
sagemaker_session: Optional[Session] = None,
1160+
training_instance_type: Optional[str] = None,
11041161
) -> None:
11051162
"""Instantiates JumpStartModelDeployKwargs object."""
11061163

@@ -1127,6 +1184,7 @@ def __init__(
11271184
self.tolerate_vulnerable_model = tolerate_vulnerable_model
11281185
self.tolerate_deprecated_model = tolerate_deprecated_model
11291186
self.sagemaker_session = sagemaker_session
1187+
self.training_instance_type = training_instance_type
11301188

11311189

11321190
class JumpStartEstimatorInitKwargs(JumpStartKwargs):

0 commit comments

Comments
 (0)