Skip to content

Commit 2aecf36

Browse files
authored
Merge branch 'master' into feat/jumpstart-default-payloads
2 parents c869f30 + a9ac311 commit 2aecf36

File tree

21 files changed

+420
-49
lines changed

21 files changed

+420
-49
lines changed

doc/overview.rst

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,9 +1560,20 @@ You can install necessary dependencies for this feature using pip.
15601560
Additionally, Local Mode also requires Docker Compose V2. Follow the guidelines in https://docs.docker.com/compose/install/ to install.
15611561
Make sure to have a Compose Version compatible with your Docker Engine installation. Check Docker Engine release notes https://docs.docker.com/engine/release-notes to find a compatible version.
15621562

1563-
If you want to keep everything local, and not use Amazon S3 either, you can enable "local code" in one of two ways:
1563+
Local mode configuration
1564+
========================
15641565

1565-
- Create a file at ``~/.sagemaker/config.yaml`` that contains:
1566+
The local mode uses a YAML configuration file located at ``~/.sagemaker/config.yaml`` to define the default values that are automatically passed to the ``config`` attribute of ``LocalSession``. This is an example of the configuration, for the full schema, see `sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA <https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/config/config_schema.py>`_.
1567+
1568+
.. code:: yaml
1569+
1570+
local:
1571+
local_code: true # Using everything locally
1572+
region_name: "us-west-2" # Name of the region
1573+
container_config: # Additional docker container config
1574+
shm_size: "128M
1575+
1576+
If you want to keep everything local, and not use Amazon S3 either, you can enable "local code" in one of two ways:
15661577
15671578
.. code:: yaml
15681579
@@ -1583,6 +1594,9 @@ If you want to keep everything local, and not use Amazon S3 either, you can enab
15831594
.. note::
15841595
If you enable "local code," then you cannot use the ``dependencies`` parameter in your estimator or model.
15851596

1597+
Activating local mode by ``instance_type`` argument
1598+
====================================================
1599+
15861600
We can take the example in `Using Estimators <#using-estimators>`__ , and use either ``local`` or ``local_gpu`` as the instance type.
15871601

15881602
.. code:: python

src/sagemaker/config/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,11 @@
1313
"""This module configures the default values for SageMaker Python SDK."""
1414

1515
from __future__ import absolute_import
16-
from sagemaker.config.config import load_sagemaker_config, validate_sagemaker_config # noqa: F401
16+
from sagemaker.config.config import ( # noqa: F401
17+
load_local_mode_config,
18+
load_sagemaker_config,
19+
validate_sagemaker_config,
20+
)
1721
from sagemaker.config.config_schema import ( # noqa: F401
1822
KEY,
1923
TRAINING_JOB,
@@ -161,4 +165,8 @@
161165
INFERENCE_SPECIFICATION,
162166
ESTIMATOR,
163167
DEBUG_HOOK_CONFIG,
168+
SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA,
169+
LOCAL,
170+
LOCAL_CODE,
171+
CONTAINER_CONFIG,
164172
)

src/sagemaker/config/config.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
The schema of the config file is dictated in config_schema.py in the same module.
1717
1818
"""
19-
from __future__ import absolute_import
19+
from __future__ import absolute_import, annotations
2020

2121
import pathlib
2222
import os
@@ -33,12 +33,18 @@
3333
logger = get_sagemaker_config_logger()
3434

3535
_APP_NAME = "sagemaker"
36+
# The default name of the config file.
37+
_CONFIG_FILE_NAME = "config.yaml"
3638
# The default config file location of the Administrator provided config file. This path can be
3739
# overridden with `SAGEMAKER_ADMIN_CONFIG_OVERRIDE` environment variable.
38-
_DEFAULT_ADMIN_CONFIG_FILE_PATH = os.path.join(site_config_dir(_APP_NAME), "config.yaml")
40+
_DEFAULT_ADMIN_CONFIG_FILE_PATH = os.path.join(site_config_dir(_APP_NAME), _CONFIG_FILE_NAME)
3941
# The default config file location of the user provided config file. This path can be
4042
# overridden with `SAGEMAKER_USER_CONFIG_OVERRIDE` environment variable.
41-
_DEFAULT_USER_CONFIG_FILE_PATH = os.path.join(user_config_dir(_APP_NAME), "config.yaml")
43+
_DEFAULT_USER_CONFIG_FILE_PATH = os.path.join(user_config_dir(_APP_NAME), _CONFIG_FILE_NAME)
44+
# The default config file location of the local mode.
45+
_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH = os.path.join(
46+
os.path.expanduser("~"), ".sagemaker", _CONFIG_FILE_NAME
47+
)
4248

4349
ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE = "SAGEMAKER_ADMIN_CONFIG_OVERRIDE"
4450
ENV_VARIABLE_USER_CONFIG_OVERRIDE = "SAGEMAKER_USER_CONFIG_OVERRIDE"
@@ -144,11 +150,21 @@ def validate_sagemaker_config(sagemaker_config: dict = None):
144150
jsonschema.validate(sagemaker_config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA)
145151

146152

153+
def load_local_mode_config() -> dict | None:
154+
"""Loads the local mode config file."""
155+
try:
156+
content = _load_config_from_file(_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH)
157+
except ValueError:
158+
content = None
159+
160+
return content
161+
162+
147163
def _load_config_from_file(file_path: str) -> dict:
148164
"""Placeholder docstring"""
149165
inferred_file_path = file_path
150166
if os.path.isdir(file_path):
151-
inferred_file_path = os.path.join(file_path, "config.yaml")
167+
inferred_file_path = os.path.join(file_path, _CONFIG_FILE_NAME)
152168
if not os.path.exists(inferred_file_path):
153169
raise ValueError(
154170
f"Unable to load the config file from the location: {file_path}"
@@ -194,10 +210,14 @@ def _get_inferred_s3_uri(s3_uri, s3_resource_for_config):
194210
if len(s3_files_with_same_prefix) > 1:
195211
# Customer has provided us with a S3 URI which points to a directory
196212
# search for s3://<bucket>/directory-key-prefix/config.yaml
197-
inferred_s3_uri = str(pathlib.PurePosixPath(s3_uri, "config.yaml")).replace("s3:/", "s3://")
213+
inferred_s3_uri = str(pathlib.PurePosixPath(s3_uri, _CONFIG_FILE_NAME)).replace(
214+
"s3:/", "s3://"
215+
)
198216
if inferred_s3_uri not in s3_files_with_same_prefix:
199217
# We don't know which file we should be operating with.
200-
raise ValueError("Provide an S3 URI of a directory that has a config.yaml file.")
218+
raise ValueError(
219+
f"Provide an S3 URI of a directory that has a {_CONFIG_FILE_NAME} file."
220+
)
201221
# Customer has a config.yaml present in the directory that was provided as the S3 URI
202222
return inferred_s3_uri
203223
return s3_uri

src/sagemaker/config/config_schema.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,11 @@
102102
DISABLE_PROFILER = "DisableProfiler"
103103
ESTIMATOR = "Estimator"
104104
DEBUG_HOOK_CONFIG = "DebugHookConfig"
105+
LOCAL = "local"
106+
LOCAL_CODE = "local_code"
107+
SERVING_PORT = "serving_port"
108+
CONTAINER_CONFIG = "container_config"
109+
REGION_NAME = "region_name"
105110

106111

107112
def _simple_path(*args: str):
@@ -1068,3 +1073,28 @@ def _simple_path(*args: str):
10681073
},
10691074
},
10701075
}
1076+
1077+
SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA = {
1078+
"$schema": "https://json-schema.org/draft/2020-12/schema",
1079+
TYPE: OBJECT,
1080+
ADDITIONAL_PROPERTIES: False,
1081+
PROPERTIES: {
1082+
LOCAL: {
1083+
TYPE: OBJECT,
1084+
ADDITIONAL_PROPERTIES: False,
1085+
PROPERTIES: {
1086+
LOCAL_CODE: {
1087+
TYPE: "boolean",
1088+
},
1089+
REGION_NAME: {TYPE: "string"},
1090+
SERVING_PORT: {
1091+
TYPE: "integer",
1092+
},
1093+
CONTAINER_CONFIG: {
1094+
TYPE: OBJECT,
1095+
},
1096+
},
1097+
},
1098+
},
1099+
"required": [LOCAL],
1100+
}

src/sagemaker/feature_store/feature_processor/_config_uploader.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# language governing permissions and limitations under the License.
1313
"""Contains classes for preparing and uploading configs for a scheduled feature processor."""
1414
from __future__ import absolute_import
15-
from typing import Callable, Dict, Tuple, List
15+
from typing import Callable, Dict, Optional, Tuple, List
1616

1717
import attr
1818

@@ -70,6 +70,7 @@ def prepare_step_input_channel_for_spark_mode(
7070
s3_base_uri,
7171
self.remote_decorator_config.s3_kms_key,
7272
sagemaker_session,
73+
self.remote_decorator_config.custom_file_filter,
7374
)
7475

7576
(
@@ -134,6 +135,7 @@ def _prepare_and_upload_dependencies(
134135
s3_base_uri: str,
135136
s3_kms_key: str,
136137
sagemaker_session: Session,
138+
custom_file_filter: Optional[Callable[[str, List], List]] = None,
137139
) -> str:
138140
"""Upload the training step dependencies to S3 if present"""
139141
return _prepare_and_upload_dependencies(
@@ -144,6 +146,7 @@ def _prepare_and_upload_dependencies(
144146
s3_base_uri=s3_base_uri,
145147
s3_kms_key=s3_kms_key,
146148
sagemaker_session=sagemaker_session,
149+
custom_file_filter=custom_file_filter,
147150
)
148151

149152
def _prepare_and_upload_runtime_scripts(

src/sagemaker/local/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def get_splitter_instance(split_type):
6969
Returns
7070
:class:`sagemaker.local.data.Splitter`: an Instance of a Splitter
7171
"""
72-
if split_type is None:
72+
if split_type == "None" or split_type is None:
7373
return NoneSplitter()
7474
if split_type == "Line":
7575
return LineSplitter()

src/sagemaker/local/image.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Placeholder docstring"""
14-
from __future__ import absolute_import
14+
from __future__ import absolute_import, annotations
1515

1616
import base64
1717
import copy
@@ -32,9 +32,11 @@
3232

3333
from distutils.spawn import find_executable
3434
from threading import Thread
35+
from typing import Dict, List
3536
from six.moves.urllib.parse import urlparse
3637

3738
import sagemaker
39+
from sagemaker.config.config_schema import CONTAINER_CONFIG, LOCAL
3840
import sagemaker.local.data
3941
import sagemaker.local.utils
4042
import sagemaker.utils
@@ -769,24 +771,38 @@ def _compose(self, detached=False):
769771
logger.info("docker command: %s", " ".join(compose_cmd))
770772
return compose_cmd
771773

772-
def _create_docker_host(self, host, environment, optml_subdirs, command, volumes):
774+
def _create_docker_host(
775+
self,
776+
host: str,
777+
environment: List[str],
778+
optml_subdirs: set[str],
779+
command: str,
780+
volumes: List,
781+
) -> Dict:
773782
"""Creates the docker host configuration.
774783
775784
Args:
776-
host:
777-
environment:
778-
optml_subdirs:
779-
command:
780-
volumes:
785+
host (str): The host address
786+
environment (List[str]): List of environment variables
787+
optml_subdirs (Set[str]): Set of subdirs
788+
command (str): Either 'train' or 'serve'
789+
volumes (list): List of volumes that will be mapped to the containers
781790
"""
782791
optml_volumes = self._build_optml_volumes(host, optml_subdirs)
783792
optml_volumes.extend(volumes)
784793

785794
container_name_prefix = "".join(
786795
random.choice(string.ascii_lowercase + string.digits) for _ in range(10)
787796
)
797+
container_default_config = (
798+
sagemaker.utils.get_config_value(
799+
f"{LOCAL}.{CONTAINER_CONFIG}", self.sagemaker_session.config
800+
)
801+
or {}
802+
)
788803

789804
host_config = {
805+
**container_default_config,
790806
"image": self.image,
791807
"container_name": f"{container_name_prefix}-{host}",
792808
"stdin_open": True,

src/sagemaker/local/local_session.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,24 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Placeholder docstring"""
14-
from __future__ import absolute_import
14+
from __future__ import absolute_import, annotations
1515

1616
import logging
17-
import os
1817
import platform
1918
from datetime import datetime
19+
from typing import Dict
2020

2121
import boto3
2222
from botocore.exceptions import ClientError
23+
import jsonschema
2324

2425
from sagemaker.config import (
25-
load_sagemaker_config,
26-
validate_sagemaker_config,
26+
SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA,
2727
SESSION_DEFAULT_S3_BUCKET_PATH,
2828
SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
29+
load_local_mode_config,
30+
load_sagemaker_config,
31+
validate_sagemaker_config,
2932
)
3033
from sagemaker.local.image import _SageMakerContainer
3134
from sagemaker.local.utils import get_docker_host
@@ -83,7 +86,7 @@ def create_processing_job(
8386
Environment=None,
8487
ProcessingInputs=None,
8588
ProcessingOutputConfig=None,
86-
**kwargs
89+
**kwargs,
8790
):
8891
"""Creates a processing job in Local Mode
8992
@@ -165,7 +168,7 @@ def create_training_job(
165168
ResourceConfig,
166169
InputDataConfig=None,
167170
Environment=None,
168-
**kwargs
171+
**kwargs,
169172
):
170173
"""Create a training job in Local Mode.
171174
@@ -230,7 +233,7 @@ def create_transform_job(
230233
TransformInput,
231234
TransformOutput,
232235
TransformResources,
233-
**kwargs
236+
**kwargs,
234237
):
235238
"""Create the transform job.
236239
@@ -537,7 +540,21 @@ def __init__(self, config=None):
537540
self.http = urllib3.PoolManager()
538541
self.serving_port = 8080
539542
self.config = config
540-
self.serving_port = get_config_value("local.serving_port", config) or 8080
543+
544+
@property
545+
def config(self) -> dict:
546+
"""Local config getter"""
547+
return self._config
548+
549+
@config.setter
550+
def config(self, value: dict):
551+
"""Local config setter, this method also updates the `serving_port` attribute.
552+
553+
Args:
554+
value (dict): the new config value
555+
"""
556+
self._config = value
557+
self.serving_port = get_config_value("local.serving_port", self._config) or 8080
541558

542559
def invoke_endpoint(
543560
self,
@@ -686,6 +703,7 @@ def _initialize(
686703

687704
self.sagemaker_client = LocalSagemakerClient(self)
688705
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
706+
689707
self.local_mode = True
690708
sagemaker_config = kwargs.get("sagemaker_config", None)
691709
if sagemaker_config:
@@ -726,17 +744,26 @@ def _initialize(
726744
sagemaker_session=self,
727745
)
728746

729-
local_mode_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml")
730-
if os.path.exists(local_mode_config_file):
747+
self.config = load_local_mode_config()
748+
if self._disable_local_code and self.config and "local" in self.config:
749+
self.config["local"]["local_code"] = False
750+
751+
@Session.config.setter
752+
def config(self, value: Dict | None):
753+
"""Setter of the local mode config"""
754+
if value is not None:
731755
try:
732-
import yaml
733-
except ImportError as e:
734-
logger.error(_module_import_error("yaml", "Local mode", "local"))
756+
jsonschema.validate(value, SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA)
757+
except jsonschema.ValidationError as e:
758+
logger.error("Failed to validate the local mode config")
735759
raise e
760+
self._config = value
761+
else:
762+
self._config = value
736763

737-
self.config = yaml.safe_load(open(local_mode_config_file, "r"))
738-
if self._disable_local_code and "local" in self.config:
739-
self.config["local"]["local_code"] = False
764+
# update the runtime client on config changed
765+
if getattr(self, "sagemaker_runtime_client", None):
766+
self.sagemaker_runtime_client.config = self._config
740767

741768
def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
742769
"""A no-op method meant to override the sagemaker client.

0 commit comments

Comments
 (0)