-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: account for EI and version-based ECR repo naming in serving_image_uri() #1273
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
105e46a
b3ecc28
ca7b716
aa88c1d
b4e2559
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,7 +15,6 @@ | |
|
||
import logging | ||
import packaging.version | ||
from sagemaker import fw_utils | ||
|
||
import sagemaker | ||
from sagemaker.fw_utils import ( | ||
|
@@ -137,34 +136,21 @@ def prepare_container_def(self, instance_type, accelerator_type=None): | |
For example, 'ml.p2.xlarge'. | ||
accelerator_type (str): The Elastic Inference accelerator type to | ||
deploy to the instance for loading and making inferences to the | ||
model. For example, 'ml.eia1.medium'. | ||
model. Currently unsupported with PyTorch. | ||
|
||
Returns: | ||
dict[str, str]: A container definition object usable with the | ||
CreateModel API. | ||
""" | ||
lowest_mms_version = packaging.version.Version(self._LOWEST_MMS_VERSION) | ||
framework_version = packaging.version.Version(self.framework_version) | ||
is_mms_version = framework_version >= lowest_mms_version | ||
|
||
deploy_image = self.image | ||
if not deploy_image: | ||
region_name = self.sagemaker_session.boto_session.region_name | ||
|
||
framework_name = self.__framework_name__ | ||
if is_mms_version: | ||
framework_name += "-serving" | ||
|
||
deploy_image = create_image_uri( | ||
region_name, | ||
framework_name, | ||
instance_type, | ||
self.framework_version, | ||
self.py_version, | ||
accelerator_type=accelerator_type, | ||
deploy_image = self.serving_image_uri( | ||
region_name, instance_type, accelerator_type=accelerator_type | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [Minor] Keyword/Named arguments make kittens happy. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. talked offline - I don't see a significant benefit of doing |
||
) | ||
|
||
deploy_key_prefix = model_code_key_prefix(self.key_prefix, self.name, deploy_image) | ||
self._upload_code(deploy_key_prefix, repack=is_mms_version) | ||
self._upload_code(deploy_key_prefix, repack=self._is_mms_version()) | ||
deploy_env = dict(self.env) | ||
deploy_env.update(self._framework_env_vars()) | ||
|
||
|
@@ -174,22 +160,41 @@ def prepare_container_def(self, instance_type, accelerator_type=None): | |
deploy_image, self.repacked_model_data or self.model_data, deploy_env | ||
) | ||
|
||
def serving_image_uri(self, region_name, instance_type): | ||
def serving_image_uri(self, region_name, instance_type, accelerator_type=None): | ||
"""Create a URI for the serving image. | ||
|
||
Args: | ||
region_name (str): AWS region where the image is uploaded. | ||
instance_type (str): SageMaker instance type. Used to determine device type | ||
(cpu/gpu/family-specific optimized). | ||
accelerator_type (str): The Elastic Inference accelerator type to | ||
deploy to the instance for loading and making inferences to the | ||
model. Currently unsupported with PyTorch. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment as above. What happens if this is passed in? Is the img not found? |
||
|
||
Returns: | ||
str: The appropriate image URI based on the given parameters. | ||
|
||
""" | ||
return fw_utils.create_image_uri( | ||
framework_name = self.__framework_name__ | ||
if self._is_mms_version(): | ||
framework_name = "{}-serving".format(framework_name) | ||
|
||
return create_image_uri( | ||
region_name, | ||
"-".join([self.__framework_name__, "serving"]), | ||
framework_name, | ||
instance_type, | ||
self.framework_version, | ||
self.py_version, | ||
accelerator_type=accelerator_type, | ||
) | ||
|
||
def _is_mms_version(self): | ||
"""Whether the framework version corresponds to an inference image using | ||
the Multi-Model Server (https://github.com/awslabs/multi-model-server). | ||
|
||
Returns: | ||
bool: If the framework version corresponds to an image using MMS. | ||
""" | ||
lowest_mms_version = packaging.version.Version(self._LOWEST_MMS_VERSION) | ||
framework_version = packaging.version.Version(self.framework_version) | ||
return framework_version >= lowest_mms_version |
Uh oh!
There was an error while loading. Please reload this page.