Skip to content

Commit 40dba74

Browse files
committed
Add container config to local mode config
1 parent a8e0eb5 commit 40dba74

File tree

7 files changed

+115
-36
lines changed

7 files changed

+115
-36
lines changed

src/sagemaker/config/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
"""This module configures the default values for SageMaker Python SDK."""
1414

1515
from __future__ import absolute_import
16-
from sagemaker.config.config import load_sagemaker_config, validate_sagemaker_config # noqa: F401
16+
from sagemaker.config.config import ( # noqa: F401
17+
load_local_mode_config,
18+
load_sagemaker_config,
19+
validate_sagemaker_config,
20+
)
1721
from sagemaker.config.config_schema import ( # noqa: F401
1822
KEY,
1923
TRAINING_JOB,
@@ -161,4 +165,8 @@
161165
INFERENCE_SPECIFICATION,
162166
ESTIMATOR,
163167
DEBUG_HOOK_CONFIG,
168+
SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA,
169+
LOCAL,
170+
LOCAL_CODE,
171+
CONTAINER_CONFIG,
164172
)

src/sagemaker/config/config.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
The schema of the config file is dictated in config_schema.py in the same module.
1717
1818
"""
19-
from __future__ import absolute_import
19+
from __future__ import absolute_import, annotations
2020

2121
import pathlib
2222
import os
@@ -33,12 +33,18 @@
3333
logger = get_sagemaker_config_logger()
3434

3535
_APP_NAME = "sagemaker"
36+
# The default name of the config file.
37+
_CONFIG_FILE_NAME = "config.yaml"
3638
# The default config file location of the Administrator provided config file. This path can be
3739
# overridden with `SAGEMAKER_ADMIN_CONFIG_OVERRIDE` environment variable.
38-
_DEFAULT_ADMIN_CONFIG_FILE_PATH = os.path.join(site_config_dir(_APP_NAME), "config.yaml")
40+
_DEFAULT_ADMIN_CONFIG_FILE_PATH = os.path.join(site_config_dir(_APP_NAME), _CONFIG_FILE_NAME)
3941
# The default config file location of the user provided config file. This path can be
4042
# overridden with `SAGEMAKER_USER_CONFIG_OVERRIDE` environment variable.
41-
_DEFAULT_USER_CONFIG_FILE_PATH = os.path.join(user_config_dir(_APP_NAME), "config.yaml")
43+
_DEFAULT_USER_CONFIG_FILE_PATH = os.path.join(user_config_dir(_APP_NAME), _CONFIG_FILE_NAME)
44+
# The default config file location of the local mode.
45+
_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH = os.path.join(
46+
os.path.expanduser("~"), ".sagemaker", _CONFIG_FILE_NAME
47+
)
4248

4349
ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE = "SAGEMAKER_ADMIN_CONFIG_OVERRIDE"
4450
ENV_VARIABLE_USER_CONFIG_OVERRIDE = "SAGEMAKER_USER_CONFIG_OVERRIDE"
@@ -144,11 +150,21 @@ def validate_sagemaker_config(sagemaker_config: dict = None):
144150
jsonschema.validate(sagemaker_config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA)
145151

146152

153+
def load_local_mode_config() -> dict | None:
154+
"""Loads the local mode config file."""
155+
try:
156+
content = _load_config_from_file(_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH)
157+
except ValueError:
158+
content = None
159+
160+
return content
161+
162+
147163
def _load_config_from_file(file_path: str) -> dict:
148164
"""Placeholder docstring"""
149165
inferred_file_path = file_path
150166
if os.path.isdir(file_path):
151-
inferred_file_path = os.path.join(file_path, "config.yaml")
167+
inferred_file_path = os.path.join(file_path, _CONFIG_FILE_NAME)
152168
if not os.path.exists(inferred_file_path):
153169
raise ValueError(
154170
f"Unable to load the config file from the location: {file_path}"
@@ -194,10 +210,14 @@ def _get_inferred_s3_uri(s3_uri, s3_resource_for_config):
194210
if len(s3_files_with_same_prefix) > 1:
195211
# Customer has provided us with a S3 URI which points to a directory
196212
# search for s3://<bucket>/directory-key-prefix/config.yaml
197-
inferred_s3_uri = str(pathlib.PurePosixPath(s3_uri, "config.yaml")).replace("s3:/", "s3://")
213+
inferred_s3_uri = str(pathlib.PurePosixPath(s3_uri, _CONFIG_FILE_NAME)).replace(
214+
"s3:/", "s3://"
215+
)
198216
if inferred_s3_uri not in s3_files_with_same_prefix:
199217
# We don't know which file we should be operating with.
200-
raise ValueError("Provide an S3 URI of a directory that has a config.yaml file.")
218+
raise ValueError(
219+
f"Provide an S3 URI of a directory that has a {_CONFIG_FILE_NAME} file."
220+
)
201221
# Customer has a config.yaml present in the directory that was provided as the S3 URI
202222
return inferred_s3_uri
203223
return s3_uri

src/sagemaker/config/config_schema.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@
102102
DISABLE_PROFILER = "DisableProfiler"
103103
ESTIMATOR = "Estimator"
104104
DEBUG_HOOK_CONFIG = "DebugHookConfig"
105-
105+
LOCAL = "local"
106+
LOCAL_CODE = "local_code"
107+
SERVING_PORT = "serving_port"
108+
CONTAINER_CONFIG = "container_config"
109+
REGION_NAME = "region_name"
106110

107111
def _simple_path(*args: str):
108112
"""Appends an arbitrary number of strings to use as path constants"""
@@ -1068,3 +1072,28 @@ def _simple_path(*args: str):
10681072
},
10691073
},
10701074
}
1075+
1076+
SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA = {
1077+
"$schema": "https://json-schema.org/draft/2020-12/schema",
1078+
TYPE: OBJECT,
1079+
ADDITIONAL_PROPERTIES: False,
1080+
PROPERTIES: {
1081+
LOCAL: {
1082+
TYPE: OBJECT,
1083+
ADDITIONAL_PROPERTIES: False,
1084+
PROPERTIES: {
1085+
LOCAL_CODE: {
1086+
TYPE: "boolean",
1087+
},
1088+
REGION_NAME: {TYPE: "string"},
1089+
SERVING_PORT: {
1090+
TYPE: "integer",
1091+
},
1092+
CONTAINER_CONFIG: {
1093+
TYPE: OBJECT,
1094+
},
1095+
},
1096+
},
1097+
},
1098+
"required": [LOCAL],
1099+
}

src/sagemaker/local/image.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from six.moves.urllib.parse import urlparse
3737

3838
import sagemaker
39+
from sagemaker.config.config_schema import CONTAINER_CONFIG, LOCAL
3940
import sagemaker.local.data
4041
import sagemaker.local.utils
4142
import sagemaker.utils
@@ -75,7 +76,6 @@ def __init__(
7576
sagemaker_session=None,
7677
container_entrypoint=None,
7778
container_arguments=None,
78-
container_default_config=None,
7979
):
8080
"""Initialize a SageMakerContainer instance
8181
@@ -92,8 +92,6 @@ def __init__(
9292
to use when interacting with SageMaker.
9393
container_entrypoint (str): the container entrypoint to execute
9494
container_arguments (str): the container entrypoint arguments
95-
container_default_config (Dict | None): the dict of user-defined docker
96-
configuration. Defaults to ``None``
9795
"""
9896
from sagemaker.local.local_session import LocalSession
9997

@@ -106,7 +104,6 @@ def __init__(
106104
self.image = image
107105
self.container_entrypoint = container_entrypoint
108106
self.container_arguments = container_arguments
109-
self.container_default_config = container_default_config or {}
110107
# Since we are using a single docker network, Generate a random suffix to attach to the
111108
# container names. This way multiple jobs can run in parallel.
112109
suffix = "".join(random.choice(string.ascii_lowercase + string.digits) for _ in range(5))
@@ -773,7 +770,7 @@ def _compose(self, detached=False):
773770

774771
logger.info("docker command: %s", " ".join(compose_cmd))
775772
return compose_cmd
776-
773+
777774
def _create_docker_host(
778775
self,
779776
host: str,
@@ -785,8 +782,8 @@ def _create_docker_host(
785782
"""Creates the docker host configuration.
786783
787784
Args:
788-
host (str): The host address
789-
environment (List[str]): List of environment variables
785+
host (str): The host address
786+
environment (List[str]): List of environment variables
790787
optml_subdirs (Set[str]): Set of subdirs
791788
command (str): Either 'train' or 'serve'
792789
volumes (list): List of volumes that will be mapped to the containers
@@ -797,9 +794,15 @@ def _create_docker_host(
797794
container_name_prefix = "".join(
798795
random.choice(string.ascii_lowercase + string.digits) for _ in range(10)
799796
)
797+
container_default_config = (
798+
sagemaker.utils.get_config_value(
799+
f"{LOCAL}.{CONTAINER_CONFIG}", self.sagemaker_session.config
800+
)
801+
or {}
802+
)
800803

801804
host_config = {
802-
**self.container_default_config,
805+
**container_default_config,
803806
"image": self.image,
804807
"container_name": f"{container_name_prefix}-{host}",
805808
"stdin_open": True,

src/sagemaker/local/local_session.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,24 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Placeholder docstring"""
14-
from __future__ import absolute_import
14+
from __future__ import absolute_import, annotations
1515

1616
import logging
17-
import os
1817
import platform
1918
from datetime import datetime
19+
from typing import Dict
2020

2121
import boto3
2222
from botocore.exceptions import ClientError
23+
import jsonschema
2324

2425
from sagemaker.config import (
25-
load_sagemaker_config,
26-
validate_sagemaker_config,
26+
SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA,
2727
SESSION_DEFAULT_S3_BUCKET_PATH,
2828
SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
29+
load_local_mode_config,
30+
load_sagemaker_config,
31+
validate_sagemaker_config,
2932
)
3033
from sagemaker.local.image import _SageMakerContainer
3134
from sagemaker.local.utils import get_docker_host
@@ -83,7 +86,7 @@ def create_processing_job(
8386
Environment=None,
8487
ProcessingInputs=None,
8588
ProcessingOutputConfig=None,
86-
**kwargs
89+
**kwargs,
8790
):
8891
"""Creates a processing job in Local Mode
8992
@@ -128,7 +131,6 @@ def create_processing_job(
128131
sagemaker_session=self.sagemaker_session,
129132
container_entrypoint=container_entrypoint,
130133
container_arguments=container_arguments,
131-
container_default_config=self._container_default_config
132134
)
133135
processing_job = _LocalProcessingJob(container)
134136
logger.info("Starting processing job")
@@ -166,7 +168,7 @@ def create_training_job(
166168
ResourceConfig,
167169
InputDataConfig=None,
168170
Environment=None,
169-
**kwargs
171+
**kwargs,
170172
):
171173
"""Create a training job in Local Mode.
172174
@@ -231,7 +233,7 @@ def create_transform_job(
231233
TransformInput,
232234
TransformOutput,
233235
TransformResources,
234-
**kwargs
236+
**kwargs,
235237
):
236238
"""Create the transform job.
237239
@@ -727,19 +729,25 @@ def _initialize(
727729
sagemaker_session=self,
728730
)
729731

730-
local_mode_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml")
731-
if os.path.exists(local_mode_config_file):
732+
self.config = load_local_mode_config()
733+
if self._disable_local_code and self.config and "local" in self.config:
734+
self.config["local"]["local_code"] = False
735+
736+
@Session.config.setter
737+
def config(self, value: Dict | None):
738+
"""Setter of the local mode config"""
739+
if value is not None:
732740
try:
733-
import yaml
734-
except ImportError as e:
735-
logger.error(_module_import_error("yaml", "Local mode", "local"))
741+
jsonschema.validate(value, SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA)
742+
except jsonschema.ValidationError as e:
743+
logger.error("Failed to validate the local mode config")
736744
raise e
745+
self._config = value
746+
else:
747+
self._config = value
737748

738-
self.config = yaml.safe_load(open(local_mode_config_file, "r"))
739-
if self._disable_local_code and "local" in self.config:
740-
self.config["local"]["local_code"] = False
741-
742-
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
749+
# update the runtime client on config changed
750+
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self._config)
743751

744752
def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
745753
"""A no-op method meant to override the sagemaker client.

src/sagemaker/session.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Placeholder docstring"""
14-
from __future__ import absolute_import, print_function
14+
from __future__ import absolute_import, annotations, print_function
1515

1616
import json
1717
import logging
@@ -239,7 +239,7 @@ def __init__(
239239
self.s3_client = None
240240
self.resource_groups_client = None
241241
self.resource_group_tagging_client = None
242-
self.config = None
242+
self._config = None
243243
self.lambda_client = None
244244
self.settings = settings
245245

@@ -326,6 +326,16 @@ def _initialize(
326326
sagemaker_session=self,
327327
)
328328

329+
@property
330+
def config(self) -> Dict | None:
331+
"""The config for the local mode, unused in a normal session"""
332+
return self._config
333+
334+
@config.setter
335+
def config(self, value: Dict | None):
336+
"""The config for the local mode, unused in a normal session"""
337+
self._config = value
338+
329339
@property
330340
def boto_region_name(self):
331341
"""Placeholder docstring"""

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ ignore =
4242
FI55,
4343
FI56,
4444
FI57,
45+
FI58,
4546
W503
4647

4748
require-code = True

0 commit comments

Comments
 (0)