@@ -98,6 +98,7 @@ def __init__(
98
98
debugger_hook_config = None ,
99
99
tensorboard_output_config = None ,
100
100
enable_sagemaker_metrics = None ,
101
+ enable_network_isolation = False ,
101
102
):
102
103
"""Initialize an ``EstimatorBase`` instance.
103
104
@@ -199,6 +200,11 @@ def __init__(
199
200
Series. For more information see:
200
201
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
201
202
(default: ``None``).
203
+ enable_network_isolation (bool): Specifies whether container will
204
+ run in network isolation mode (default: ``False``). Network
205
+ isolation mode restricts the container access to outside networks
206
+ (such as the Internet). The container does not make any inbound or
207
+ outbound network calls. Also known as Internet-free mode.
202
208
"""
203
209
self .role = role
204
210
self .train_instance_count = train_instance_count
@@ -260,6 +266,7 @@ def __init__(
260
266
self .collection_configs = None
261
267
262
268
self .enable_sagemaker_metrics = enable_sagemaker_metrics
269
+ self ._enable_network_isolation = enable_network_isolation
263
270
264
271
@abstractmethod
265
272
def train_image (self ):
@@ -290,7 +297,7 @@ def enable_network_isolation(self):
290
297
Returns:
291
298
bool: Whether this Estimator needs network isolation or not.
292
299
"""
293
- return False
300
+ return self . _enable_network_isolation
294
301
295
302
def prepare_workflow_for_training (self , job_name = None ):
296
303
"""Calls _prepare_for_training. Used when setting up a workflow.
@@ -818,6 +825,8 @@ def transformer(
818
825
role = None ,
819
826
volume_kms_key = None ,
820
827
vpc_config_override = vpc_utils .VPC_CONFIG_DEFAULT ,
828
+ enable_network_isolation = None ,
829
+ model_name = None ,
821
830
):
822
831
"""Return a ``Transformer`` that uses a SageMaker Model based on the
823
832
training job. It reuses the SageMaker Session and base job name used by
@@ -856,8 +865,20 @@ def transformer(
856
865
vpc_config_override (dict[str, list[str]]): Optional override for the
857
866
VpcConfig set on the model.
858
867
Default: use subnets and security groups from this Estimator.
868
+
859
869
* 'Subnets' (list[str]): List of subnet ids.
860
870
* 'SecurityGroupIds' (list[str]): List of security group ids.
871
+
872
+ enable_network_isolation (bool): Specifies whether container will
873
+ run in network isolation mode. Network isolation mode restricts
874
+ the container access to outside networks (such as the internet).
875
+ The container does not make any inbound or outbound network
876
+ calls. If True, a channel named "code" will be created for any
877
+ user entry script for inference. Also known as Internet-free mode.
878
+ If not specified, this setting is taken from the estimator's
879
+ current configuration.
880
+ model_name (str): Name to use for creating an Amazon SageMaker
881
+ model. If not specified, the name of the training job is used.
861
882
"""
862
883
tags = tags or self .tags
863
884
@@ -866,11 +887,16 @@ def transformer(
866
887
"No finished training job found associated with this estimator. Please make sure "
867
888
"this estimator is only used for building workflow config"
868
889
)
869
- model_name = self ._current_job_name
890
+ model_name = model_name or self ._current_job_name
870
891
else :
871
- model_name = self .latest_training_job .name
892
+ model_name = model_name or self .latest_training_job .name
893
+ if enable_network_isolation is None :
894
+ enable_network_isolation = self .enable_network_isolation ()
895
+
872
896
model = self .create_model (
873
- vpc_config_override = vpc_config_override , model_kms_key = self .output_kms_key
897
+ vpc_config_override = vpc_config_override ,
898
+ model_kms_key = self .output_kms_key ,
899
+ enable_network_isolation = enable_network_isolation ,
874
900
)
875
901
876
902
# not all create_model() implementations have the same kwargs
@@ -1219,21 +1245,17 @@ def __init__(
1219
1245
checkpoints will be provided under `/opt/ml/checkpoints/`.
1220
1246
(default: ``None``).
1221
1247
enable_network_isolation (bool): Specifies whether container will
1222
- run in network isolation mode. Network isolation mode restricts
1223
- the container access to outside networks (such as the Internet).
1224
- The container does not make any inbound or outbound network
1225
- calls. If ``True``, a channel named "code" will be created for any
1226
- user entry script for training. The user entry script, files in
1227
- source_dir (if specified), and dependencies will be uploaded in
1228
- a tar to S3. Also known as internet-free mode (default: ``False``).
1248
+ run in network isolation mode (default: ``False``). Network
1249
+ isolation mode restricts the container access to outside networks
1250
+ (such as the Internet). The container does not make any inbound or
1251
+ outbound network calls. Also known as Internet-free mode.
1229
1252
enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
1230
1253
Series. For more information see:
1231
1254
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
1232
1255
(default: ``None``).
1233
1256
"""
1234
1257
self .image_name = image_name
1235
1258
self .hyperparam_dict = hyperparameters .copy () if hyperparameters else {}
1236
- self ._enable_network_isolation = enable_network_isolation
1237
1259
super (Estimator , self ).__init__ (
1238
1260
role ,
1239
1261
train_instance_count ,
@@ -1261,16 +1283,9 @@ def __init__(
1261
1283
debugger_hook_config = debugger_hook_config ,
1262
1284
tensorboard_output_config = tensorboard_output_config ,
1263
1285
enable_sagemaker_metrics = enable_sagemaker_metrics ,
1286
+ enable_network_isolation = enable_network_isolation ,
1264
1287
)
1265
1288
1266
- def enable_network_isolation (self ):
1267
- """If this Estimator can use network isolation when running.
1268
-
1269
- Returns:
1270
- bool: Whether this Estimator can use network isolation or not.
1271
- """
1272
- return self ._enable_network_isolation
1273
-
1274
1289
def train_image (self ):
1275
1290
"""Returns the docker image to use for training.
1276
1291
@@ -1358,14 +1373,16 @@ def predict_wrapper(endpoint, session):
1358
1373
1359
1374
role = role or self .role
1360
1375
1376
+ if "enable_network_isolation" not in kwargs :
1377
+ kwargs ["enable_network_isolation" ] = self .enable_network_isolation ()
1378
+
1361
1379
return Model (
1362
1380
self .model_data ,
1363
1381
image or self .train_image (),
1364
1382
role ,
1365
1383
vpc_config = self .get_vpc_config (vpc_config_override ),
1366
1384
sagemaker_session = self .sagemaker_session ,
1367
1385
predictor_cls = predictor_cls ,
1368
- enable_network_isolation = self .enable_network_isolation (),
1369
1386
** kwargs
1370
1387
)
1371
1388
@@ -1498,15 +1515,15 @@ def __init__(
1498
1515
>>> |------ train.py
1499
1516
>>> |------ common
1500
1517
>>> |------ virtual-env
1518
+
1501
1519
enable_network_isolation (bool): Specifies whether container will
1502
1520
run in network isolation mode. Network isolation mode restricts
1503
1521
the container access to outside networks (such as the internet).
1504
1522
The container does not make any inbound or outbound network
1505
1523
calls. If True, a channel named "code" will be created for any
1506
1524
user entry script for training. The user entry script, files in
1507
1525
source_dir (if specified), and dependencies will be uploaded in
1508
- a tar to S3. Also known as internet-free mode (default: `False`
1509
- ).
1526
+ a tar to S3. Also known as internet-free mode (default: `False`).
1510
1527
git_config (dict[str, str]): Git configurations used for cloning
1511
1528
files, including ``repo``, ``branch``, ``commit``,
1512
1529
``2FA_enabled``, ``username``, ``password`` and ``token``. The
@@ -1579,7 +1596,7 @@ def __init__(
1579
1596
You can find additional parameters for initializing this class at
1580
1597
:class:`~sagemaker.estimator.EstimatorBase`.
1581
1598
"""
1582
- super (Framework , self ).__init__ (** kwargs )
1599
+ super (Framework , self ).__init__ (enable_network_isolation = enable_network_isolation , ** kwargs )
1583
1600
if entry_point .startswith ("s3://" ):
1584
1601
raise ValueError (
1585
1602
"Invalid entry point script: {}. Must be a path to a local file." .format (
@@ -1599,7 +1616,6 @@ def __init__(
1599
1616
self .container_log_level = container_log_level
1600
1617
self .code_location = code_location
1601
1618
self .image_name = image_name
1602
- self ._enable_network_isolation = enable_network_isolation
1603
1619
1604
1620
self .uploaded_code = None
1605
1621
@@ -1608,14 +1624,6 @@ def __init__(
1608
1624
self .checkpoint_local_path = checkpoint_local_path
1609
1625
self .enable_sagemaker_metrics = enable_sagemaker_metrics
1610
1626
1611
- def enable_network_isolation (self ):
1612
- """Return True if this Estimator can use network isolation to run.
1613
-
1614
- Returns:
1615
- bool: Whether this Estimator can use network isolation or not.
1616
- """
1617
- return self ._enable_network_isolation
1618
-
1619
1627
def _prepare_for_training (self , job_name = None ):
1620
1628
"""Set hyperparameters needed for training. This method will also
1621
1629
validate ``source_dir``.
@@ -1891,6 +1899,8 @@ def transformer(
1891
1899
volume_kms_key = None ,
1892
1900
entry_point = None ,
1893
1901
vpc_config_override = vpc_utils .VPC_CONFIG_DEFAULT ,
1902
+ enable_network_isolation = None ,
1903
+ model_name = None ,
1894
1904
):
1895
1905
"""Return a ``Transformer`` that uses a SageMaker Model based on the
1896
1906
training job. It reuses the SageMaker Session and base job name used by
@@ -1935,9 +1945,21 @@ def transformer(
1935
1945
vpc_config_override (dict[str, list[str]]): Optional override for
1936
1946
the VpcConfig set on the model.
1937
1947
Default: use subnets and security groups from this Estimator.
1948
+
1938
1949
* 'Subnets' (list[str]): List of subnet ids.
1939
1950
* 'SecurityGroupIds' (list[str]): List of security group ids.
1940
1951
1952
+ enable_network_isolation (bool): Specifies whether container will
1953
+ run in network isolation mode. Network isolation mode restricts
1954
+ the container access to outside networks (such as the internet).
1955
+ The container does not make any inbound or outbound network
1956
+ calls. If True, a channel named "code" will be created for any
1957
+ user entry script for inference. Also known as Internet-free mode.
1958
+ If not specified, this setting is taken from the estimator's
1959
+ current configuration.
1960
+ model_name (str): Name to use for creating an Amazon SageMaker
1961
+ model. If not specified, the name of the training job is used.
1962
+
1941
1963
Returns:
1942
1964
sagemaker.transformer.Transformer: a ``Transformer`` object that can be used to start a
1943
1965
SageMaker Batch Transform job.
@@ -1946,12 +1968,17 @@ def transformer(
1946
1968
tags = tags or self .tags
1947
1969
1948
1970
if self .latest_training_job is not None :
1971
+ if enable_network_isolation is None :
1972
+ enable_network_isolation = self .enable_network_isolation ()
1973
+
1949
1974
model = self .create_model (
1950
1975
role = role ,
1951
1976
model_server_workers = model_server_workers ,
1952
1977
entry_point = entry_point ,
1953
1978
vpc_config_override = vpc_config_override ,
1954
1979
model_kms_key = self .output_kms_key ,
1980
+ enable_network_isolation = enable_network_isolation ,
1981
+ name = model_name ,
1955
1982
)
1956
1983
model ._create_sagemaker_model (instance_type , tags = tags )
1957
1984
@@ -1964,7 +1991,7 @@ def transformer(
1964
1991
"No finished training job found associated with this estimator. Please make sure "
1965
1992
"this estimator is only used for building workflow config"
1966
1993
)
1967
- model_name = self ._current_job_name
1994
+ model_name = model_name or self ._current_job_name
1968
1995
transform_env = env or {}
1969
1996
1970
1997
return Transformer (
0 commit comments