Skip to content

Commit 18f95af

Browse files
authored
feat: add network isolation support for PipelineModel (#1943)
* fix: pass enable_network_isolation to create_model * add unit test for network isolation
1 parent 67ef671 commit 18f95af

File tree

2 files changed

+63
-3
lines changed

2 files changed

+63
-3
lines changed

src/sagemaker/pipeline.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,14 @@ class PipelineModel(object):
2626
"""
2727

2828
def __init__(
29-
self, models, role, predictor_cls=None, name=None, vpc_config=None, sagemaker_session=None
29+
self,
30+
models,
31+
role,
32+
predictor_cls=None,
33+
name=None,
34+
vpc_config=None,
35+
sagemaker_session=None,
36+
enable_network_isolation=False,
3037
):
3138
"""Initialize a SageMaker `Model` instance.
3239
@@ -57,13 +64,18 @@ def __init__(
5764
object, used for SageMaker interactions (default: None). If not
5865
specified, one is created using the default AWS configuration
5966
chain.
67+
enable_network_isolation (bool): Default False. if True, enables
68+
network isolation in the endpoint, isolating the model
69+
container. No inbound or outbound network calls can be made to
70+
or from the model container.Boolean
6071
"""
6172
self.models = models
6273
self.role = role
6374
self.predictor_cls = predictor_cls
6475
self.name = name
6576
self.vpc_config = vpc_config
6677
self.sagemaker_session = sagemaker_session
78+
self.enable_network_isolation = enable_network_isolation
6779
self.endpoint_name = None
6880

6981
def pipeline_container_def(self, instance_type):
@@ -157,7 +169,11 @@ def deploy(
157169

158170
self.name = self.name or name_from_image(containers[0]["Image"])
159171
self.sagemaker_session.create_model(
160-
self.name, self.role, containers, vpc_config=self.vpc_config
172+
self.name,
173+
self.role,
174+
containers,
175+
vpc_config=self.vpc_config,
176+
enable_network_isolation=self.enable_network_isolation,
161177
)
162178

163179
production_variant = sagemaker.production_variant(
@@ -214,7 +230,11 @@ def _create_sagemaker_pipeline_model(self, instance_type):
214230

215231
self.name = self.name or name_from_image(containers[0]["Image"])
216232
self.sagemaker_session.create_model(
217-
self.name, self.role, containers, vpc_config=self.vpc_config
233+
self.name,
234+
self.role,
235+
containers,
236+
vpc_config=self.vpc_config,
237+
enable_network_isolation=self.enable_network_isolation,
218238
)
219239

220240
def transformer(

tests/unit/test_pipeline_model.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,3 +298,43 @@ def test_delete_model(tfo, time, sagemaker_session):
298298

299299
pipeline_model.delete_model()
300300
sagemaker_session.delete_model.assert_called_with(pipeline_model.name)
301+
302+
303+
@patch("tarfile.open")
304+
@patch("time.strftime", return_value=TIMESTAMP)
305+
def test_network_isolation(tfo, time, sagemaker_session):
306+
framework_model = DummyFrameworkModel(sagemaker_session)
307+
sparkml_model = SparkMLModel(
308+
model_data=MODEL_DATA_2, role=ROLE, sagemaker_session=sagemaker_session
309+
)
310+
model = PipelineModel(
311+
models=[framework_model, sparkml_model],
312+
role=ROLE,
313+
sagemaker_session=sagemaker_session,
314+
enable_network_isolation=True,
315+
)
316+
model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1)
317+
318+
sagemaker_session.create_model.assert_called_with(
319+
model.name,
320+
ROLE,
321+
[
322+
{
323+
"Image": "mi-1",
324+
"Environment": {
325+
"SAGEMAKER_PROGRAM": "blah.py",
326+
"SAGEMAKER_SUBMIT_DIRECTORY": "s3://mybucket/mi-1-2017-10-10-14-14-15/sourcedir.tar.gz",
327+
"SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
328+
"SAGEMAKER_REGION": "us-west-2",
329+
},
330+
"ModelDataUrl": "s3://bucket/model_1.tar.gz",
331+
},
332+
{
333+
"Image": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.2",
334+
"Environment": {},
335+
"ModelDataUrl": "s3://bucket/model_2.tar.gz",
336+
},
337+
],
338+
vpc_config=None,
339+
enable_network_isolation=True,
340+
)

0 commit comments

Comments
 (0)