Skip to content

Commit 630b51c

Browse files
feat: support creating endpoints with model images from private registries (#1834)
Co-authored-by: ChoiByungWook <[email protected]>
1 parent c0f54d9 commit 630b51c

File tree

3 files changed

+29
-2
lines changed

3 files changed

+29
-2
lines changed

src/sagemaker/model.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(
5353
sagemaker_session=None,
5454
enable_network_isolation=False,
5555
model_kms_key=None,
56+
image_config=None,
5657
):
5758
"""Initialize an SageMaker ``Model``.
5859
@@ -90,6 +91,10 @@ def __init__(
9091
or from the model container.
9192
model_kms_key (str): KMS key ARN used to encrypt the repacked
9293
model archive file if the model is repacked
94+
image_config (dict[str, str]): Specifies whether the image of
95+
model container is pulled from ECR, or private registry in your
96+
VPC. By default it is set to pull model container image from
97+
ECR. (default: None).
9398
"""
9499
self.model_data = model_data
95100
self.image_uri = image_uri
@@ -106,6 +111,7 @@ def __init__(
106111
self._is_edge_packaged_model = False
107112
self._enable_network_isolation = enable_network_isolation
108113
self.model_kms_key = model_kms_key
114+
self.image_config = image_config
109115

110116
def register(
111117
self,
@@ -279,7 +285,9 @@ def prepare_container_def(
279285
Returns:
280286
dict: A container definition object usable with the CreateModel API.
281287
"""
282-
return sagemaker.container_def(self.image_uri, self.model_data, self.env)
288+
return sagemaker.container_def(
289+
self.image_uri, self.model_data, self.env, image_config=self.image_config
290+
)
283291

284292
def enable_network_isolation(self):
285293
"""Whether to enable network isolation when creating this Model

src/sagemaker/session.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4018,7 +4018,7 @@ def update_args(args: Dict[str, Any], **kwargs):
40184018
args.update({key: value})
40194019

40204020

4021-
def container_def(image_uri, model_data_url=None, env=None, container_mode=None):
4021+
def container_def(image_uri, model_data_url=None, env=None, container_mode=None, image_config=None):
40224022
"""Create a definition for executing a container as part of a SageMaker model.
40234023
40244024
Args:
@@ -4030,6 +4030,9 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None)
40304030
* MultiModel: Indicates that model container can support hosting multiple models
40314031
* SingleModel: Indicates that model container can support hosting a single model
40324032
This is the default model container mode when container_mode = None
4033+
image_config (dict[str, str]): Specifies whether the image of model container is pulled
4034+
from ECR, or private registry in your VPC. By default it is set to pull model
4035+
container image from ECR. (default: None).
40334036
40344037
Returns:
40354038
dict[str, str]: A complete container definition object usable with the CreateModel API if
@@ -4042,6 +4045,8 @@ def container_def(image_uri, model_data_url=None, env=None, container_mode=None)
40424045
c_def["ModelDataUrl"] = model_data_url
40434046
if container_mode:
40444047
c_def["Mode"] = container_mode
4048+
if image_config:
4049+
c_def["ImageConfig"] = image_config
40454050
return c_def
40464051

40474052

tests/unit/sagemaker/model/test_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,20 @@ def test_prepare_container_def_with_model_data_and_env():
5454
assert expected == container_def
5555

5656

57+
def test_prepare_container_def_with_image_config():
58+
image_config = {"RepositoryAccessMode": "Vpc"}
59+
model = Model(MODEL_IMAGE, image_config=image_config)
60+
61+
expected = {
62+
"Image": MODEL_IMAGE,
63+
"ImageConfig": {"RepositoryAccessMode": "Vpc"},
64+
"Environment": {},
65+
}
66+
67+
container_def = model.prepare_container_def()
68+
assert expected == container_def
69+
70+
5771
def test_model_enable_network_isolation():
5872
model = Model(MODEL_IMAGE, MODEL_DATA)
5973
assert model.enable_network_isolation() is False

0 commit comments

Comments
 (0)