@@ -118,6 +118,8 @@ def __init__(self, role, train_instance_count, train_instance_type,
118
118
self .metric_definitions = metric_definitions
119
119
self .model_uri = model_uri
120
120
self .model_channel_name = model_channel_name
121
+ self .code_uri = None
122
+ self .code_channel_name = 'code'
121
123
122
124
if self .train_instance_type in ('local' , 'local_gpu' ):
123
125
if self .train_instance_type == 'local_gpu' and self .train_instance_count > 1 :
@@ -774,10 +776,11 @@ class Framework(EstimatorBase):
774
776
LAUNCH_MPI_ENV_NAME = 'sagemaker_mpi_enabled'
775
777
MPI_NUM_PROCESSES_PER_HOST = 'sagemaker_mpi_num_of_processes_per_host'
776
778
MPI_CUSTOM_MPI_OPTIONS = 'sagemaker_mpi_custom_mpi_options'
779
+ CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH = '/opt/ml/input/data/code/sourcedir.tar.gz'
777
780
778
- def __init__ (self , entry_point , source_dir = None , hyperparameters = None ,
779
- enable_cloudwatch_metrics = False , container_log_level = logging .INFO , code_location = None ,
780
- image_name = None , dependencies = None , git_config = None , ** kwargs ):
781
+ def __init__ (self , entry_point , source_dir = None , hyperparameters = None , enable_cloudwatch_metrics = False ,
782
+ container_log_level = logging .INFO , code_location = None , image_name = None , dependencies = None ,
783
+ enable_network_isolation = False , ** kwargs ):
781
784
"""Base class initializer. Subclasses which override ``__init__`` should invoke ``super()``
782
785
783
786
Args:
@@ -811,6 +814,7 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None,
811
814
the specified commit.
812
815
source_dir (str): Path (absolute or relative) to a directory with any other training
813
816
source code dependencies aside from the entry point file (default: None). Structure within this
817
+ <<<<<<< HEAD
814
818
directory are preserved when training on Amazon SageMaker. If 'git_config' is provided,
815
819
source_dir should be a relative location to a directory in the Git repo.
816
820
Example:
@@ -824,6 +828,24 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None,
824
828
825
829
and you need 'train.py' as entry point and 'test.py' as training source code as well, you can
826
830
assign entry_point='train.py', source_dir='src'.
831
+ =======
832
+ directory are preserved when training on Amazon SageMaker.
833
+ hyperparameters (dict): Hyperparameters that will be used for training (default: None).
834
+ The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
835
+ For convenience, this accepts other types for keys and values, but ``str()`` will be called
836
+ to convert them before training.
837
+ enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker
838
+ training jobs. This will be ignored for now and removed in a further release.
839
+ container_log_level (int): Log level to use within the container (default: logging.INFO).
840
+ Valid values are defined in the Python logging module.
841
+ code_location (str): The S3 prefix URI where custom code will be uploaded (default: None).
842
+ The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'.
843
+ If not specified, the default code location is s3://default_bucket/job-name/. And code file
844
+ uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz
845
+ image_name (str): An alternate image name to use instead of the official Sagemaker image
846
+ for the framework. This is useful to run one of the Sagemaker supported frameworks
847
+ with an image containing custom dependencies.
848
+ >>>>>>> f1d34ad4073f8d856ef9c596b491f8a4cd8ef31f
827
849
dependencies (list[str]): A list of paths to directories (absolute or relative) with
828
850
any additional libraries that will be exported to the container (default: []).
829
851
The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
@@ -840,21 +862,11 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None,
840
862
>>> |------ common
841
863
>>> |------ virtual-env
842
864
843
- hyperparameters (dict): Hyperparameters that will be used for training (default: None).
844
- The hyperparameters are made accessible as a dict[str, str] to the training code on SageMaker.
845
- For convenience, this accepts other types for keys and values, but ``str()`` will be called
846
- to convert them before training.
847
- enable_cloudwatch_metrics (bool): [DEPRECATED] Now there are cloudwatch metrics emitted by all SageMaker
848
- training jobs. This will be ignored for now and removed in a further release.
849
- container_log_level (int): Log level to use within the container (default: logging.INFO).
850
- Valid values are defined in the Python logging module.
851
- code_location (str): The S3 prefix URI where custom code will be uploaded (default: None).
852
- The code file uploaded in S3 is 'code_location/source/sourcedir.tar.gz'.
853
- If not specified, the default code location is s3://default_bucket/job-name/. And code file
854
- uploaded to S3 is s3://default_bucket/job-name/source/sourcedir.tar.gz
855
- image_name (str): An alternate image name to use instead of the official Sagemaker image
856
- for the framework. This is useful to run one of the Sagemaker supported frameworks
857
- with an image containing custom dependencies.
865
+ enable_network_isolation (bool): Specifies whether container will run in network isolation mode. Network
866
+ isolation mode restricts the container access to outside networks (such as the internet). The container
867
+ does not make any inbound or outbound network calls. If True, a channel named "code" will be created
868
+ for any user entry script for training. The user entry script, files in source_dir (if specified), and
869
+ dependencies will be uploaded in a tar to S3. Also known as internet-free mode (default: `False`).
858
870
**kwargs: Additional kwargs passed to the ``EstimatorBase`` constructor.
859
871
"""
860
872
super (Framework , self ).__init__ (** kwargs )
@@ -872,9 +884,18 @@ def __init__(self, entry_point, source_dir=None, hyperparameters=None,
872
884
self .container_log_level = container_log_level
873
885
self .code_location = code_location
874
886
self .image_name = image_name
887
+ self ._enable_network_isolation = enable_network_isolation
875
888
876
889
self ._hyperparameters = hyperparameters or {}
877
890
891
+ def enable_network_isolation (self ):
892
+ """Return True if this Estimator can use network isolation to run.
893
+
894
+ Returns:
895
+ bool: Whether this Estimator can use network isolation or not.
896
+ """
897
+ return self ._enable_network_isolation
898
+
878
899
def _prepare_for_training (self , job_name = None ):
879
900
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
880
901
@@ -907,6 +928,11 @@ def _prepare_for_training(self, job_name=None):
907
928
908
929
code_dir = 'file://' + self .source_dir
909
930
script = self .entry_point
931
+ elif self .enable_network_isolation () and self .entry_point :
932
+ self .uploaded_code = self ._stage_user_code_in_s3 ()
933
+ code_dir = self .CONTAINER_CODE_CHANNEL_SOURCEDIR_PATH
934
+ script = self .uploaded_code .script_name
935
+ self .code_uri = self .uploaded_code .s3_prefix
910
936
else :
911
937
self .uploaded_code = self ._stage_user_code_in_s3 ()
912
938
code_dir = self .uploaded_code .s3_prefix
@@ -930,12 +956,12 @@ def _stage_user_code_in_s3(self):
930
956
931
957
if self .code_location is None and local_mode :
932
958
code_bucket = self .sagemaker_session .default_bucket ()
933
- code_s3_prefix = '{}/source ' .format (self ._current_job_name )
959
+ code_s3_prefix = '{}/{} ' .format (self ._current_job_name , 'source' )
934
960
kms_key = None
935
961
936
962
elif self .code_location is None :
937
963
code_bucket , _ = parse_s3_url (self .output_path )
938
- code_s3_prefix = '{}/source ' .format (self ._current_job_name )
964
+ code_s3_prefix = '{}/{} ' .format (self ._current_job_name , 'source' )
939
965
kms_key = self .output_kms_key
940
966
else :
941
967
code_bucket , key_prefix = parse_s3_url (self .code_location )
0 commit comments