Skip to content

Commit a8e0eb5

Browse files
committed
Add container_default_config argument
1 parent 897cfe4 commit a8e0eb5

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

src/sagemaker/local/image.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Placeholder docstring"""
14-
from __future__ import absolute_import
14+
from __future__ import absolute_import, annotations
1515

1616
import base64
1717
import copy
@@ -32,6 +32,7 @@
3232

3333
from distutils.spawn import find_executable
3434
from threading import Thread
35+
from typing import Dict, List
3536
from six.moves.urllib.parse import urlparse
3637

3738
import sagemaker
@@ -74,6 +75,7 @@ def __init__(
7475
sagemaker_session=None,
7576
container_entrypoint=None,
7677
container_arguments=None,
78+
container_default_config=None,
7779
):
7880
"""Initialize a SageMakerContainer instance
7981
@@ -90,6 +92,8 @@ def __init__(
9092
to use when interacting with SageMaker.
9193
container_entrypoint (str): the container entrypoint to execute
9294
container_arguments (str): the container entrypoint arguments
95+
container_default_config (Dict | None): the dict of user-defined docker
96+
configuration. Defaults to ``None``
9397
"""
9498
from sagemaker.local.local_session import LocalSession
9599

@@ -102,6 +106,7 @@ def __init__(
102106
self.image = image
103107
self.container_entrypoint = container_entrypoint
104108
self.container_arguments = container_arguments
109+
self.container_default_config = container_default_config or {}
105110
# Since we are using a single docker network, Generate a random suffix to attach to the
106111
# container names. This way multiple jobs can run in parallel.
107112
suffix = "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(5))
@@ -768,16 +773,23 @@ def _compose(self, detached=False):
768773

769774
logger.info("docker command: %s", " ".join(compose_cmd))
770775
return compose_cmd
771-
772-
def _create_docker_host(self, host, environment, optml_subdirs, command, volumes):
776+
777+
def _create_docker_host(
778+
self,
779+
host: str,
780+
environment: List[str],
781+
optml_subdirs: set[str],
782+
command: str,
783+
volumes: List,
784+
) -> Dict:
773785
"""Creates the docker host configuration.
774786
775787
Args:
776-
host:
777-
environment:
778-
optml_subdirs:
779-
command:
780-
volumes:
788+
host (str): The host address
789+
environment (List[str]): List of environment variables
790+
optml_subdirs (Set[str]): Set of subdirs
791+
command (str): Either 'train' or 'serve'
792+
volumes (list): List of volumes that will be mapped to the containers
781793
"""
782794
optml_volumes = self._build_optml_volumes(host, optml_subdirs)
783795
optml_volumes.extend(volumes)
@@ -787,6 +799,7 @@ def _create_docker_host(self, host, environment, optml_subdirs, command, volumes
787799
)
788800

789801
host_config = {
802+
**self.container_default_config,
790803
"image": self.image,
791804
"container_name": f"{container_name_prefix}-{host}",
792805
"stdin_open": True,

src/sagemaker/local/local_session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def create_processing_job(
128128
sagemaker_session=self.sagemaker_session,
129129
container_entrypoint=container_entrypoint,
130130
container_arguments=container_arguments,
131+
container_default_config=self._container_default_config
131132
)
132133
processing_job = _LocalProcessingJob(container)
133134
logger.info("Starting processing job")
@@ -685,7 +686,7 @@ def _initialize(
685686
)
686687

687688
self.sagemaker_client = LocalSagemakerClient(self)
688-
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
689+
689690
self.local_mode = True
690691
sagemaker_config = kwargs.get("sagemaker_config", None)
691692
if sagemaker_config:
@@ -738,6 +739,8 @@ def _initialize(
738739
if self._disable_local_code and "local" in self.config:
739740
self.config["local"]["local_code"] = False
740741

742+
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
743+
741744
def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
742745
"""A no-op method meant to override the sagemaker client.
743746

0 commit comments

Comments
 (0)