@@ -117,6 +117,8 @@ def __init__(self, role, train_instance_count, train_instance_type,
117
117
self .metric_definitions = metric_definitions
118
118
self .model_uri = model_uri
119
119
self .model_channel_name = model_channel_name
120
+ self .code_uri = None
121
+ self .code_channel_name = 'code'
120
122
121
123
if self .train_instance_type in ('local' , 'local_gpu' ):
122
124
if self .train_instance_type == 'local_gpu' and self .train_instance_count > 1 :
@@ -773,9 +775,11 @@ class Framework(EstimatorBase):
773
775
LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled'
774
776
MPI_NUM_PROCESSES_PER_HOST = 'sagemaker_mpi_num_of_processes_per_host'
775
777
MPI_CUSTOM_MPI_OPTIONS = 'sagemaker_mpi_custom_mpi_options'
778
+ CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = '/opt/ml/input/data/code/sourcedir.tar.gz'
776
779
777
780
def __init__ (self , entry_point , source_dir = None , hyperparameters = None , enable_cloudwatch_metrics = False ,
778
- container_log_level = logging .INFO , code_location = None , image_name = None , dependencies = None , ** kwargs ):
781
+ container_log_level = logging .INFO , code_location = None , image_name = None , dependencies = None ,
782
+ enable_network_isolation = False , ** kwargs ):
779
783
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
780
784
781
785
Args:
@@ -784,6 +788,21 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
784
788
source_dir (str): Path (absolute or relative) to a directory with any other training
785
789
source code dependencies aside from the entry point file (default: None). Structure within this
786
790
directory are preserved when training on Amazon SageMaker.
791
+ hyperparameters (dict): Hyperparameters that will be used for training (default: None).
792
+ The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
793
+ For convenience, this accepts other types for keys and values, but ``str()`` will be called
794
+ to convert them before training.
795
+ enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker
796
+ training jobs. This will be ignored for now and removed in a further release.
797
+ container_log_level (int): Log level to use within the container (default: logging.INFO).
798
+ Valid values are defined in the Python logging module.
799
+ code_location (str): The S3 prefix URI where custom code will be uploaded (default: None).
800
+ The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'.
801
+ If not specified, the default code location is s3://default_bucket/job-name/. And code file
802
+ uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz
803
+ image_name (str): An alternate image name to use instead of the official Sagemaker image
804
+ for the framework. This is useful to run one of the Sagemaker supported frameworks
805
+ with an image containing custom dependencies.
787
806
dependencies (list[str]): A list of paths to directories (absolute or relative) with
788
807
any additional libraries that will be exported to the container (default: []).
789
808
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
@@ -800,21 +819,11 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
800
819
>>> |------ common
801
820
>>> |------ virtual-env
802
821
803
- hyperparameters (dict): Hyperparameters that will be used for training (default: None).
804
- The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
805
- For convenience, this accepts other types for keys and values, but ``str()`` will be called
806
- to convert them before training.
807
- enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker
808
- training jobs. This will be ignored for now and removed in a further release.
809
- container_log_level (int): Log level to use within the container (default: logging.INFO).
810
- Valid values are defined in the Python logging module.
811
- code_location (str): The S3 prefix URI where custom code will be uploaded (default: None).
812
- The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'.
813
- If not specified, the default code location is s3://default_bucket/job-name/. And code file
814
- uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz
815
- image_name (str): An alternate image name to use instead of the official Sagemaker image
816
- for the framework. This is useful to run one of the Sagemaker supported frameworks
817
- with an image containing custom dependencies.
822
+ enable_network_isolation (bool): Specifies whether container will run in network isolation mode. Network
823
+ isolation mode restricts the container access to outside networks (such as the internet). The container
824
+ does not make any inbound or outbound network calls. If True, a channel named "code" will be created
825
+ for any user entry script for training. The user entry script, files in source_dir (if specified), and
826
+ dependencies will be uploaded in a tar to S3. Also known as internet-free mode (default: `False`).
818
827
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
819
828
"""
820
829
super (Framework , self ).__init__ (** kwargs )
@@ -830,9 +839,18 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None, enable_cl
830
839
self .container_log_level = container_log_level
831
840
self .code_location = code_location
832
841
self .image_name = image_name
842
+ self ._enable_network_isolation = enable_network_isolation
833
843
834
844
self ._hyperparameters = hyperparameters or {}
835
845
846
+ def enable_network_isolation (self ):
847
+ """Return True if this Estimator can use network isolation to run.
848
+
849
+ Returns:
850
+ bool: Whether this Estimator can use network isolation or not.
851
+ """
852
+ return self ._enable_network_isolation
853
+
836
854
def _prepare_for_training (self , job_name = None ):
837
855
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
838
856
@@ -858,6 +876,11 @@ def _prepare_for_training(self, job_name=None):
858
876
859
877
code_dir = 'file://' + self .source_dir
860
878
script = self .entry_point
879
+ elif self .enable_network_isolation () and self .entry_point :
880
+ self .uploaded_code = self ._stage_user_code_in_s3 ()
881
+ code_dir = self .CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH
882
+ script = self .uploaded_code .script_name
883
+ self .code_uri = self .uploaded_code .s3_prefix
861
884
else :
862
885
self .uploaded_code = self ._stage_user_code_in_s3 ()
863
886
code_dir = self .uploaded_code .s3_prefix
@@ -881,12 +904,12 @@ def _stage_user_code_in_s3(self):
881
904
882
905
if self .code_location is None and local_mode :
883
906
code_bucket = self .sagemaker_session .default_bucket ()
884
- code_s3_prefix = '{}/source ' .format (self ._current_job_name )
907
+ code_s3_prefix = '{}/{} ' .format (self ._current_job_name , 'source' )
885
908
kms_key = None
886
909
887
910
elif self .code_location is None :
888
911
code_bucket , _ = parse_s3_url (self .output_path )
889
- code_s3_prefix = '{}/source ' .format (self ._current_job_name )
912
+ code_s3_prefix = '{}/{} ' .format (self ._current_job_name , 'source' )
890
913
kms_key = self .output_kms_key
891
914
else :
892
915
code_bucket , key_prefix = parse_s3_url (self .code_location )
0 commit comments