Skip to content

Commit ad41e98

Browse files
authored
fix: Use serving_image_uri for Airflow (#1245)
* Add methods to framework models to get serving image URI * Update test_airflow_config integ test * Format with black
1 parent becf92d commit ad41e98

File tree

9 files changed

+244
-27
lines changed

9 files changed

+244
-27
lines changed

src/sagemaker/chainer/model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import logging
1717

18+
from sagemaker import fw_utils
19+
1820
import sagemaker
1921
from sagemaker.fw_utils import (
2022
create_image_uri,
@@ -154,3 +156,23 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
154156
if self.model_server_workers:
155157
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
156158
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
159+
160+
def serving_image_uri(self, region_name, instance_type):
161+
"""Create a URI for the serving image.
162+
163+
Args:
164+
region_name (str): AWS region where the image is uploaded.
165+
instance_type (str): SageMaker instance type. Used to determine device type
166+
(cpu/gpu/family-specific optimized).
167+
168+
Returns:
169+
str: The appropriate image URI based on the given parameters.
170+
171+
"""
172+
return fw_utils.create_image_uri(
173+
region_name,
174+
self.__framework_name__,
175+
instance_type,
176+
self.framework_version,
177+
self.py_version,
178+
)

src/sagemaker/mxnet/model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import logging
1717

1818
from pkg_resources import parse_version
19+
from sagemaker import fw_utils
1920

2021
import sagemaker
2122
from sagemaker.fw_utils import (
@@ -167,3 +168,23 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
167168
return sagemaker.container_def(
168169
deploy_image, self.repacked_model_data or self.model_data, deploy_env
169170
)
171+
172+
def serving_image_uri(self, region_name, instance_type):
173+
"""Create a URI for the serving image.
174+
175+
Args:
176+
region_name (str): AWS region where the image is uploaded.
177+
instance_type (str): SageMaker instance type. Used to determine device type
178+
(cpu/gpu/family-specific optimized).
179+
180+
Returns:
181+
str: The appropriate image URI based on the given parameters.
182+
183+
"""
184+
return fw_utils.create_image_uri(
185+
region_name,
186+
self.__framework_name__,
187+
instance_type,
188+
self.framework_version,
189+
self.py_version,
190+
)

src/sagemaker/pytorch/model.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
import pkg_resources
18+
from sagemaker import fw_utils
1819

1920
import sagemaker
2021
from sagemaker.fw_utils import (
@@ -167,3 +168,23 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
167168
return sagemaker.container_def(
168169
deploy_image, self.repacked_model_data or self.model_data, deploy_env
169170
)
171+
172+
def serving_image_uri(self, region_name, instance_type):
173+
"""Create a URI for the serving image.
174+
175+
Args:
176+
region_name (str): AWS region where the image is uploaded.
177+
instance_type (str): SageMaker instance type. Used to determine device type
178+
(cpu/gpu/family-specific optimized).
179+
180+
Returns:
181+
str: The appropriate image URI based on the given parameters.
182+
183+
"""
184+
return fw_utils.create_image_uri(
185+
region_name,
186+
self.__framework_name__,
187+
instance_type,
188+
self.framework_version,
189+
self.py_version,
190+
)

src/sagemaker/sklearn/model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import logging
1717

18+
from sagemaker import fw_utils
19+
1820
import sagemaker
1921
from sagemaker.fw_utils import model_code_key_prefix, python_deprecation_warning
2022
from sagemaker.fw_registry import default_framework_uri
@@ -151,3 +153,23 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
151153
self.repacked_model_data if self.enable_network_isolation() else self.model_data
152154
)
153155
return sagemaker.container_def(deploy_image, model_data_uri, deploy_env)
156+
157+
def serving_image_uri(self, region_name, instance_type):
158+
"""Create a URI for the serving image.
159+
160+
Args:
161+
region_name (str): AWS region where the image is uploaded.
162+
instance_type (str): SageMaker instance type. Used to determine device type
163+
(cpu/gpu/family-specific optimized).
164+
165+
Returns:
166+
str: The appropriate image URI based on the given parameters.
167+
168+
"""
169+
return fw_utils.create_image_uri(
170+
region_name,
171+
self.__framework_name__,
172+
instance_type,
173+
self.framework_version,
174+
self.py_version,
175+
)

src/sagemaker/tensorflow/model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import logging
1717

18+
from sagemaker import fw_utils
19+
1820
import sagemaker
1921
from sagemaker.fw_utils import (
2022
create_image_uri,
@@ -157,3 +159,23 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
157159
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
158160

159161
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
162+
163+
def serving_image_uri(self, region_name, instance_type):
164+
"""Create a URI for the serving image.
165+
166+
Args:
167+
region_name (str): AWS region where the image is uploaded.
168+
instance_type (str): SageMaker instance type. Used to determine device type
169+
(cpu/gpu/family-specific optimized).
170+
171+
Returns:
172+
str: The appropriate image URI based on the given parameters.
173+
174+
"""
175+
return fw_utils.create_image_uri(
176+
region_name,
177+
self.__framework_name__,
178+
instance_type,
179+
self.framework_version,
180+
self.py_version,
181+
)

src/sagemaker/tensorflow/serving.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,3 +275,17 @@ def _get_image_uri(self, instance_type, accelerator_type=None):
275275
self._framework_version,
276276
accelerator_type=accelerator_type,
277277
)
278+
279+
def serving_image_uri(self, region_name, instance_type): # pylint: disable=unused-argument
280+
"""Create a URI for the serving image.
281+
282+
Args:
283+
region_name (str): AWS region where the image is uploaded.
284+
instance_type (str): SageMaker instance type. Used to determine device type
285+
(cpu/gpu/family-specific optimized).
286+
287+
Returns:
288+
str: The appropriate image URI based on the given parameters.
289+
290+
"""
291+
return self._get_image_uri(instance_type=instance_type)

src/sagemaker/workflow/airflow.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -516,13 +516,7 @@ def prepare_framework_container_def(model, instance_type, s3_operations):
516516
deploy_image = model.image
517517
if not deploy_image:
518518
region_name = model.sagemaker_session.boto_session.region_name
519-
deploy_image = fw_utils.create_image_uri(
520-
region_name,
521-
model.__framework_name__,
522-
instance_type,
523-
model.framework_version,
524-
model.py_version,
525-
)
519+
deploy_image = model.serving_image_uri(region_name, instance_type)
526520

527521
base_name = utils.base_name_from_image(deploy_image)
528522
model.name = model.name or utils.name_from_base(base_name)

src/sagemaker/xgboost/model.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import logging
1717

18+
from sagemaker import fw_utils
19+
1820
import sagemaker
1921
from sagemaker.fw_utils import model_code_key_prefix, python_deprecation_warning
2022
from sagemaker.fw_registry import default_framework_uri
@@ -135,3 +137,23 @@ def prepare_container_def(self, instance_type, accelerator_type=None):
135137
if self.model_server_workers:
136138
deploy_env[MODEL_SERVER_WORKERS_PARAM_NAME.upper()] = str(self.model_server_workers)
137139
return sagemaker.container_def(deploy_image, self.model_data, deploy_env)
140+
141+
def serving_image_uri(self, region_name, instance_type):
142+
"""Create a URI for the serving image.
143+
144+
Args:
145+
region_name (str): AWS region where the image is uploaded.
146+
instance_type (str): SageMaker instance type. Used to determine device type
147+
(cpu/gpu/family-specific optimized).
148+
149+
Returns:
150+
str: The appropriate image URI based on the given parameters.
151+
152+
"""
153+
return fw_utils.create_image_uri(
154+
region_name,
155+
self.__framework_name__,
156+
instance_type,
157+
self.framework_version,
158+
self.py_version,
159+
)

0 commit comments

Comments
 (0)