Skip to content

Allow Local Mode to work with a local training script. #178

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 10, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 53 additions & 25 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@

import json
import logging
import os
from abc import ABCMeta
from abc import abstractmethod
from six import with_metaclass, string_types

from sagemaker.fw_utils import tar_and_upload_dir
from sagemaker.fw_utils import parse_s3_url
from sagemaker.fw_utils import UploadedCode
from sagemaker.local.local_session import LocalSession, file_input
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
from sagemaker.local import LocalSession, file_input

from sagemaker.model import Model
from sagemaker.model import (SCRIPT_PARAM_NAME, DIR_PARAM_NAME, CLOUDWATCH_METRICS_PARAM_NAME,
Expand All @@ -30,7 +29,7 @@
from sagemaker.predictor import RealTimePredictor
from sagemaker.session import Session
from sagemaker.session import s3_input
from sagemaker.utils import base_name_from_image, name_from_base
from sagemaker.utils import base_name_from_image, name_from_base, get_config_value


class EstimatorBase(with_metaclass(ABCMeta, object)):
Expand Down Expand Up @@ -83,13 +82,10 @@ def __init__(self, role, train_instance_count, train_instance_type,
self.input_mode = input_mode

if self.train_instance_type in ('local', 'local_gpu'):
self.local_mode = True
if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
raise RuntimeError("Distributed Training in Local GPU is not supported")

self.sagemaker_session = sagemaker_session or LocalSession()
else:
self.local_mode = False
self.sagemaker_session = sagemaker_session or Session()

self.base_job_name = base_job_name
Expand Down Expand Up @@ -158,9 +154,14 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
base_name = self.base_job_name or base_name_from_image(self.train_image())
self._current_job_name = name_from_base(base_name)

# if output_path was specified we use it otherwise initialize here
# if output_path was specified we use it otherwise initialize here.
# For Local Mode with local_code=True we don't need an explicit output_path
if self.output_path is None:
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
if self.sagemaker_session.local_mode and local_code:
self.output_path = ''
else:
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())

self.latest_training_job = _TrainingJob.start_new(self, inputs)
if wait:
Expand Down Expand Up @@ -323,7 +324,7 @@ def start_new(cls, estimator, inputs):
sagemaker.estimator.Framework: Constructed object that captures all information about the started job.
"""

local_mode = estimator.local_mode
local_mode = estimator.sagemaker_session.local_mode

# Allow file:// input only in local mode
if isinstance(inputs, str) and inputs.startswith('file://'):
Expand Down Expand Up @@ -604,27 +605,54 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
base_name = self.base_job_name or base_name_from_image(self.train_image())
self._current_job_name = name_from_base(base_name)

# validate source dir will raise a ValueError if there is something wrong with the
# source directory. We are intentionally not handling it because this is a critical error.
if self.source_dir and not self.source_dir.lower().startswith('s3://'):
validate_source_dir(self.entry_point, self.source_dir)

# if we are in local mode with local_code=True. We want the container to just
# mount the source dir instead of uploading to S3.
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
if self.sagemaker_session.local_mode and local_code:
# if there is no source dir, use the directory containing the entry point.
if self.source_dir is None:
self.source_dir = os.path.dirname(self.entry_point)
self.entry_point = os.path.basename(self.entry_point)

code_dir = 'file://' + self.source_dir
script = self.entry_point
else:
self.uploaded_code = self._stage_user_code_in_s3()
code_dir = self.uploaded_code.s3_prefix
script = self.uploaded_code.script_name

# Modify hyperparameters in-place to point to the right code directory and script URIs
self._hyperparameters[DIR_PARAM_NAME] = code_dir
self._hyperparameters[SCRIPT_PARAM_NAME] = script
self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
super(Framework, self).fit(inputs, wait, logs, self._current_job_name)

def _stage_user_code_in_s3(self):
""" Upload the user training script to s3 and return the location.

Returns: s3 uri

"""
if self.code_location is None:
code_bucket = self.sagemaker_session.default_bucket()
code_s3_prefix = '{}/source'.format(self._current_job_name)
else:
code_bucket, key_prefix = parse_s3_url(self.code_location)
code_s3_prefix = '{}/{}/source'.format(key_prefix, self._current_job_name)

self.uploaded_code = tar_and_upload_dir(session=self.sagemaker_session.boto_session,
bucket=code_bucket,
s3_key_prefix=code_s3_prefix,
script=self.entry_point,
directory=self.source_dir)

# Modify hyperparameters in-place to add the URLs to the uploaded code.
self._hyperparameters[DIR_PARAM_NAME] = self.uploaded_code.s3_prefix
self._hyperparameters[SCRIPT_PARAM_NAME] = self.uploaded_code.script_name
self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_session.region_name
super(Framework, self).fit(inputs, wait, logs, self._current_job_name)
return tar_and_upload_dir(session=self.sagemaker_session.boto_session,
bucket=code_bucket,
s3_key_prefix=code_s3_prefix,
script=self.entry_point,
directory=self.source_dir)

def hyperparameters(self):
"""Return the hyperparameters as a dictionary to use for training.
Expand Down
31 changes: 20 additions & 11 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,23 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
.format(account, region, framework, tag)


def validate_source_dir(script, directory):
"""Validate that the source directory exists and it contains the user script

Args:
script (str): Script filename.
directory (str): Directory containing the source file.

Raises:
ValueError: If ``directory`` does not exist, is not a directory, or does not contain ``script``.
"""
if directory:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove this if statement

if not os.path.isfile(os.path.join(directory, script)):
raise ValueError('No file named "{}" was found in directory "{}".'.format(script, directory))

return True


def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory):
"""Pack and upload source files to S3 only if directory is empty or local.

Expand All @@ -83,21 +100,13 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory):

Returns:
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name.

Raises:
ValueError: If ``directory`` does not exist, is not a directory, or does not contain ``script``.
"""
if directory:
if directory.lower().startswith("s3://"):
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
if not os.path.exists(directory):
raise ValueError('"{}" does not exist.'.format(directory))
if not os.path.isdir(directory):
raise ValueError('"{}" is not a directory.'.format(directory))
if script not in os.listdir(directory):
raise ValueError('No file named "{}" was found in directory "{}".'.format(script, directory))
script_name = script
source_files = [os.path.join(directory, name) for name in os.listdir(directory)]
else:
script_name = script
source_files = [os.path.join(directory, name) for name in os.listdir(directory)]
else:
# If no directory is specified, the script parameter needs to be a valid relative path.
os.path.exists(script)
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/local/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,7 @@
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from .local_session import (file_input, LocalSession, LocalSagemakerRuntimeClient,
LocalSagemakerClient)

__all__ = [file_input, LocalSession, LocalSagemakerClient, LocalSagemakerRuntimeClient]
45 changes: 35 additions & 10 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@

import yaml

import sagemaker
from sagemaker.utils import get_config_value

CONTAINER_PREFIX = "algo"
DOCKER_COMPOSE_FILENAME = 'docker-compose.yaml'

Expand Down Expand Up @@ -68,11 +71,6 @@ def __init__(self, instance_type, instance_count, image, sagemaker_session=None)
self.hosts = ['{}-{}-{}'.format(CONTAINER_PREFIX, i, suffix) for i in range(1, self.instance_count + 1)]
self.container_root = None
self.container = None
# set the local config. This is optional and will use reasonable defaults
# if not present.
self.local_config = None
if self.sagemaker_session.config and 'local' in self.sagemaker_session.config:
self.local_config = self.sagemaker_session.config['local']

def train(self, input_data_config, hyperparameters):
"""Run a training job locally using docker-compose.
Expand All @@ -85,6 +83,10 @@ def train(self, input_data_config, hyperparameters):
"""
self.container_root = self._create_tmp_folder()
os.mkdir(os.path.join(self.container_root, 'output'))
# A shared directory for all the containers. It is only mounted if the training script is
# Local.
shared_dir = os.path.join(self.container_root, 'shared')
os.mkdir(shared_dir)

data_dir = self._create_tmp_folder()
volumes = []
Expand Down Expand Up @@ -116,6 +118,14 @@ def train(self, input_data_config, hyperparameters):
else:
raise ValueError('Unknown URI scheme {}'.format(parsed_uri.scheme))

# If the training script directory is a local directory, mount it to the container.
training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME])
parsed_uri = urlparse(training_dir)
if parsed_uri.scheme == 'file':
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))
# Also mount a directory that all the containers can access.
volumes.append(_Volume(shared_dir, '/opt/ml/shared'))

# Create the configuration files for each container that we will create
# Each container will map the additional local volumes (if any).
for host in self.hosts:
Expand All @@ -135,6 +145,7 @@ def train(self, input_data_config, hyperparameters):
# lots of data downloaded from S3. This doesn't delete any local
# data that was just mounted to the container.
_delete_tree(data_dir)
_delete_tree(shared_dir)
# Also free the container config files.
for host in self.hosts:
container_config_path = os.path.join(self.container_root, host)
Expand Down Expand Up @@ -171,7 +182,16 @@ def serve(self, primary_container):

_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image)

self._generate_compose_file('serve', additional_env_vars=env_vars)
# If the user script was passed as a file:// mount it to the container.
script_dir = primary_container['Environment'][sagemaker.estimator.DIR_PARAM_NAME.upper()]
parsed_uri = urlparse(script_dir)
volumes = []
if parsed_uri.scheme == 'file':
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))

self._generate_compose_file('serve',
additional_env_vars=env_vars,
additional_volumes=volumes)
compose_command = self._compose()
self.container = _HostingContainer(compose_command)
self.container.up()
Expand Down Expand Up @@ -366,8 +386,9 @@ def _create_docker_host(self, host, environment, optml_subdirs, command, volumes
}
}

serving_port = 8080 if self.local_config is None else self.local_config.get('serving_port', 8080)
if command == 'serve':
serving_port = get_config_value('local.serving_port',
self.sagemaker_session.config) or 8080
host_config.update({
'ports': [
'%s:8080' % serving_port
Expand All @@ -377,9 +398,9 @@ def _create_docker_host(self, host, environment, optml_subdirs, command, volumes
return host_config

def _create_tmp_folder(self):
root_dir = None
if self.local_config and 'container_root' in self.local_config:
root_dir = os.path.abspath(self.local_config['container_root'])
root_dir = get_config_value('local.container_root', self.sagemaker_session.config)
if root_dir:
root_dir = os.path.abspath(root_dir)

dir = tempfile.mkdtemp(dir=root_dir)

Expand Down Expand Up @@ -565,6 +586,10 @@ def _ecr_login_if_needed(boto_session, image):
if _check_output('docker images -q %s' % image).strip():
return

if not boto_session:
raise RuntimeError('A boto session is required to login to ECR.'
'Please pull the image: %s manually.' % image)

ecr = boto_session.client('ecr')
auth = ecr.get_authorization_token(registryIds=[image.split('.')[0]])
authorization_data = auth['authorizationData'][0]
Expand Down
21 changes: 16 additions & 5 deletions src/sagemaker/local/local_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import platform
import time

import boto3
import urllib3
from botocore.exceptions import ClientError

from sagemaker.local.image import _SageMakerContainer
from sagemaker.session import Session
from sagemaker.utils import get_config_value

logger = logging.getLogger(__name__)
logger.setLevel(logging.WARNING)
Expand Down Expand Up @@ -115,9 +117,7 @@ def create_endpoint(self, EndpointName, EndpointConfigName):

i = 0
http = urllib3.PoolManager()
serving_port = 8080
if self.sagemaker_session.config and 'local' in self.sagemaker_session.config:
serving_port = self.sagemaker_session.config['local'].get('serving_port', 8080)
serving_port = get_config_value('local.serving_port', self.sagemaker_session.config) or 8080
endpoint_url = "http://localhost:%s/ping" % serving_port
while True:
i += 1
Expand Down Expand Up @@ -153,8 +153,8 @@ def __init__(self, config=None):
"""
self.http = urllib3.PoolManager()
self.serving_port = 8080
if config and 'local' in config:
self.serving_port = config['local'].get('serving_port', 8080)
self.config = config
self.serving_port = get_config_value('local.serving_port', config) or 8080

def invoke_endpoint(self, Body, EndpointName, ContentType, Accept):
url = "http://localhost:%s/invocations" % self.serving_port
Expand All @@ -171,8 +171,19 @@ def __init__(self, boto_session=None):

if platform.system() == 'Windows':
logger.warning("Windows Support for Local Mode is Experimental")

def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
"""Initialize this Local SageMaker Session."""

self.boto_session = boto_session or boto3.Session()
self._region_name = self.boto_session.region_name

if self._region_name is None:
raise ValueError('Must setup local AWS configuration with a region supported by SageMaker.')

self.sagemaker_client = LocalSagemakerClient(self)
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
self.local_mode = True

def logs_for_job(self, job_name, wait=False, poll=5):
# override logs_for_job() as it doesn't need to perform any action
Expand Down
Loading