Skip to content

Commit fa8c39b

Browse files
author
Ignacio Quintero
committed
Allow Local Mode to work without a boto session.
This change works when ~/.sagemaker/config.yaml has local: no_internet: True It depends on the container image supporting a local training script instead of an s3 location.
1 parent c1f1ab9 commit fa8c39b

File tree

13 files changed

+156
-63
lines changed

13 files changed

+156
-63
lines changed

src/sagemaker/estimator.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@
1414

1515
import json
1616
import logging
17+
import os
1718
from abc import ABCMeta
1819
from abc import abstractmethod
1920
from six import with_metaclass, string_types
2021

21-
from sagemaker.fw_utils import tar_and_upload_dir
22-
from sagemaker.fw_utils import parse_s3_url
23-
from sagemaker.fw_utils import UploadedCode
22+
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
2423
from sagemaker.local.local_session import LocalSession, file_input
2524

2625
from sagemaker.model import Model
@@ -30,7 +29,7 @@
3029
from sagemaker.predictor import RealTimePredictor
3130
from sagemaker.session import Session
3231
from sagemaker.session import s3_input
33-
from sagemaker.utils import base_name_from_image, name_from_base
32+
from sagemaker.utils import base_name_from_image, name_from_base, get_config_value
3433

3534

3635
class EstimatorBase(with_metaclass(ABCMeta, object)):
@@ -158,9 +157,14 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
158157
base_name = self.base_job_name or base_name_from_image(self.train_image())
159158
self._current_job_name = name_from_base(base_name)
160159

161-
# if output_path was specified we use it otherwise initialize here
160+
# if output_path was specified we use it otherwise initialize here.
161+
# For Local Mode with no_internet=True we don't need an explicit output_path
162162
if self.output_path is None:
163-
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())
163+
if self.local_mode and get_config_value('local.no_internet',
164+
self.sagemaker_session.config):
165+
self.output_path = ''
166+
else:
167+
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())
164168

165169
self.latest_training_job = _TrainingJob.start_new(self, inputs)
166170
if wait:
@@ -604,27 +608,53 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
604608
base_name = self.base_job_name or base_name_from_image(self.train_image())
605609
self._current_job_name = name_from_base(base_name)
606610

611+
# if there is no source dir, use the directory containing the entry point.
612+
if self.source_dir is None:
613+
self.source_dir = os.path.dirname(self.entry_point)
614+
self.entry_point = os.path.basename(self.entry_point)
615+
616+
# validate source dir will raise a ValueError if there is something wrong with the
617+
# source directory. We are intentionally not handling it because this is a critical error.
618+
if self.source_dir and not self.source_dir.lower().startswith('s3://'):
619+
validate_source_dir(self.entry_point, self.source_dir)
620+
621+
# if we are in local mode with no_internet=True. We want the container to just
622+
# mount the source dir instead of uploading to S3.
623+
if self.local_mode and get_config_value('local.no_internet', self.sagemaker_session.config):
624+
code_dir = 'file://' + self.source_dir
625+
script = self.entry_point
626+
else:
627+
self.uploaded_code = self._stage_user_code_in_s3()
628+
code_dir = self.uploaded_code.s3_prefix
629+
script = self.uploaded_code.script_name
630+
631+
# Modify hyperparameters in-place to point to the right code directory and script URIs
632+
self._hyperparameters[DIR_PARAM_NAME] = code_dir
633+
self._hyperparameters[SCRIPT_PARAM_NAME] = script
634+
self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
635+
self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
636+
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
637+
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.region_name
638+
super(Framework, self).fit(inputs, wait, logs, self._current_job_name)
639+
640+
def _stage_user_code_in_s3(self):
641+
""" Upload the user training script to s3 and return the location.
642+
643+
Returns: s3 uri
644+
645+
"""
607646
if self.code_location is None:
608647
code_bucket = self.sagemaker_session.default_bucket()
609648
code_s3_prefix = '{}/source'.format(self._current_job_name)
610649
else:
611650
code_bucket, key_prefix = parse_s3_url(self.code_location)
612651
code_s3_prefix = '{}/{}/source'.format(key_prefix, self._current_job_name)
613652

614-
self.uploaded_code = tar_and_upload_dir(session=self.sagemaker_session.boto_session,
615-
bucket=code_bucket,
616-
s3_key_prefix=code_s3_prefix,
617-
script=self.entry_point,
618-
directory=self.source_dir)
619-
620-
# Modify hyperparameters in-place to add the URLs to the uploaded code.
621-
self._hyperparameters[DIR_PARAM_NAME] = self.uploaded_code.s3_prefix
622-
self._hyperparameters[SCRIPT_PARAM_NAME] = self.uploaded_code.script_name
623-
self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
624-
self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
625-
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
626-
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_session.region_name
627-
super(Framework, self).fit(inputs, wait, logs, self._current_job_name)
653+
return tar_and_upload_dir(session=self.sagemaker_session.boto_session,
654+
bucket=code_bucket,
655+
s3_key_prefix=code_s3_prefix,
656+
script=self.entry_point,
657+
directory=self.source_dir)
628658

629659
def hyperparameters(self):
630660
"""Return the hyperparameters as a dictionary to use for training.

src/sagemaker/fw_utils.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,27 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
6868
.format(account, region, framework, tag)
6969

7070

71+
def validate_source_dir(script, directory):
72+
"""Validate that the source directory exists and it contains the user script
73+
74+
Args:
75+
script (str): Script filename.
76+
directory (str): Directory containing the source file.
77+
78+
Raises:
79+
ValueError: If ``directory`` does not exist, is not a directory, or does not contain ``script``.
80+
"""
81+
if directory:
82+
if not os.path.exists(directory):
83+
raise ValueError('"{}" does not exist.'.format(directory))
84+
if not os.path.isdir(directory):
85+
raise ValueError('"{}" is not a directory.'.format(directory))
86+
if script not in os.listdir(directory):
87+
raise ValueError('No file named "{}" was found in directory "{}".'.format(script, directory))
88+
89+
return True
90+
91+
7192
def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory):
7293
"""Pack and upload source files to S3 only if directory is empty or local.
7394
@@ -83,21 +104,13 @@ def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory):
83104
84105
Returns:
85106
sagemaker.fw_utils.UserCode: An object with the S3 bucket and key (S3 prefix) and script name.
86-
87-
Raises:
88-
ValueError: If ``directory`` does not exist, is not a directory, or does not contain ``script``.
89107
"""
90108
if directory:
91109
if directory.lower().startswith("s3://"):
92110
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
93-
if not os.path.exists(directory):
94-
raise ValueError('"{}" does not exist.'.format(directory))
95-
if not os.path.isdir(directory):
96-
raise ValueError('"{}" is not a directory.'.format(directory))
97-
if script not in os.listdir(directory):
98-
raise ValueError('No file named "{}" was found in directory "{}".'.format(script, directory))
99-
script_name = script
100-
source_files = [os.path.join(directory, name) for name in os.listdir(directory)]
111+
else:
112+
script_name = script
113+
source_files = [os.path.join(directory, name) for name in os.listdir(directory)]
101114
else:
102115
# If no directory is specified, the script parameter needs to be a valid relative path.
103116
os.path.exists(script)

src/sagemaker/local/image.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929

3030
import yaml
3131

32+
import sagemaker
33+
from sagemaker.utils import get_config_value
34+
3235
CONTAINER_PREFIX = "algo"
3336
DOCKER_COMPOSE_FILENAME = 'docker-compose.yaml'
3437

@@ -70,9 +73,7 @@ def __init__(self, instance_type, instance_count, image, sagemaker_session=None)
7073
self.container = None
7174
# set the local config. This is optional and will use reasonable defaults
7275
# if not present.
73-
self.local_config = None
74-
if self.sagemaker_session.config and 'local' in self.sagemaker_session.config:
75-
self.local_config = self.sagemaker_session.config['local']
76+
self.local_config = get_config_value('local', self.sagemaker_session.config)
7677

7778
def train(self, input_data_config, hyperparameters):
7879
"""Run a training job locally using docker-compose.
@@ -116,6 +117,13 @@ def train(self, input_data_config, hyperparameters):
116117
else:
117118
raise ValueError('Unknown URI scheme {}'.format(parsed_uri.scheme))
118119

120+
# If the training script directory is a local directory, mount it to the container.
121+
training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME])
122+
parsed_uri = urlparse(training_dir)
123+
if parsed_uri.scheme == 'file':
124+
print('appended Volume')
125+
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))
126+
119127
# Create the configuration files for each container that we will create
120128
# Each container will map the additional local volumes (if any).
121129
for host in self.hosts:
@@ -366,8 +374,9 @@ def _create_docker_host(self, host, environment, optml_subdirs, command, volumes
366374
}
367375
}
368376

369-
serving_port = 8080 if self.local_config is None else self.local_config.get('serving_port', 8080)
370377
if command == 'serve':
378+
serving_port = get_config_value('local.serving_port',
379+
self.sagemaker_session.config) or 8080
371380
host_config.update({
372381
'ports': [
373382
'%s:8080' % serving_port
@@ -377,9 +386,9 @@ def _create_docker_host(self, host, environment, optml_subdirs, command, volumes
377386
return host_config
378387

379388
def _create_tmp_folder(self):
380-
root_dir = None
381-
if self.local_config and 'container_root' in self.local_config:
382-
root_dir = os.path.abspath(self.local_config['container_root'])
389+
root_dir = get_config_value('local.container_root', self.sagemaker_session.config)
390+
if root_dir:
391+
root_dir = os.path.abspath(root_dir)
383392

384393
dir = tempfile.mkdtemp(dir=root_dir)
385394

src/sagemaker/local/local_session.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
import platform
1616
import time
1717

18+
import boto3
1819
import urllib3
1920
from botocore.exceptions import ClientError
2021

2122
from sagemaker.local.image import _SageMakerContainer
2223
from sagemaker.session import Session
24+
from sagemaker.utils import get_config_value
2325

2426
logger = logging.getLogger(__name__)
2527
logger.setLevel(logging.WARNING)
@@ -115,9 +117,7 @@ def create_endpoint(self, EndpointName, EndpointConfigName):
115117

116118
i = 0
117119
http = urllib3.PoolManager()
118-
serving_port = 8080
119-
if self.sagemaker_session.config and 'local' in self.sagemaker_session.config:
120-
serving_port = self.sagemaker_session.config['local'].get('serving_port', 8080)
120+
serving_port = get_config_value('local.serving_port', self.sagemaker_session.config) or 8080
121121
endpoint_url = "http://localhost:%s/ping" % serving_port
122122
while True:
123123
i += 1
@@ -153,8 +153,8 @@ def __init__(self, config=None):
153153
"""
154154
self.http = urllib3.PoolManager()
155155
self.serving_port = 8080
156-
if config and 'local' in config:
157-
self.serving_port = config['local'].get('serving_port', 8080)
156+
self.config = config
157+
self.serving_port = get_config_value('local.serving_port', config) or 8080
158158

159159
def invoke_endpoint(self, Body, EndpointName, ContentType, Accept):
160160
url = "http://localhost:%s/invocations" % self.serving_port
@@ -171,6 +171,23 @@ def __init__(self, boto_session=None):
171171

172172
if platform.system() == 'Windows':
173173
logger.warning("Windows Support for Local Mode is Experimental")
174+
175+
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
176+
"""Initialize a boto session for this Local SageMaker Session."""
177+
if get_config_value('local.no_internet', self.config):
178+
# if no_internet is set to True in the config file then we won't create a boto_session
179+
# this will make any component that defaults to using S3 utilize a local file instead.
180+
self.boto_session = None
181+
self.region_name = get_config_value('local.region_name', self.config)
182+
if self.region_name is None:
183+
raise ValueError('Must setup region_name in the sagemaker config file. See <Link to Readme Here>')
184+
else:
185+
self.boto_session = boto_session or boto3.Session()
186+
self.region_name = self.boto_session.region_name
187+
188+
if self.region_name is None:
189+
raise ValueError('Must setup local AWS configuration with a region supported by SageMaker.')
190+
174191
self.sagemaker_client = LocalSagemakerClient(self)
175192
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
176193

src/sagemaker/mxnet/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def train_image(self):
6565
Returns:
6666
str: The URI of the Docker image.
6767
"""
68-
return create_image_uri(self.sagemaker_session.boto_session.region_name, self.__framework_name__,
68+
return create_image_uri(self.sagemaker_session.region_name, self.__framework_name__,
6969
self.train_instance_type, framework_version=self.framework_version,
7070
py_version=self.py_version)
7171

src/sagemaker/session.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,25 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c
7070
If not provided, one will be created using this instance's ``boto_session``.
7171
"""
7272
self._default_bucket = None
73+
74+
sagemaker_config_file = os.path.join(os.path.expanduser('~'), '.sagemaker', 'config.yaml')
75+
if os.path.exists(sagemaker_config_file):
76+
self.config = yaml.load(open(sagemaker_config_file, 'r'))
77+
else:
78+
self.config = None
79+
80+
self._initialize(boto_session, sagemaker_client, sagemaker_runtime_client)
81+
82+
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
83+
"""Initialize this SageMaker Session.
84+
85+
Creates or uses a boto_session, sagemaker_client and sagemaker_runtime_client.
86+
Sets the region_name.
87+
"""
7388
self.boto_session = boto_session or boto3.Session()
7489

75-
region = self.boto_session.region_name
76-
if region is None:
90+
self.region = self.boto_session.region_name
91+
if self.region is None:
7792
raise ValueError('Must setup local AWS configuration with a region supported by SageMaker.')
7893

7994
self.sagemaker_client = sagemaker_client or self.boto_session.client('sagemaker')
@@ -82,12 +97,6 @@ def __init__(self, boto_session=None, sagemaker_client=None, sagemaker_runtime_c
8297
self.sagemaker_runtime_client = sagemaker_runtime_client or self.boto_session.client('runtime.sagemaker')
8398
prepend_user_agent(self.sagemaker_runtime_client)
8499

85-
sagemaker_config_file = os.path.join(os.path.expanduser('~'), '.sagemaker', 'config.yaml')
86-
if os.path.exists(sagemaker_config_file):
87-
self.config = yaml.load(open(sagemaker_config_file, 'r'))
88-
else:
89-
self.config = None
90-
91100
@property
92101
def boto_region_name(self):
93102
return self.boto_session.region_name

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ def train_image(self):
278278
Returns:
279279
str: The URI of the Docker image.
280280
"""
281-
return create_image_uri(self.sagemaker_session.boto_session.region_name, self.__framework_name__,
281+
return create_image_uri(self.sagemaker_session.region_name, self.__framework_name__,
282282
self.train_instance_type, self.framework_version, py_version=self.py_version)
283283

284284
def create_model(self, model_server_workers=None):

src/sagemaker/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,17 @@ def wrapper(*args, **kwargs):
7676
return func(*args, **kwargs)
7777

7878
return wrapper
79+
80+
81+
def get_config_value(key_path, config):
82+
if config is None:
83+
return None
84+
85+
current_section = config
86+
for key in key_path.split('.'):
87+
if key in current_section:
88+
current_section = current_section[key]
89+
else:
90+
return None
91+
92+
return current_section

tests/unit/test_estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def create_predictor(self, endpoint_name):
8989
@pytest.fixture()
9090
def sagemaker_session():
9191
boto_mock = Mock(name='boto_session', region_name=REGION)
92-
ims = Mock(name='sagemaker_session', boto_session=boto_mock)
92+
ims = Mock(name='sagemaker_session', boto_session=boto_mock, region_name=REGION)
9393
ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
9494
ims.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
9595
return_value=DESCRIBE_TRAINING_JOB_RESULT)

tests/unit/test_fw_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from mock import Mock
1515
import os
1616
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
17-
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode
17+
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
1818
import pytest
1919

2020

@@ -110,13 +110,11 @@ def test_tar_and_upload_dir_is_not_directory(sagemaker_session):
110110
assert 'is not a directory' in str(error)
111111

112112

113-
def test_tar_and_upload_dir_file_not_in_dir(sagemaker_session):
114-
bucket = 'mybucker'
115-
s3_key_prefix = 'something/source'
113+
def test_validate_source_dir_file_not_in_dir():
116114
script = ' !@#$%^&*() .myscript. !@#$%^&*() '
117115
directory = '.'
118116
with pytest.raises(ValueError) as error:
119-
tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, directory)
117+
validate_source_dir(script, directory)
120118
assert 'No file named' in str(error)
121119

122120

0 commit comments

Comments
 (0)