@@ -676,44 +676,42 @@ def test_jumpstart_hyperparameter_instance_variants():
676
676
677
677
678
678
def test_jumpstart_inference_instance_type_variants ():
679
- assert INSTANCE_TYPE_VARIANT .get_training_instance_specific_supported_inference_instance_types (
679
+ assert INSTANCE_TYPE_VARIANT .get_instance_specific_supported_inference_instance_types (
680
680
"ml.p2.xlarge"
681
681
) == ["ml.p2.xlarge" , "ml.p3.xlarge" ]
682
682
assert (
683
- INSTANCE_TYPE_VARIANT .get_training_instance_specific_default_inference_instance_type (
684
- "ml.p2.2xlarge"
685
- )
683
+ INSTANCE_TYPE_VARIANT .get_instance_specific_default_inference_instance_type ("ml.p2.2xlarge" )
686
684
== "ml.p2.xlarge"
687
685
)
688
686
689
- assert INSTANCE_TYPE_VARIANT .get_training_instance_specific_supported_inference_instance_types (
687
+ assert INSTANCE_TYPE_VARIANT .get_instance_specific_supported_inference_instance_types (
690
688
"ml.p2.12xlarge"
691
689
) == ["ml.p2.xlarge" , "ml.p3.xlarge" , "ml.p5.xlarge" ]
692
690
assert (
693
- INSTANCE_TYPE_VARIANT .get_training_instance_specific_default_inference_instance_type (
691
+ INSTANCE_TYPE_VARIANT .get_instance_specific_default_inference_instance_type (
694
692
"ml.p2.12xlarge"
695
693
)
696
694
== "ml.p5.xlarge"
697
695
)
698
696
699
697
assert (
700
- INSTANCE_TYPE_VARIANT .get_training_instance_specific_supported_inference_instance_types (
698
+ INSTANCE_TYPE_VARIANT .get_instance_specific_supported_inference_instance_types (
701
699
"ml.sdfsad.12xlarge"
702
700
)
703
701
== []
704
702
)
705
703
assert (
706
- INSTANCE_TYPE_VARIANT .get_training_instance_specific_default_inference_instance_type (
704
+ INSTANCE_TYPE_VARIANT .get_instance_specific_default_inference_instance_type (
707
705
"ml.adfas.12xlarge"
708
706
)
709
707
is None
710
708
)
711
709
712
- assert INSTANCE_TYPE_VARIANT .get_training_instance_specific_supported_inference_instance_types (
710
+ assert INSTANCE_TYPE_VARIANT .get_instance_specific_supported_inference_instance_types (
713
711
"ml.trn1.12xlarge"
714
712
) == ["ml.inf1.2xlarge" , "ml.inf1.xlarge" ]
715
713
assert (
716
- INSTANCE_TYPE_VARIANT .get_training_instance_specific_default_inference_instance_type (
714
+ INSTANCE_TYPE_VARIANT .get_instance_specific_default_inference_instance_type (
717
715
"ml.trn1.12xlarge"
718
716
)
719
717
== "ml.inf1.xlarge"
0 commit comments