@@ -192,7 +192,7 @@ def retrieve(
192
192
config = _config_for_framework_and_scope (_framework , final_image_scope , accelerator_type )
193
193
194
194
original_version = version
195
- version = _validate_version_and_set_if_needed (version , config , framework )
195
+ version = _validate_version_and_set_if_needed (version , config , framework , image_scope )
196
196
version_config = config ["versions" ][_version_for_config (version , config )]
197
197
198
198
if framework == HUGGING_FACE_FRAMEWORK :
@@ -460,8 +460,24 @@ def _get_inference_tool(inference_tool, instance_type):
460
460
461
461
def _get_latest_versions (list_of_versions ):
462
462
"""Extract the latest version from the input list of available versions."""
463
+ print ("SORT" )
463
464
return sorted (list_of_versions , reverse = True )[0 ]
464
465
466
+ def _get_latest_version (framework , version , image_scope ):
467
+ """Get the latest version from the input framework"""
468
+ if version :
469
+ return version
470
+ try :
471
+ framework_config = config_for_framework (framework )
472
+ except FileNotFoundError :
473
+ raise ValueError ("Invalid framework {}" .format (framework ))
474
+
475
+ if not framework_config :
476
+ raise ValueError ("Invalid framework {}" .format (framework ))
477
+
478
+ if not version :
479
+ version = _fetch_latest_version_from_config (framework_config , image_scope )
480
+ return version
465
481
466
482
def _validate_accelerator_type (accelerator_type ):
467
483
"""Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
@@ -472,32 +488,23 @@ def _validate_accelerator_type(accelerator_type):
472
488
)
473
489
474
490
475
- def _validate_version_and_set_if_needed (version , config , framework ):
491
+ def _validate_version_and_set_if_needed (version , config , framework , image_scope ):
476
492
"""Checks if the framework/algorithm version is one of the supported versions."""
493
+ if not config :
494
+ config = config_for_framework (framework )
477
495
available_versions = list (config ["versions" ].keys ())
478
496
aliased_versions = list (config .get ("version_aliases" , {}).keys ())
479
-
480
497
if len (available_versions ) == 1 and version not in aliased_versions :
481
- log_message = "Defaulting to the only supported framework/algorithm version: {}." .format (
482
- available_versions [0 ]
483
- )
484
- if version and version != available_versions [0 ]:
485
- logger .warning ("%s Ignoring framework/algorithm version: %s." , log_message , version )
486
- elif not version :
487
- logger .info (log_message )
488
-
489
498
return available_versions [0 ]
490
-
491
- if version is None and framework in [
499
+ if not version and framework in [
492
500
DATA_WRANGLER_FRAMEWORK ,
493
501
HUGGING_FACE_LLM_FRAMEWORK ,
494
502
HUGGING_FACE_TEI_GPU_FRAMEWORK ,
495
503
HUGGING_FACE_TEI_CPU_FRAMEWORK ,
496
504
HUGGING_FACE_LLM_NEURONX_FRAMEWORK ,
497
505
STABILITYAI_FRAMEWORK ,
498
506
]:
499
- version = _get_latest_versions (available_versions )
500
-
507
+ version = _get_latest_version (framework , version , image_scope )
501
508
_validate_arg (version , available_versions + aliased_versions , "{} version" .format (framework ))
502
509
return version
503
510
@@ -609,6 +616,7 @@ def _validate_py_version_and_set_if_needed(py_version, version_config, framework
609
616
610
617
def _validate_arg (arg , available_options , arg_name ):
611
618
"""Checks if the arg is in the available options, and raises a ``ValueError`` if not."""
619
+ print ("VALIDATE" )
612
620
if arg not in available_options :
613
621
raise ValueError (
614
622
"Unsupported {arg_name}: {arg}. You may need to upgrade your SDK version "
@@ -748,101 +756,6 @@ def get_base_python_image_uri(region, py_version="310") -> str:
748
756
return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo_and_tag )
749
757
750
758
751
- def get_latest_container_image (
752
- framework : str ,
753
- image_scope : Optional [str ] = None ,
754
- instance_type : Optional [str ] = None ,
755
- py_version : Optional [str ] = None ,
756
- region : str = "us-west-2" ,
757
- version : Optional [str ] = None ,
758
- accelerator_type = None ,
759
- container_version = None ,
760
- distribution = None ,
761
- base_framework_version = None ,
762
- training_compiler_config = None ,
763
- model_id = None ,
764
- model_version = None ,
765
- hub_arn = None ,
766
- sdk_version = None ,
767
- inference_tool = None ,
768
- serverless_inference_config = None ,
769
- config_name = None ,
770
- ) -> Tuple [str , str ]:
771
- """Retrieves the latest container image URI
772
-
773
- Args:
774
- framework (str): The name of the framework or algorithm.
775
- image_scope (str): The image type, i.e. what it is used for.
776
- Valid values: "training", "inference", "inference_graviton", "eia".
777
- If ``accelerator_type`` is set, ``image_scope`` is ignored.
778
- region (str): The AWS region.
779
- version (str): The framework or algorithm version. This is required if there is
780
- more than one supported version for the given framework or algorithm.
781
- py_version (str): The Python version. This is required if there is
782
- more than one supported Python version for the given framework version.
783
- instance_type (str): The SageMaker instance type. For supported types, see
784
- https://aws.amazon.com/sagemaker/pricing. This is required if
785
- there are different images for different processor types.
786
- accelerator_type (str): Elastic Inference accelerator type. For more, see
787
- https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
788
- container_version (str): the version of docker image.
789
- Ideally the value of parameter should be created inside the framework.
790
- For custom use, see the list of supported container versions:
791
- https://github.com/aws/deep-learning-containers/blob/master/available_images.md
792
- (default: None).
793
- distribution (dict): A dictionary with information on how to run distributed training
794
- training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
795
- A configuration class for the SageMaker Training Compiler
796
- (default: None).
797
- model_id (str): The JumpStart model ID for which to retrieve the image URI
798
- (default: None).
799
- model_version (str): The version of the JumpStart model for which to retrieve the
800
- image URI (default: None).
801
- hub_arn (str): The arn of the SageMaker Hub for which to retrieve
802
- model details from. (Default: None).
803
- sdk_version (str): the version of python-sdk that will be used in the image retrieval.
804
- (default: None).
805
- inference_tool (str): the tool that will be used to aid in the inference.
806
- Valid values: "neuron, neuronx, None"
807
- (default: None).
808
- serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
809
- Specifies configuration related to serverless endpoint. Instance type is
810
- not provided in serverless inference. So this is used to determine processor type.
811
- config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
812
- """
813
- try :
814
- framework_config = config_for_framework (framework )
815
- except FileNotFoundError :
816
- raise ValueError ("Invalid framework {}" .format (framework ))
817
-
818
- if not framework_config :
819
- raise ValueError ("Invalid framework {}" .format (framework ))
820
-
821
- if not version :
822
- version = _fetch_latest_version_from_config (framework_config , image_scope )
823
- image_uri = retrieve (
824
- framework = framework ,
825
- region = region ,
826
- version = version ,
827
- instance_type = instance_type ,
828
- py_version = py_version ,
829
- accelerator_type = accelerator_type ,
830
- image_scope = image_scope ,
831
- container_version = container_version ,
832
- distribution = distribution ,
833
- base_framework_version = base_framework_version ,
834
- training_compiler_config = training_compiler_config ,
835
- model_id = model_id ,
836
- model_version = model_version ,
837
- hub_arn = hub_arn ,
838
- sdk_version = sdk_version ,
839
- inference_tool = inference_tool ,
840
- serverless_inference_config = serverless_inference_config ,
841
- config_name = config_name ,
842
- )
843
- return image_uri , version
844
-
845
-
846
759
def _fetch_latest_version_from_config (
847
760
framework_config : dict , image_scope : Optional [str ] = None
848
761
) -> Optional [str ]:
@@ -864,6 +777,8 @@ def _fetch_latest_version_from_config(
864
777
865
778
if "versions" in framework_config :
866
779
versions = list (framework_config ["versions" ].keys ())
780
+ if len (versions ) == 1 :
781
+ return versions [0 ]
867
782
top_version = versions [0 ]
868
783
bottom_version = versions [- 1 ]
869
784
if top_version == "latest" or bottom_version == "latest" :
@@ -880,7 +795,6 @@ def _fetch_latest_version_from_config(
880
795
versions = list (framework_config ["processing" ]["versions" ].keys ())
881
796
top_version = versions [0 ]
882
797
bottom_version = versions [- 1 ]
883
-
884
798
if top_version and bottom_version :
885
799
if top_version .endswith (".x" ) or bottom_version .endswith (".x" ):
886
800
top_number = int (top_version [:- 2 ])
0 commit comments