Skip to content

Commit 8dae626

Browse files
authored
change: enable network isolation for amazon estimators (#1293)
1 parent 1752e2f commit 8dae626

File tree

3 files changed

+37
-10
lines changed

3 files changed

+37
-10
lines changed

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,13 @@ class AmazonAlgorithmEstimatorBase(EstimatorBase):
4444
repo_version = None
4545

4646
def __init__(
47-
self, role, train_instance_count, train_instance_type, data_location=None, **kwargs
47+
self,
48+
role,
49+
train_instance_count,
50+
train_instance_type,
51+
data_location=None,
52+
enable_network_isolation=False,
53+
**kwargs
4854
):
4955
"""Initialize an AmazonAlgorithmEstimatorBase.
5056
@@ -63,6 +69,10 @@ def __init__(
6369
"s3://example-bucket/some-key-prefix/". Objects will be saved in
6470
a unique sub-directory of the specified location. If None, a
6571
default data location will be used.
72+
enable_network_isolation (bool): Specifies whether container will
73+
run in network isolation mode. Network isolation mode restricts
74+
the container access to outside networks (such as the internet).
75+
Also known as internet-free mode (default: ``False``).
6676
**kwargs: Additional parameters passed to
6777
:class:`~sagemaker.estimator.EstimatorBase`.
6878
@@ -71,14 +81,6 @@ def __init__(
7181
You can find additional parameters for initializing this class at
7282
:class:`~sagemaker.estimator.EstimatorBase`.
7383
"""
74-
75-
if "enable_network_isolation" in kwargs:
76-
logger.debug(
77-
"removing unused enable_network_isolation argument: %s",
78-
str(kwargs["enable_network_isolation"]),
79-
)
80-
del kwargs["enable_network_isolation"]
81-
8284
super(AmazonAlgorithmEstimatorBase, self).__init__(
8385
role, train_instance_count, train_instance_type, **kwargs
8486
)
@@ -87,6 +89,7 @@ def __init__(
8789
self.sagemaker_session.default_bucket()
8890
)
8991
self._data_location = data_location
92+
self._enable_network_isolation = enable_network_isolation
9093

9194
def train_image(self):
9295
"""Placeholder docstring"""
@@ -98,6 +101,14 @@ def hyperparameters(self):
98101
"""Placeholder docstring"""
99102
return hp.serialize_all(self)
100103

104+
def enable_network_isolation(self):
105+
"""If this Estimator can use network isolation when running.
106+
107+
Returns:
108+
bool: Whether this Estimator can use network isolation or not.
109+
"""
110+
return self._enable_network_isolation
111+
101112
@property
102113
def data_location(self):
103114
"""Placeholder docstring"""

tests/integ/test_pca.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def test_pca(sagemaker_session, cpu_instance_type):
4141
train_instance_type=cpu_instance_type,
4242
num_components=48,
4343
sagemaker_session=sagemaker_session,
44+
enable_network_isolation=True,
4445
)
4546

4647
pca.algorithm_mode = "randomized"
@@ -50,7 +51,10 @@ def test_pca(sagemaker_session, cpu_instance_type):
5051

5152
with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session):
5253
pca_model = sagemaker.amazon.pca.PCAModel(
53-
model_data=pca.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session
54+
model_data=pca.model_data,
55+
role="SageMakerRole",
56+
sagemaker_session=sagemaker_session,
57+
enable_network_isolation=True,
5458
)
5559
predictor = pca_model.deploy(
5660
initial_instance_count=1, instance_type=cpu_instance_type, endpoint_name=job_name

tests/unit/test_amazon_estimator.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ def test_gov_ecr_uri():
9393
def test_init(sagemaker_session):
9494
pca = PCA(num_components=55, sagemaker_session=sagemaker_session, **COMMON_ARGS)
9595
assert pca.num_components == 55
96+
assert pca.enable_network_isolation() is False
97+
98+
99+
def test_init_enable_network_isolation(sagemaker_session):
100+
pca = PCA(
101+
num_components=55,
102+
sagemaker_session=sagemaker_session,
103+
enable_network_isolation=True,
104+
**COMMON_ARGS
105+
)
106+
assert pca.num_components == 55
107+
assert pca.enable_network_isolation() is True
96108

97109

98110
def test_init_all_pca_hyperparameters(sagemaker_session):

0 commit comments

Comments
 (0)