Skip to content

Commit 84daab6

Browse files
chad119Chad Chiangnargokul
authored andcommitted
Fix: move the functionality from latest_container_image to retrieve (#1583)
* Fix: move the functionality from latest_container_image to retrieve * address some comments from Gokul and add unit test * remove extra functions and rewrite the test * fix unit test * fix for other unit test * unit test fix * fix unit test: add one more condition * more unit tests fix * remove redundant files --------- Co-authored-by: Chad Chiang <[email protected]> Co-authored-by: Gokul Anantha Narayanan <[email protected]>
1 parent 3bb4fbb commit 84daab6

File tree

3 files changed

+128
-255
lines changed

3 files changed

+128
-255
lines changed

src/sagemaker/image_uris.py

Lines changed: 25 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def retrieve(
192192
config = _config_for_framework_and_scope(_framework, final_image_scope, accelerator_type)
193193

194194
original_version = version
195-
version = _validate_version_and_set_if_needed(version, config, framework)
195+
version = _validate_version_and_set_if_needed(version, config, framework, image_scope)
196196
version_config = config["versions"][_version_for_config(version, config)]
197197

198198
if framework == HUGGING_FACE_FRAMEWORK:
@@ -460,8 +460,24 @@ def _get_inference_tool(inference_tool, instance_type):
460460

461461
def _get_latest_versions(list_of_versions):
462462
"""Extract the latest version from the input list of available versions."""
463+
print("SORT")
463464
return sorted(list_of_versions, reverse=True)[0]
464465

466+
def _get_latest_version(framework, version, image_scope):
467+
"""Get the latest version from the input framework"""
468+
if version:
469+
return version
470+
try:
471+
framework_config = config_for_framework(framework)
472+
except FileNotFoundError:
473+
raise ValueError("Invalid framework {}".format(framework))
474+
475+
if not framework_config:
476+
raise ValueError("Invalid framework {}".format(framework))
477+
478+
if not version:
479+
version = _fetch_latest_version_from_config(framework_config, image_scope)
480+
return version
465481

466482
def _validate_accelerator_type(accelerator_type):
467483
"""Raises a ``ValueError`` if ``accelerator_type`` is invalid."""
@@ -472,32 +488,23 @@ def _validate_accelerator_type(accelerator_type):
472488
)
473489

474490

475-
def _validate_version_and_set_if_needed(version, config, framework):
491+
def _validate_version_and_set_if_needed(version, config, framework, image_scope):
476492
"""Checks if the framework/algorithm version is one of the supported versions."""
493+
if not config:
494+
config = config_for_framework(framework)
477495
available_versions = list(config["versions"].keys())
478496
aliased_versions = list(config.get("version_aliases", {}).keys())
479-
480497
if len(available_versions) == 1 and version not in aliased_versions:
481-
log_message = "Defaulting to the only supported framework/algorithm version: {}.".format(
482-
available_versions[0]
483-
)
484-
if version and version != available_versions[0]:
485-
logger.warning("%s Ignoring framework/algorithm version: %s.", log_message, version)
486-
elif not version:
487-
logger.info(log_message)
488-
489498
return available_versions[0]
490-
491-
if version is None and framework in [
499+
if not version and framework in [
492500
DATA_WRANGLER_FRAMEWORK,
493501
HUGGING_FACE_LLM_FRAMEWORK,
494502
HUGGING_FACE_TEI_GPU_FRAMEWORK,
495503
HUGGING_FACE_TEI_CPU_FRAMEWORK,
496504
HUGGING_FACE_LLM_NEURONX_FRAMEWORK,
497505
STABILITYAI_FRAMEWORK,
498506
]:
499-
version = _get_latest_versions(available_versions)
500-
507+
version = _get_latest_version(framework, version, image_scope)
501508
_validate_arg(version, available_versions + aliased_versions, "{} version".format(framework))
502509
return version
503510

@@ -609,6 +616,7 @@ def _validate_py_version_and_set_if_needed(py_version, version_config, framework
609616

610617
def _validate_arg(arg, available_options, arg_name):
611618
"""Checks if the arg is in the available options, and raises a ``ValueError`` if not."""
619+
print("VALIDATE")
612620
if arg not in available_options:
613621
raise ValueError(
614622
"Unsupported {arg_name}: {arg}. You may need to upgrade your SDK version "
@@ -748,101 +756,6 @@ def get_base_python_image_uri(region, py_version="310") -> str:
748756
return ECR_URI_TEMPLATE.format(registry=registry, hostname=hostname, repository=repo_and_tag)
749757

750758

751-
def get_latest_container_image(
752-
framework: str,
753-
image_scope: Optional[str] = None,
754-
instance_type: Optional[str] = None,
755-
py_version: Optional[str] = None,
756-
region: str = "us-west-2",
757-
version: Optional[str] = None,
758-
accelerator_type=None,
759-
container_version=None,
760-
distribution=None,
761-
base_framework_version=None,
762-
training_compiler_config=None,
763-
model_id=None,
764-
model_version=None,
765-
hub_arn=None,
766-
sdk_version=None,
767-
inference_tool=None,
768-
serverless_inference_config=None,
769-
config_name=None,
770-
) -> Tuple[str, str]:
771-
"""Retrieves the latest container image URI
772-
773-
Args:
774-
framework (str): The name of the framework or algorithm.
775-
image_scope (str): The image type, i.e. what it is used for.
776-
Valid values: "training", "inference", "inference_graviton", "eia".
777-
If ``accelerator_type`` is set, ``image_scope`` is ignored.
778-
region (str): The AWS region.
779-
version (str): The framework or algorithm version. This is required if there is
780-
more than one supported version for the given framework or algorithm.
781-
py_version (str): The Python version. This is required if there is
782-
more than one supported Python version for the given framework version.
783-
instance_type (str): The SageMaker instance type. For supported types, see
784-
https://aws.amazon.com/sagemaker/pricing. This is required if
785-
there are different images for different processor types.
786-
accelerator_type (str): Elastic Inference accelerator type. For more, see
787-
https://docs.aws.amazon.com/sagemaker/latest/dg/ei.html.
788-
container_version (str): the version of docker image.
789-
Ideally the value of parameter should be created inside the framework.
790-
For custom use, see the list of supported container versions:
791-
https://github.com/aws/deep-learning-containers/blob/master/available_images.md
792-
(default: None).
793-
distribution (dict): A dictionary with information on how to run distributed training
794-
training_compiler_config (:class:`~sagemaker.training_compiler.TrainingCompilerConfig`):
795-
A configuration class for the SageMaker Training Compiler
796-
(default: None).
797-
model_id (str): The JumpStart model ID for which to retrieve the image URI
798-
(default: None).
799-
model_version (str): The version of the JumpStart model for which to retrieve the
800-
image URI (default: None).
801-
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
802-
model details from. (Default: None).
803-
sdk_version (str): the version of python-sdk that will be used in the image retrieval.
804-
(default: None).
805-
inference_tool (str): the tool that will be used to aid in the inference.
806-
Valid values: "neuron, neuronx, None"
807-
(default: None).
808-
serverless_inference_config (sagemaker.serverless.ServerlessInferenceConfig):
809-
Specifies configuration related to serverless endpoint. Instance type is
810-
not provided in serverless inference. So this is used to determine processor type.
811-
config_name (Optional[str]): Name of the JumpStart Model config to apply. (Default: None).
812-
"""
813-
try:
814-
framework_config = config_for_framework(framework)
815-
except FileNotFoundError:
816-
raise ValueError("Invalid framework {}".format(framework))
817-
818-
if not framework_config:
819-
raise ValueError("Invalid framework {}".format(framework))
820-
821-
if not version:
822-
version = _fetch_latest_version_from_config(framework_config, image_scope)
823-
image_uri = retrieve(
824-
framework=framework,
825-
region=region,
826-
version=version,
827-
instance_type=instance_type,
828-
py_version=py_version,
829-
accelerator_type=accelerator_type,
830-
image_scope=image_scope,
831-
container_version=container_version,
832-
distribution=distribution,
833-
base_framework_version=base_framework_version,
834-
training_compiler_config=training_compiler_config,
835-
model_id=model_id,
836-
model_version=model_version,
837-
hub_arn=hub_arn,
838-
sdk_version=sdk_version,
839-
inference_tool=inference_tool,
840-
serverless_inference_config=serverless_inference_config,
841-
config_name=config_name,
842-
)
843-
return image_uri, version
844-
845-
846759
def _fetch_latest_version_from_config(
847760
framework_config: dict, image_scope: Optional[str] = None
848761
) -> Optional[str]:
@@ -864,6 +777,8 @@ def _fetch_latest_version_from_config(
864777

865778
if "versions" in framework_config:
866779
versions = list(framework_config["versions"].keys())
780+
if len(versions) == 1:
781+
return versions[0]
867782
top_version = versions[0]
868783
bottom_version = versions[-1]
869784
if top_version == "latest" or bottom_version == "latest":
@@ -880,7 +795,6 @@ def _fetch_latest_version_from_config(
880795
versions = list(framework_config["processing"]["versions"].keys())
881796
top_version = versions[0]
882797
bottom_version = versions[-1]
883-
884798
if top_version and bottom_version:
885799
if top_version.endswith(".x") or bottom_version.endswith(".x"):
886800
top_number = int(top_version[:-2])

tests/unit/sagemaker/image_uris/test_latest_container_image.py

Lines changed: 0 additions & 129 deletions
This file was deleted.

0 commit comments

Comments
 (0)