11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
13
"""Placeholder docstring"""
14
- from __future__ import absolute_import
14
+ from __future__ import absolute_import , annotations
15
15
16
16
import base64
17
17
import copy
32
32
33
33
from distutils .spawn import find_executable
34
34
from threading import Thread
35
+ from typing import Dict , List
35
36
from six .moves .urllib .parse import urlparse
36
37
37
38
import sagemaker
@@ -74,6 +75,7 @@ def __init__(
74
75
sagemaker_session = None ,
75
76
container_entrypoint = None ,
76
77
container_arguments = None ,
78
+ container_default_config = None ,
77
79
):
78
80
"""Initialize a SageMakerContainer instance
79
81
@@ -90,6 +92,8 @@ def __init__(
90
92
to use when interacting with SageMaker.
91
93
container_entrypoint (str): the container entrypoint to execute
92
94
container_arguments (str): the container entrypoint arguments
95
+ container_default_config (Dict | None): the dict of user-defined docker
96
+ configuration. Defaults to ``None``
93
97
"""
94
98
from sagemaker .local .local_session import LocalSession
95
99
@@ -102,6 +106,7 @@ def __init__(
102
106
self .image = image
103
107
self .container_entrypoint = container_entrypoint
104
108
self .container_arguments = container_arguments
109
+ self .container_default_config = container_default_config or {}
105
110
# Since we are using a single docker network, Generate a random suffix to attach to the
106
111
# container names. This way multiple jobs can run in parallel.
107
112
suffix = "" .join (random .choice (string .ascii_lowercase + string .digits ) for _ in range (5 ))
@@ -768,16 +773,23 @@ def _compose(self, detached=False):
768
773
769
774
logger .info ("docker command: %s" , " " .join (compose_cmd ))
770
775
return compose_cmd
771
-
772
- def _create_docker_host (self , host , environment , optml_subdirs , command , volumes ):
776
+
777
+ def _create_docker_host (
778
+ self ,
779
+ host : str ,
780
+ environment : List [str ],
781
+ optml_subdirs : set [str ],
782
+ command : str ,
783
+ volumes : List ,
784
+ ) -> Dict :
773
785
"""Creates the docker host configuration.
774
786
775
787
Args:
776
- host:
777
- environment:
778
- optml_subdirs:
779
- command:
780
- volumes:
788
+ host (str): The host address
789
+ environment (List[str]): List of environment variables
790
+ optml_subdirs (Set[str]): Set of subdirs
791
+ command (str): Either 'train' or 'serve'
792
+ volumes (list): List of volumes that will be mapped to the containers
781
793
"""
782
794
optml_volumes = self ._build_optml_volumes (host , optml_subdirs )
783
795
optml_volumes .extend (volumes )
@@ -787,6 +799,7 @@ def _create_docker_host(self, host, environment, optml_subdirs, command, volumes
787
799
)
788
800
789
801
host_config = {
802
+ ** self .container_default_config ,
790
803
"image" : self .image ,
791
804
"container_name" : f"{ container_name_prefix } -{ host } " ,
792
805
"stdin_open" : True ,
0 commit comments