Skip to content

Commit 389fadf

Browse files
committed
chore: address PR comments
1 parent cea49d7 commit 389fadf

File tree

4 files changed

+22
-16
lines changed

4 files changed

+22
-16
lines changed

src/sagemaker/instance_types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,8 @@ def retrieve_default(
5757
object, used for SageMaker interactions. If not
5858
specified, one is created using the default AWS configuration
5959
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
60-
training_instance_type (str): The training instance type from which to deploy an endpoint.
60+
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
61+
instance type used for the training job that produced the fine-tuned weights.
6162
Optionally supply this to get a inference instance type conditioned
6263
on the training instance, to ensure compatability of training artifact to inference
6364
instance. (Default: None).
@@ -117,7 +118,8 @@ def retrieve(
117118
object, used for SageMaker interactions. If not
118119
specified, one is created using the default AWS configuration
119120
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
120-
training_instance_type (str): The training instance type from which to deploy an endpoint.
121+
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
122+
instance type used for the training job that produced the fine-tuned weights.
121123
Optionally supply this to get a inference instance type conditioned
122124
on the training instance, to ensure compatability of training artifact to inference
123125
instance. (Default: None).

src/sagemaker/jumpstart/artifacts/instance_types.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def _retrieve_default_instance_type(
6161
object, used for SageMaker interactions. If not
6262
specified, one is created using the default AWS configuration
6363
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
64-
training_instance_type (str): The training instance type from which to deploy an endpoint.
64+
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
65+
instance type used for the training job that produced the fine-tuned weights.
6566
Optionally supply this to get a inference instance type conditioned
6667
on the training instance, to ensure compatability of training artifact to inference
6768
instance. (Default: None).
@@ -89,7 +90,7 @@ def _retrieve_default_instance_type(
8990
if scope == JumpStartScriptScope.INFERENCE:
9091
instance_specific_default_instance_type = (
9192
(
92-
model_specs.training_instance_type_variants.get_instance_specific_default_inference_instance_type( # pylint: disable=C0301 # noqa: E501
93+
model_specs.training_instance_type_variants.get_training_instance_specific_default_inference_instance_type( # pylint: disable=C0301 # noqa: E501
9394
training_instance_type
9495
)
9596
)
@@ -146,7 +147,8 @@ def _retrieve_instance_types(
146147
object, used for SageMaker interactions. If not
147148
specified, one is created using the default AWS configuration
148149
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
149-
training_instance_type (str): The training instance type from which to deploy an endpoint.
150+
training_instance_type (str): In the case of a model fine-tuned on SageMaker, the training
151+
instance type used for the training job that produced the fine-tuned weights.
150152
Optionally supply this to get a inference instance type conditioned
151153
on the training instance, to ensure compatability of training artifact to inference
152154
instance. (Default: None).
@@ -174,7 +176,7 @@ def _retrieve_instance_types(
174176
if scope == JumpStartScriptScope.INFERENCE:
175177
default_instance_types = model_specs.supported_inference_instance_types or []
176178
instance_specific_instance_types = (
177-
model_specs.training_instance_type_variants.get_instance_specific_supported_inference_instance_types( # pylint: disable=C0301 # noqa: E501
179+
model_specs.training_instance_type_variants.get_training_instance_specific_supported_inference_instance_types( # pylint: disable=C0301 # noqa: E501
178180
training_instance_type
179181
)
180182
if training_instance_type is not None

src/sagemaker/jumpstart/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def get_instance_specific_environment_variables(self, instance_type: str) -> Dic
581581

582582
return instance_family_environment_variables
583583

584-
def get_instance_specific_default_inference_instance_type(
584+
def get_training_instance_specific_default_inference_instance_type(
585585
self, instance_type: str
586586
) -> Optional[str]:
587587
"""Returns instance specific default inference instance type.
@@ -594,7 +594,7 @@ def get_instance_specific_default_inference_instance_type(
594594
instance_type, "default_inference_instance_type"
595595
)
596596

597-
def get_instance_specific_supported_inference_instance_types(
597+
def get_training_instance_specific_supported_inference_instance_types(
598598
self, instance_type: str
599599
) -> List[str]:
600600
"""Returns instance specific supported inference instance types.

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -676,42 +676,44 @@ def test_jumpstart_hyperparameter_instance_variants():
676676

677677

678678
def test_jumpstart_inference_instance_type_variants():
679-
assert INSTANCE_TYPE_VARIANT.get_instance_specific_supported_inference_instance_types(
679+
assert INSTANCE_TYPE_VARIANT.get_training_instance_specific_supported_inference_instance_types(
680680
"ml.p2.xlarge"
681681
) == ["ml.p2.xlarge", "ml.p3.xlarge"]
682682
assert (
683-
INSTANCE_TYPE_VARIANT.get_instance_specific_default_inference_instance_type("ml.p2.2xlarge")
683+
INSTANCE_TYPE_VARIANT.get_training_instance_specific_default_inference_instance_type(
684+
"ml.p2.2xlarge"
685+
)
684686
== "ml.p2.xlarge"
685687
)
686688

687-
assert INSTANCE_TYPE_VARIANT.get_instance_specific_supported_inference_instance_types(
689+
assert INSTANCE_TYPE_VARIANT.get_training_instance_specific_supported_inference_instance_types(
688690
"ml.p2.12xlarge"
689691
) == ["ml.p2.xlarge", "ml.p3.xlarge", "ml.p5.xlarge"]
690692
assert (
691-
INSTANCE_TYPE_VARIANT.get_instance_specific_default_inference_instance_type(
693+
INSTANCE_TYPE_VARIANT.get_training_instance_specific_default_inference_instance_type(
692694
"ml.p2.12xlarge"
693695
)
694696
== "ml.p5.xlarge"
695697
)
696698

697699
assert (
698-
INSTANCE_TYPE_VARIANT.get_instance_specific_supported_inference_instance_types(
700+
INSTANCE_TYPE_VARIANT.get_training_instance_specific_supported_inference_instance_types(
699701
"ml.sdfsad.12xlarge"
700702
)
701703
== []
702704
)
703705
assert (
704-
INSTANCE_TYPE_VARIANT.get_instance_specific_default_inference_instance_type(
706+
INSTANCE_TYPE_VARIANT.get_training_instance_specific_default_inference_instance_type(
705707
"ml.adfas.12xlarge"
706708
)
707709
is None
708710
)
709711

710-
assert INSTANCE_TYPE_VARIANT.get_instance_specific_supported_inference_instance_types(
712+
assert INSTANCE_TYPE_VARIANT.get_training_instance_specific_supported_inference_instance_types(
711713
"ml.trn1.12xlarge"
712714
) == ["ml.inf1.2xlarge", "ml.inf1.xlarge"]
713715
assert (
714-
INSTANCE_TYPE_VARIANT.get_instance_specific_default_inference_instance_type(
716+
INSTANCE_TYPE_VARIANT.get_training_instance_specific_default_inference_instance_type(
715717
"ml.trn1.12xlarge"
716718
)
717719
== "ml.inf1.xlarge"

0 commit comments

Comments
 (0)