Skip to content

feat: allow configuring docker container in local mode #4153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions doc/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1560,9 +1560,20 @@ You can install necessary dependencies for this feature using pip.
Additionally, Local Mode also requires Docker Compose V2. Follow the guidelines in https://docs.docker.com/compose/install/ to install.
Make sure to have a Compose Version compatible with your Docker Engine installation. Check Docker Engine release notes https://docs.docker.com/engine/release-notes to find a compatible version.

If you want to keep everything local, and not use Amazon S3 either, you can enable "local code" in one of two ways:
Local mode configuration
========================

- Create a file at ``~/.sagemaker/config.yaml`` that contains:
The local mode uses a YAML configuration file located at ``~/.sagemaker/config.yaml`` to define the default values that are automatically passed to the ``config`` attribute of ``LocalSession``. This is an example of the configuration, for the full schema, see `sagemaker.config.config_schema.SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA <https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/config/config_schema.py>`_.

.. code:: yaml

local:
local_code: true # Using everything locally
region_name: "us-west-2" # Name of the region
container_config: # Additional docker container config
shm_size: "128M

If you want to keep everything local, and not use Amazon S3 either, you can enable "local code" in one of two ways:

.. code:: yaml

Expand All @@ -1583,6 +1594,9 @@ If you want to keep everything local, and not use Amazon S3 either, you can enab
.. note::
If you enable "local code," then you cannot use the ``dependencies`` parameter in your estimator or model.

Activating local mode by ``instance_type`` argument
====================================================

We can take the example in `Using Estimators <#using-estimators>`__ , and use either ``local`` or ``local_gpu`` as the instance type.

.. code:: python
Expand Down
10 changes: 9 additions & 1 deletion src/sagemaker/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
"""This module configures the default values for SageMaker Python SDK."""

from __future__ import absolute_import
from sagemaker.config.config import load_sagemaker_config, validate_sagemaker_config # noqa: F401
from sagemaker.config.config import ( # noqa: F401
load_local_mode_config,
load_sagemaker_config,
validate_sagemaker_config,
)
from sagemaker.config.config_schema import ( # noqa: F401
KEY,
TRAINING_JOB,
Expand Down Expand Up @@ -161,4 +165,8 @@
INFERENCE_SPECIFICATION,
ESTIMATOR,
DEBUG_HOOK_CONFIG,
SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA,
LOCAL,
LOCAL_CODE,
CONTAINER_CONFIG,
)
32 changes: 26 additions & 6 deletions src/sagemaker/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
The schema of the config file is dictated in config_schema.py in the same module.

"""
from __future__ import absolute_import
from __future__ import absolute_import, annotations

import pathlib
import os
Expand All @@ -33,12 +33,18 @@
logger = get_sagemaker_config_logger()

_APP_NAME = "sagemaker"
# The default name of the config file.
_CONFIG_FILE_NAME = "config.yaml"
# The default config file location of the Administrator provided config file. This path can be
# overridden with `SAGEMAKER_ADMIN_CONFIG_OVERRIDE` environment variable.
_DEFAULT_ADMIN_CONFIG_FILE_PATH = os.path.join(site_config_dir(_APP_NAME), "config.yaml")
_DEFAULT_ADMIN_CONFIG_FILE_PATH = os.path.join(site_config_dir(_APP_NAME), _CONFIG_FILE_NAME)
# The default config file location of the user provided config file. This path can be
# overridden with `SAGEMAKER_USER_CONFIG_OVERRIDE` environment variable.
_DEFAULT_USER_CONFIG_FILE_PATH = os.path.join(user_config_dir(_APP_NAME), "config.yaml")
_DEFAULT_USER_CONFIG_FILE_PATH = os.path.join(user_config_dir(_APP_NAME), _CONFIG_FILE_NAME)
# The default config file location of the local mode.
_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH = os.path.join(
os.path.expanduser("~"), ".sagemaker", _CONFIG_FILE_NAME
)

ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE = "SAGEMAKER_ADMIN_CONFIG_OVERRIDE"
ENV_VARIABLE_USER_CONFIG_OVERRIDE = "SAGEMAKER_USER_CONFIG_OVERRIDE"
Expand Down Expand Up @@ -144,11 +150,21 @@ def validate_sagemaker_config(sagemaker_config: dict = None):
jsonschema.validate(sagemaker_config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA)


def load_local_mode_config() -> dict | None:
"""Loads the local mode config file."""
try:
content = _load_config_from_file(_DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH)
except ValueError:
content = None

return content


def _load_config_from_file(file_path: str) -> dict:
"""Placeholder docstring"""
inferred_file_path = file_path
if os.path.isdir(file_path):
inferred_file_path = os.path.join(file_path, "config.yaml")
inferred_file_path = os.path.join(file_path, _CONFIG_FILE_NAME)
if not os.path.exists(inferred_file_path):
raise ValueError(
f"Unable to load the config file from the location: {file_path}"
Expand Down Expand Up @@ -194,10 +210,14 @@ def _get_inferred_s3_uri(s3_uri, s3_resource_for_config):
if len(s3_files_with_same_prefix) > 1:
# Customer has provided us with a S3 URI which points to a directory
# search for s3://<bucket>/directory-key-prefix/config.yaml
inferred_s3_uri = str(pathlib.PurePosixPath(s3_uri, "config.yaml")).replace("s3:/", "s3://")
inferred_s3_uri = str(pathlib.PurePosixPath(s3_uri, _CONFIG_FILE_NAME)).replace(
"s3:/", "s3://"
)
if inferred_s3_uri not in s3_files_with_same_prefix:
# We don't know which file we should be operating with.
raise ValueError("Provide an S3 URI of a directory that has a config.yaml file.")
raise ValueError(
f"Provide an S3 URI of a directory that has a {_CONFIG_FILE_NAME} file."
)
# Customer has a config.yaml present in the directory that was provided as the S3 URI
return inferred_s3_uri
return s3_uri
30 changes: 30 additions & 0 deletions src/sagemaker/config/config_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@
DISABLE_PROFILER = "DisableProfiler"
ESTIMATOR = "Estimator"
DEBUG_HOOK_CONFIG = "DebugHookConfig"
LOCAL = "local"
LOCAL_CODE = "local_code"
SERVING_PORT = "serving_port"
CONTAINER_CONFIG = "container_config"
REGION_NAME = "region_name"


def _simple_path(*args: str):
Expand Down Expand Up @@ -1068,3 +1073,28 @@ def _simple_path(*args: str):
},
},
}

SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA = {
"$schema": "https://json-schema.org/draft/2020-12/schema",
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: False,
PROPERTIES: {
LOCAL: {
TYPE: OBJECT,
ADDITIONAL_PROPERTIES: False,
PROPERTIES: {
LOCAL_CODE: {
TYPE: "boolean",
},
REGION_NAME: {TYPE: "string"},
SERVING_PORT: {
TYPE: "integer",
},
CONTAINER_CONFIG: {
TYPE: OBJECT,
},
},
},
},
"required": [LOCAL],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does required mean here? the usage of defaults config is still optional correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added it to prevent creating an empty config.yaml file or passing an empty dict to the local session config. This config is optional as always.

}
30 changes: 23 additions & 7 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import
from __future__ import absolute_import, annotations

import base64
import copy
Expand All @@ -32,9 +32,11 @@

from distutils.spawn import find_executable
from threading import Thread
from typing import Dict, List
from six.moves.urllib.parse import urlparse

import sagemaker
from sagemaker.config.config_schema import CONTAINER_CONFIG, LOCAL
import sagemaker.local.data
import sagemaker.local.utils
import sagemaker.utils
Expand Down Expand Up @@ -769,24 +771,38 @@ def _compose(self, detached=False):
logger.info("docker command: %s", " ".join(compose_cmd))
return compose_cmd

def _create_docker_host(self, host, environment, optml_subdirs, command, volumes):
def _create_docker_host(
self,
host: str,
environment: List[str],
optml_subdirs: set[str],
command: str,
volumes: List,
) -> Dict:
"""Creates the docker host configuration.

Args:
host:
environment:
optml_subdirs:
command:
volumes:
host (str): The host address
environment (List[str]): List of environment variables
optml_subdirs (Set[str]): Set of subdirs
command (str): Either 'train' or 'serve'
volumes (list): List of volumes that will be mapped to the containers
"""
optml_volumes = self._build_optml_volumes(host, optml_subdirs)
optml_volumes.extend(volumes)

container_name_prefix = "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(10)
)
container_default_config = (
sagemaker.utils.get_config_value(
f"{LOCAL}.{CONTAINER_CONFIG}", self.sagemaker_session.config
)
or {}
)

host_config = {
**container_default_config,
"image": self.image,
"container_name": f"{container_name_prefix}-{host}",
"stdin_open": True,
Expand Down
59 changes: 43 additions & 16 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,21 +11,24 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import
from __future__ import absolute_import, annotations

import logging
import os
import platform
from datetime import datetime
from typing import Dict

import boto3
from botocore.exceptions import ClientError
import jsonschema

from sagemaker.config import (
load_sagemaker_config,
validate_sagemaker_config,
SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA,
SESSION_DEFAULT_S3_BUCKET_PATH,
SESSION_DEFAULT_S3_OBJECT_KEY_PREFIX_PATH,
load_local_mode_config,
load_sagemaker_config,
validate_sagemaker_config,
)
from sagemaker.local.image import _SageMakerContainer
from sagemaker.local.utils import get_docker_host
Expand Down Expand Up @@ -83,7 +86,7 @@ def create_processing_job(
Environment=None,
ProcessingInputs=None,
ProcessingOutputConfig=None,
**kwargs
**kwargs,
):
"""Creates a processing job in Local Mode

Expand Down Expand Up @@ -165,7 +168,7 @@ def create_training_job(
ResourceConfig,
InputDataConfig=None,
Environment=None,
**kwargs
**kwargs,
):
"""Create a training job in Local Mode.

Expand Down Expand Up @@ -230,7 +233,7 @@ def create_transform_job(
TransformInput,
TransformOutput,
TransformResources,
**kwargs
**kwargs,
):
"""Create the transform job.

Expand Down Expand Up @@ -537,7 +540,21 @@ def __init__(self, config=None):
self.http = urllib3.PoolManager()
self.serving_port = 8080
self.config = config
self.serving_port = get_config_value("local.serving_port", config) or 8080

@property
def config(self) -> dict:
"""Local config getter"""
return self._config

@config.setter
def config(self, value: dict):
"""Local config setter, this method also updates the `serving_port` attribute.

Args:
value (dict): the new config value
"""
self._config = value
self.serving_port = get_config_value("local.serving_port", self._config) or 8080

def invoke_endpoint(
self,
Expand Down Expand Up @@ -686,6 +703,7 @@ def _initialize(

self.sagemaker_client = LocalSagemakerClient(self)
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)

self.local_mode = True
sagemaker_config = kwargs.get("sagemaker_config", None)
if sagemaker_config:
Expand Down Expand Up @@ -726,17 +744,26 @@ def _initialize(
sagemaker_session=self,
)

local_mode_config_file = os.path.join(os.path.expanduser("~"), ".sagemaker", "config.yaml")
if os.path.exists(local_mode_config_file):
self.config = load_local_mode_config()
if self._disable_local_code and self.config and "local" in self.config:
self.config["local"]["local_code"] = False

@Session.config.setter
def config(self, value: Dict | None):
"""Setter of the local mode config"""
if value is not None:
try:
import yaml
except ImportError as e:
logger.error(_module_import_error("yaml", "Local mode", "local"))
jsonschema.validate(value, SAGEMAKER_PYTHON_SDK_LOCAL_MODE_CONFIG_SCHEMA)
except jsonschema.ValidationError as e:
logger.error("Failed to validate the local mode config")
raise e
self._config = value
else:
self._config = value

self.config = yaml.safe_load(open(local_mode_config_file, "r"))
if self._disable_local_code and "local" in self.config:
self.config["local"]["local_code"] = False
# update the runtime client on config changed
if getattr(self, "sagemaker_runtime_client", None):
self.sagemaker_runtime_client.config = self._config

def logs_for_job(self, job_name, wait=False, poll=5, log_type="All"):
"""A no-op method meant to override the sagemaker client.
Expand Down
14 changes: 12 additions & 2 deletions src/sagemaker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import, print_function
from __future__ import absolute_import, annotations, print_function

import json
import logging
Expand Down Expand Up @@ -239,7 +239,7 @@ def __init__(
self.s3_client = None
self.resource_groups_client = None
self.resource_group_tagging_client = None
self.config = None
self._config = None
self.lambda_client = None
self.settings = settings

Expand Down Expand Up @@ -326,6 +326,16 @@ def _initialize(
sagemaker_session=self,
)

@property
def config(self) -> Dict | None:
"""The config for the local mode, unused in a normal session"""
return self._config

@config.setter
def config(self, value: Dict | None):
"""The config for the local mode, unused in a normal session"""
self._config = value

@property
def boto_region_name(self):
"""Placeholder docstring"""
Expand Down
Loading