Skip to content

Commit c3a0b33

Browse files
committed
chore: improve function name to make more generic
1 parent 8a696cc commit c3a0b33

File tree

3 files changed

+12
-14
lines changed

3 files changed

+12
-14
lines changed

src/sagemaker/jumpstart/artifacts/instance_types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _retrieve_default_instance_type(
9090
if scope == JumpStartScriptScope.INFERENCE:
9191
instance_specific_default_instance_type = (
9292
(
93-
model_specs.training_instance_type_variants.get_training_instance_specific_default_inference_instance_type( # pylint: disable=C0301 # noqa: E501
93+
model_specs.training_instance_type_variants.get_instance_specific_default_inference_instance_type( # pylint: disable=C0301 # noqa: E501
9494
training_instance_type
9595
)
9696
)
@@ -176,7 +176,7 @@ def _retrieve_instance_types(
176176
if scope == JumpStartScriptScope.INFERENCE:
177177
default_instance_types = model_specs.supported_inference_instance_types or []
178178
instance_specific_instance_types = (
179-
model_specs.training_instance_type_variants.get_training_instance_specific_supported_inference_instance_types( # pylint: disable=C0301 # noqa: E501
179+
model_specs.training_instance_type_variants.get_instance_specific_supported_inference_instance_types( # pylint: disable=C0301 # noqa: E501
180180
training_instance_type
181181
)
182182
if training_instance_type is not None

src/sagemaker/jumpstart/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -581,7 +581,7 @@ def get_instance_specific_environment_variables(self, instance_type: str) -> Dic
581581

582582
return instance_family_environment_variables
583583

584-
def get_training_instance_specific_default_inference_instance_type(
584+
def get_instance_specific_default_inference_instance_type(
585585
self, instance_type: str
586586
) -> Optional[str]:
587587
"""Returns instance specific default inference instance type.
@@ -594,7 +594,7 @@ def get_training_instance_specific_default_inference_instance_type(
594594
instance_type, "default_inference_instance_type"
595595
)
596596

597-
def get_training_instance_specific_supported_inference_instance_types(
597+
def get_instance_specific_supported_inference_instance_types(
598598
self, instance_type: str
599599
) -> List[str]:
600600
"""Returns instance specific supported inference instance types.

tests/unit/sagemaker/jumpstart/test_types.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -676,44 +676,42 @@ def test_jumpstart_hyperparameter_instance_variants():
676676

677677

678678
def test_jumpstart_inference_instance_type_variants():
679-
assert INSTANCE_TYPE_VARIANT.get_training_instance_specific_supported_inference_instance_types(
679+
assert INSTANCE_TYPE_VARIANT.get_instance_specific_supported_inference_instance_types(
680680
"ml.p2.xlarge"
681681
) == ["ml.p2.xlarge", "ml.p3.xlarge"]
682682
assert (
683-
INSTANCE_TYPE_VARIANT.get_training_instance_specific_default_inference_instance_type(
684-
"ml.p2.2xlarge"
685-
)
683+
INSTANCE_TYPE_VARIANT.get_instance_specific_default_inference_instance_type("ml.p2.2xlarge")
686684
== "ml.p2.xlarge"
687685
)
688686

689-
assert INSTANCE_TYPE_VARIANT.get_training_instance_specific_supported_inference_instance_types(
687+
assert INSTANCE_TYPE_VARIANT.get_instance_specific_supported_inference_instance_types(
690688
"ml.p2.12xlarge"
691689
) == ["ml.p2.xlarge", "ml.p3.xlarge", "ml.p5.xlarge"]
692690
assert (
693-
INSTANCE_TYPE_VARIANT.get_training_instance_specific_default_inference_instance_type(
691+
INSTANCE_TYPE_VARIANT.get_instance_specific_default_inference_instance_type(
694692
"ml.p2.12xlarge"
695693
)
696694
== "ml.p5.xlarge"
697695
)
698696

699697
assert (
700-
INSTANCE_TYPE_VARIANT.get_training_instance_specific_supported_inference_instance_types(
698+
INSTANCE_TYPE_VARIANT.get_instance_specific_supported_inference_instance_types(
701699
"ml.sdfsad.12xlarge"
702700
)
703701
== []
704702
)
705703
assert (
706-
INSTANCE_TYPE_VARIANT.get_training_instance_specific_default_inference_instance_type(
704+
INSTANCE_TYPE_VARIANT.get_instance_specific_default_inference_instance_type(
707705
"ml.adfas.12xlarge"
708706
)
709707
is None
710708
)
711709

712-
assert INSTANCE_TYPE_VARIANT.get_training_instance_specific_supported_inference_instance_types(
710+
assert INSTANCE_TYPE_VARIANT.get_instance_specific_supported_inference_instance_types(
713711
"ml.trn1.12xlarge"
714712
) == ["ml.inf1.2xlarge", "ml.inf1.xlarge"]
715713
assert (
716-
INSTANCE_TYPE_VARIANT.get_training_instance_specific_default_inference_instance_type(
714+
INSTANCE_TYPE_VARIANT.get_instance_specific_default_inference_instance_type(
717715
"ml.trn1.12xlarge"
718716
)
719717
== "ml.inf1.xlarge"

0 commit comments

Comments
 (0)