Skip to content

Commit cea49d7

Browse files
committed
feat: inference instance type conditioned on training instance type
1 parent 0202bad commit cea49d7

File tree

10 files changed

+760
-5
lines changed

10 files changed

+760
-5
lines changed

src/sagemaker/instance_types.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def retrieve_default(
3333
tolerate_vulnerable_model: bool = False,
3434
tolerate_deprecated_model: bool = False,
3535
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
36+
training_instance_type: Optional[str] = None,
3637
) -> str:
3738
"""Retrieves the default instance type for the model matching the given arguments.
3839
@@ -56,6 +57,10 @@ def retrieve_default(
5657
object, used for SageMaker interactions. If not
5758
specified, one is created using the default AWS configuration
5859
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
60+
training_instance_type (str): The training instance type from which to deploy an endpoint.
61+
Optionally supply this to get a inference instance type conditioned
62+
on the training instance, to ensure compatability of training artifact to inference
63+
instance. (Default: None).
5964
Returns:
6065
str: The default instance type to use for the model.
6166
@@ -78,6 +83,7 @@ def retrieve_default(
7883
tolerate_vulnerable_model,
7984
tolerate_deprecated_model,
8085
sagemaker_session=sagemaker_session,
86+
training_instance_type=training_instance_type,
8187
)
8288

8389

@@ -89,6 +95,7 @@ def retrieve(
8995
tolerate_vulnerable_model: bool = False,
9096
tolerate_deprecated_model: bool = False,
9197
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
98+
training_instance_type: Optional[str] = None,
9299
) -> List[str]:
93100
"""Retrieves the supported training instance types for the model matching the given arguments.
94101
@@ -110,6 +117,11 @@ def retrieve(
110117
object, used for SageMaker interactions. If not
111118
specified, one is created using the default AWS configuration
112119
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
120+
training_instance_type (str): The training instance type from which to deploy an endpoint.
121+
Optionally supply this to get a inference instance type conditioned
122+
on the training instance, to ensure compatability of training artifact to inference
123+
instance. (Default: None).
124+
113125
Returns:
114126
list: The supported instance types to use for the model.
115127
@@ -132,4 +144,5 @@ def retrieve(
132144
tolerate_vulnerable_model,
133145
tolerate_deprecated_model,
134146
sagemaker_session=sagemaker_session,
147+
training_instance_type=training_instance_type,
135148
)

src/sagemaker/jumpstart/artifacts/instance_types.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def _retrieve_default_instance_type(
3737
tolerate_vulnerable_model: bool = False,
3838
tolerate_deprecated_model: bool = False,
3939
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
40+
training_instance_type: Optional[str] = None,
4041
) -> str:
4142
"""Retrieves the default instance type for the model.
4243
@@ -60,6 +61,10 @@ def _retrieve_default_instance_type(
6061
object, used for SageMaker interactions. If not
6162
specified, one is created using the default AWS configuration
6263
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64+
training_instance_type (str): The training instance type from which to deploy an endpoint.
65+
Optionally supply this to get a inference instance type conditioned
66+
on the training instance, to ensure compatability of training artifact to inference
67+
instance. (Default: None).
6368
Returns:
6469
str: the default instance type to use for the model or None.
6570
@@ -82,7 +87,21 @@ def _retrieve_default_instance_type(
8287
)
8388

8489
if scope == JumpStartScriptScope.INFERENCE:
85-
default_instance_type = model_specs.default_inference_instance_type
90+
instance_specific_default_instance_type = (
91+
(
92+
model_specs.training_instance_type_variants.get_instance_specific_default_inference_instance_type( # pylint: disable=C0301 # noqa: E501
93+
training_instance_type
94+
)
95+
)
96+
if training_instance_type is not None
97+
and getattr(model_specs, "training_instance_type_variants", None) is not None
98+
else None
99+
)
100+
default_instance_type = (
101+
instance_specific_default_instance_type
102+
if instance_specific_default_instance_type is not None
103+
else model_specs.default_inference_instance_type
104+
)
86105
elif scope == JumpStartScriptScope.TRAINING:
87106
default_instance_type = model_specs.default_training_instance_type
88107
else:
@@ -103,6 +122,7 @@ def _retrieve_instance_types(
103122
tolerate_vulnerable_model: bool = False,
104123
tolerate_deprecated_model: bool = False,
105124
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
125+
training_instance_type: Optional[str] = None,
106126
) -> List[str]:
107127
"""Retrieves the supported instance types for the model.
108128
@@ -126,6 +146,10 @@ def _retrieve_instance_types(
126146
object, used for SageMaker interactions. If not
127147
specified, one is created using the default AWS configuration
128148
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
149+
training_instance_type (str): The training instance type from which to deploy an endpoint.
150+
Optionally supply this to get a inference instance type conditioned
151+
on the training instance, to ensure compatability of training artifact to inference
152+
instance. (Default: None).
129153
Returns:
130154
list: the supported instance types to use for the model or None.
131155
@@ -148,8 +172,25 @@ def _retrieve_instance_types(
148172
)
149173

150174
if scope == JumpStartScriptScope.INFERENCE:
151-
instance_types = model_specs.supported_inference_instance_types
175+
default_instance_types = model_specs.supported_inference_instance_types or []
176+
instance_specific_instance_types = (
177+
model_specs.training_instance_type_variants.get_instance_specific_supported_inference_instance_types( # pylint: disable=C0301 # noqa: E501
178+
training_instance_type
179+
)
180+
if training_instance_type is not None
181+
and hasattr(model_specs, "training_instance_type_variants")
182+
else []
183+
)
184+
# pylint: enable
185+
instance_types = (
186+
instance_specific_instance_types
187+
if len(instance_specific_instance_types) > 0
188+
else default_instance_types
189+
)
190+
152191
elif scope == JumpStartScriptScope.TRAINING:
192+
if training_instance_type is not None:
193+
raise ValueError("Cannot use `training_instance_type` argument " "with training scope.")
153194
instance_types = model_specs.supported_training_instance_types
154195
else:
155196
raise NotImplementedError(

src/sagemaker/jumpstart/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,6 @@ def deploy(
988988
use_compiled_model (bool): Flag to select whether to use compiled
989989
(optimized) model. (Default: False).
990990
"""
991-
992991
self.orig_predictor_cls = predictor_cls
993992

994993
sagemaker_session = sagemaker_session or self.sagemaker_session
@@ -1039,6 +1038,7 @@ def deploy(
10391038
dependencies=dependencies,
10401039
git_config=git_config,
10411040
use_compiled_model=use_compiled_model,
1041+
training_instance_type=self.instance_type,
10421042
)
10431043

10441044
predictor = super(JumpStartEstimator, self).deploy(

src/sagemaker/jumpstart/factory/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def get_deploy_kwargs(
277277
tolerate_vulnerable_model: Optional[bool] = None,
278278
use_compiled_model: Optional[bool] = None,
279279
model_name: Optional[str] = None,
280+
training_instance_type: Optional[str] = None,
280281
) -> JumpStartEstimatorDeployKwargs:
281282
"""Returns kwargs required to call `deploy` on `sagemaker.estimator.Estimator` object."""
282283

@@ -310,7 +311,7 @@ def get_deploy_kwargs(
310311
model_id=model_id,
311312
model_from_estimator=True,
312313
model_version=model_version,
313-
instance_type=model_deploy_kwargs.instance_type, # prevent excess logging
314+
instance_type=model_deploy_kwargs.instance_type if training_instance_type is None else None,
314315
region=region,
315316
image_uri=image_uri,
316317
source_dir=source_dir,
@@ -330,6 +331,7 @@ def get_deploy_kwargs(
330331
git_config=git_config,
331332
tolerate_vulnerable_model=tolerate_vulnerable_model,
332333
tolerate_deprecated_model=tolerate_deprecated_model,
334+
training_instance_type=training_instance_type,
333335
)
334336

335337
estimator_deploy_kwargs: JumpStartEstimatorDeployKwargs = JumpStartEstimatorDeployKwargs(

src/sagemaker/jumpstart/factory/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def _add_instance_type_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartM
180180
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
181181
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
182182
sagemaker_session=kwargs.sagemaker_session,
183+
training_instance_type=kwargs.training_instance_type,
183184
)
184185

185186
if orig_instance_type is None:
@@ -619,6 +620,7 @@ def get_init_kwargs(
619620
dependencies: Optional[List[str]] = None,
620621
git_config: Optional[Dict[str, str]] = None,
621622
model_package_arn: Optional[str] = None,
623+
training_instance_type: Optional[str] = None,
622624
) -> JumpStartModelInitKwargs:
623625
"""Returns kwargs required to instantiate `sagemaker.estimator.Model` object."""
624626

@@ -647,6 +649,7 @@ def get_init_kwargs(
647649
tolerate_deprecated_model=tolerate_deprecated_model,
648650
tolerate_vulnerable_model=tolerate_vulnerable_model,
649651
model_package_arn=model_package_arn,
652+
training_instance_type=training_instance_type,
650653
)
651654

652655
model_init_kwargs = _add_model_version_to_kwargs(kwargs=model_init_kwargs)

src/sagemaker/jumpstart/types.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,56 @@ def get_instance_specific_environment_variables(self, instance_type: str) -> Dic
581581

582582
return instance_family_environment_variables
583583

584+
def get_instance_specific_default_inference_instance_type(
585+
self, instance_type: str
586+
) -> Optional[str]:
587+
"""Returns instance specific default inference instance type.
588+
589+
Returns None if a model, instance type tuple does not have instance
590+
specific inference instance types.
591+
"""
592+
593+
return self._get_instance_specific_property(
594+
instance_type, "default_inference_instance_type"
595+
)
596+
597+
def get_instance_specific_supported_inference_instance_types(
598+
self, instance_type: str
599+
) -> List[str]:
600+
"""Returns instance specific supported inference instance types.
601+
602+
Returns empty list if a model, instance type tuple does not have instance
603+
specific inference instance types.
604+
"""
605+
606+
if self.variants is None:
607+
return []
608+
609+
instance_specific_inference_instance_types: List[str] = (
610+
self.variants.get(instance_type, {})
611+
.get("properties", {})
612+
.get("supported_inference_instance_types", [])
613+
)
614+
615+
instance_type_family = get_instance_type_family(instance_type)
616+
617+
instance_family_inference_instance_types: List[str] = (
618+
self.variants.get(instance_type_family, {})
619+
.get("properties", {})
620+
.get("supported_inference_instance_types", [])
621+
if instance_type_family not in {"", None}
622+
else []
623+
)
624+
625+
return sorted(
626+
list(
627+
set(
628+
instance_specific_inference_instance_types
629+
+ instance_family_inference_instance_types
630+
)
631+
)
632+
)
633+
584634
def get_image_uri(self, instance_type: str, region: str) -> Optional[str]:
585635
"""Returns image uri from instance type and region.
586636
@@ -971,6 +1021,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
9711021
"dependencies",
9721022
"git_config",
9731023
"model_package_arn",
1024+
"training_instance_type",
9741025
]
9751026

9761027
SERIALIZATION_EXCLUSION_SET = {
@@ -981,6 +1032,7 @@ class JumpStartModelInitKwargs(JumpStartKwargs):
9811032
"tolerate_deprecated_model",
9821033
"region",
9831034
"model_package_arn",
1035+
"training_instance_type",
9841036
}
9851037

9861038
def __init__(
@@ -1009,6 +1061,7 @@ def __init__(
10091061
tolerate_vulnerable_model: Optional[bool] = None,
10101062
tolerate_deprecated_model: Optional[bool] = None,
10111063
model_package_arn: Optional[str] = None,
1064+
training_instance_type: Optional[str] = None,
10121065
) -> None:
10131066
"""Instantiates JumpStartModelInitKwargs object."""
10141067

@@ -1036,6 +1089,7 @@ def __init__(
10361089
self.tolerate_deprecated_model = tolerate_deprecated_model
10371090
self.tolerate_vulnerable_model = tolerate_vulnerable_model
10381091
self.model_package_arn = model_package_arn
1092+
self.training_instance_type = training_instance_type
10391093

10401094

10411095
class JumpStartModelDeployKwargs(JumpStartKwargs):
@@ -1065,6 +1119,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
10651119
"tolerate_vulnerable_model",
10661120
"tolerate_deprecated_model",
10671121
"sagemaker_session",
1122+
"training_instance_type",
10681123
]
10691124

10701125
SERIALIZATION_EXCLUSION_SET = {
@@ -1074,6 +1129,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
10741129
"tolerate_deprecated_model",
10751130
"tolerate_vulnerable_model",
10761131
"sagemaker_session",
1132+
"training_instance_type",
10771133
}
10781134

10791135
def __init__(
@@ -1101,6 +1157,7 @@ def __init__(
11011157
tolerate_deprecated_model: Optional[bool] = None,
11021158
tolerate_vulnerable_model: Optional[bool] = None,
11031159
sagemaker_session: Optional[Session] = None,
1160+
training_instance_type: Optional[str] = None,
11041161
) -> None:
11051162
"""Instantiates JumpStartModelDeployKwargs object."""
11061163

@@ -1127,6 +1184,7 @@ def __init__(
11271184
self.tolerate_vulnerable_model = tolerate_vulnerable_model
11281185
self.tolerate_deprecated_model = tolerate_deprecated_model
11291186
self.sagemaker_session = sagemaker_session
1187+
self.training_instance_type = training_instance_type
11301188

11311189

11321190
class JumpStartEstimatorInitKwargs(JumpStartKwargs):

0 commit comments

Comments
 (0)