Skip to content

Commit 41a9ddb

Browse files
author
Ignacio Quintero
committed
Address Owen's feedback.
In general boto_session is not None anymore. no_internet renamed to local_code. Added more unit tests to bring coverage up.
1 parent 2d42219 commit 41a9ddb

16 files changed

+147
-68
lines changed

src/sagemaker/estimator.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from six import with_metaclass, string_types
2121

2222
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url, UploadedCode, validate_source_dir
23-
from sagemaker.local.local_session import LocalSession, file_input
23+
from sagemaker.local import LocalSession, file_input
2424

2525
from sagemaker.model import Model
2626
from sagemaker.model import (SCRIPT_PARAM_NAME, DIR_PARAM_NAME, CLOUDWATCH_METRICS_PARAM_NAME,
@@ -155,10 +155,10 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
155155
self._current_job_name = name_from_base(base_name)
156156

157157
# if output_path was specified we use it otherwise initialize here.
158-
# For Local Mode with no_internet=True we don't need an explicit output_path
158+
# For Local Mode with local_code=True we don't need an explicit output_path
159159
if self.output_path is None:
160-
no_internet = get_config_value('local.no_internet', self.sagemaker_session.config)
161-
if self.sagemaker_session.local_mode and no_internet:
160+
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
161+
if self.sagemaker_session.local_mode and local_code:
162162
self.output_path = ''
163163
else:
164164
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())
@@ -610,10 +610,10 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
610610
if self.source_dir and not self.source_dir.lower().startswith('s3://'):
611611
validate_source_dir(self.entry_point, self.source_dir)
612612

613-
# if we are in local mode with no_internet=True. We want the container to just
613+
# if we are in local mode with local_code=True. We want the container to just
614614
# mount the source dir instead of uploading to S3.
615-
no_internet = get_config_value('local.no_internet', self.sagemaker_session.config)
616-
if self.sagemaker_session.local_mode and no_internet:
615+
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
616+
if self.sagemaker_session.local_mode and local_code:
617617
# if there is no source dir, use the directory containing the entry point.
618618
if self.source_dir is None:
619619
self.source_dir = os.path.dirname(self.entry_point)
@@ -632,7 +632,7 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
632632
self._hyperparameters[CLOUDWATCH_METRICS_PARAM_NAME] = self.enable_cloudwatch_metrics
633633
self._hyperparameters[CONTAINER_LOG_LEVEL_PARAM_NAME] = self.container_log_level
634634
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
635-
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.region_name
635+
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
636636
super(Framework, self).fit(inputs, wait, logs, self._current_job_name)
637637

638638
def _stage_user_code_in_s3(self):

src/sagemaker/local/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,7 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13+
from .local_session import (file_input, LocalSession, LocalSagemakerRuntimeClient,
14+
LocalSagemakerClient)
15+
16+
__all__ = [file_input, LocalSession, LocalSagemakerClient, LocalSagemakerRuntimeClient]

src/sagemaker/local/image.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,6 @@ def __init__(self, instance_type, instance_count, image, sagemaker_session=None)
7171
self.hosts = ['{}-{}-{}'.format(CONTAINER_PREFIX, i, suffix) for i in range(1, self.instance_count + 1)]
7272
self.container_root = None
7373
self.container = None
74-
# set the local config. This is optional and will use reasonable defaults
75-
# if not present.
76-
self.local_config = get_config_value('local', self.sagemaker_session.config)
7774

7875
def train(self, input_data_config, hyperparameters):
7976
"""Run a training job locally using docker-compose.
@@ -86,7 +83,10 @@ def train(self, input_data_config, hyperparameters):
8683
"""
8784
self.container_root = self._create_tmp_folder()
8885
os.mkdir(os.path.join(self.container_root, 'output'))
89-
os.mkdir(os.path.join(self.container_root, 'shared'))
86+
# A shared directory for all the containers. It is only mounted if the training script is
87+
# Local.
88+
shared_dir = os.path.join(self.container_root, 'shared')
89+
os.mkdir(shared_dir)
9090

9191
data_dir = self._create_tmp_folder()
9292
volumes = []
@@ -123,7 +123,8 @@ def train(self, input_data_config, hyperparameters):
123123
parsed_uri = urlparse(training_dir)
124124
if parsed_uri.scheme == 'file':
125125
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))
126-
volumes.append(_Volume(os.path.join(self.container_root, 'shared'), '/opt/ml/shared'))
126+
# Also mount a directory that all the containers can access.
127+
volumes.append(_Volume(shared_dir, '/opt/ml/shared'))
127128

128129
# Create the configuration files for each container that we will create
129130
# Each container will map the additional local volumes (if any).
@@ -144,6 +145,7 @@ def train(self, input_data_config, hyperparameters):
144145
# lots of data downloaded from S3. This doesn't delete any local
145146
# data that was just mounted to the container.
146147
_delete_tree(data_dir)
148+
_delete_tree(shared_dir)
147149
# Also free the container config files.
148150
for host in self.hosts:
149151
container_config_path = os.path.join(self.container_root, host)

src/sagemaker/local/local_session.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -173,20 +173,13 @@ def __init__(self, boto_session=None):
173173
logger.warning("Windows Support for Local Mode is Experimental")
174174

175175
def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
176-
"""Initialize a boto session for this Local SageMaker Session."""
177-
if get_config_value('local.no_internet', self.config):
178-
# if no_internet is set to True in the config file then we won't create a boto_session
179-
# this will make any component that defaults to using S3 utilize a local file instead.
180-
self.boto_session = None
181-
self.region_name = get_config_value('local.region_name', self.config)
182-
if self.region_name is None:
183-
raise ValueError('Must setup region_name in the sagemaker config file. See <Link to Readme Here>')
184-
else:
185-
self.boto_session = boto_session or boto3.Session()
186-
self.region_name = self.boto_session.region_name
176+
"""Initialize this Local SageMaker Session."""
177+
178+
self.boto_session = boto_session or boto3.Session()
179+
self._region_name = self.boto_session.region_name
187180

188-
if self.region_name is None:
189-
raise ValueError('Must setup local AWS configuration with a region supported by SageMaker.')
181+
if self._region_name is None:
182+
raise ValueError('Must setup local AWS configuration with a region supported by SageMaker.')
190183

191184
self.sagemaker_client = LocalSagemakerClient(self)
192185
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)

src/sagemaker/model.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -160,21 +160,21 @@ def prepare_container_def(self, instance_type):
160160
Returns:
161161
dict[str, str]: A container definition object usable with the CreateModel API.
162162
"""
163-
no_internet = get_config_value('local.no_internet', self.sagemaker_session.config)
164-
if self.sagemaker_session.local_mode and no_internet:
165-
self.uploaded_code = None
166-
else:
167-
self._upload_code(self.key_prefix or self.name or name_from_image(self.image))
163+
self._upload_code(self.key_prefix or self.name or name_from_image(self.image))
168164
deploy_env = dict(self.env)
169165
deploy_env.update(self._framework_env_vars())
170166
return sagemaker.container_def(self.image, self.model_data, deploy_env)
171167

172168
def _upload_code(self, key_prefix):
173-
self.uploaded_code = tar_and_upload_dir(session=self.sagemaker_session.boto_session,
174-
bucket=self.bucket or self.sagemaker_session.default_bucket(),
175-
s3_key_prefix=key_prefix,
176-
script=self.entry_point,
177-
directory=self.source_dir)
169+
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
170+
if self.sagemaker_session.local_mode and local_code:
171+
self.uploaded_code = None
172+
else:
173+
self.uploaded_code = tar_and_upload_dir(session=self.sagemaker_session.boto_session,
174+
bucket=self.bucket or self.sagemaker_session.default_bucket(),
175+
s3_key_prefix=key_prefix,
176+
script=self.entry_point,
177+
directory=self.source_dir)
178178

179179
def _framework_env_vars(self):
180180
if self.uploaded_code:
@@ -189,5 +189,5 @@ def _framework_env_vars(self):
189189
DIR_PARAM_NAME.upper(): dir_name,
190190
CLOUDWATCH_METRICS_PARAM_NAME.upper(): str(self.enable_cloudwatch_metrics).lower(),
191191
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level),
192-
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.region_name
192+
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_region_name
193193
}

src/sagemaker/mxnet/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def train_image(self):
6565
Returns:
6666
str: The URI of the Docker image.
6767
"""
68-
return create_image_uri(self.sagemaker_session.region_name, self.__framework_name__,
68+
return create_image_uri(self.sagemaker_session.boto_region_name, self.__framework_name__,
6969
self.train_instance_type, framework_version=self.framework_version,
7070
py_version=self.py_version)
7171

src/sagemaker/session.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
8787
"""
8888
self.boto_session = boto_session or boto3.Session()
8989

90-
self.region_name = self.boto_session.region_name
91-
if self.region_name is None:
90+
self._region_name = self.boto_session.region_name
91+
if self._region_name is None:
9292
raise ValueError('Must setup local AWS configuration with a region supported by SageMaker.')
9393

9494
self.sagemaker_client = sagemaker_client or self.boto_session.client('sagemaker')
@@ -101,7 +101,7 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
101101

102102
@property
103103
def boto_region_name(self):
104-
return self.boto_session.region_name
104+
return self._region_name
105105

106106
def upload_data(self, path, bucket=None, key_prefix='data'):
107107
"""Upload local file or directory to S3.

src/sagemaker/tensorflow/estimator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ def train_image(self):
279279
Returns:
280280
str: The URI of the Docker image.
281281
"""
282-
return create_image_uri(self.sagemaker_session.region_name, self.__framework_name__,
282+
return create_image_uri(self.sagemaker_session.boto_region_name, self.__framework_name__,
283283
self.train_instance_type, self.framework_version, py_version=self.py_version)
284284

285285
def create_model(self, model_server_workers=None):
@@ -306,8 +306,8 @@ def hyperparameters(self):
306306
hyperparameters = super(TensorFlow, self).hyperparameters()
307307

308308
if not self.checkpoint_path:
309-
no_internet = get_config_value('local.no_internet', self.sagemaker_session.config)
310-
if self.sagemaker_session.local_mode and no_internet:
309+
local_code = get_config_value('local.local_code', self.sagemaker_session.config)
310+
if self.sagemaker_session.local_mode and local_code:
311311
self.checkpoint_path = '/opt/ml/shared/checkpoints'
312312
else:
313313
self.checkpoint_path = os.path.join(self.output_path,

src/sagemaker/tensorflow/model.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sagemaker.predictor import RealTimePredictor
1717
from sagemaker.tensorflow.defaults import TF_VERSION
1818
from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer
19-
from sagemaker.utils import name_from_image, get_config_value
19+
from sagemaker.utils import name_from_image
2020

2121

2222
class TensorFlowPredictor(RealTimePredictor):
@@ -83,18 +83,13 @@ def prepare_container_def(self, instance_type):
8383
"""
8484
deploy_image = self.image
8585
if not deploy_image:
86-
region_name = self.sagemaker_session.region_name
86+
region_name = self.sagemaker_session.boto_region_name
8787
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
8888
self.framework_version, self.py_version)
8989

9090
deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)
9191

92-
no_internet = get_config_value('local.no_internet', self.sagemaker_session.config)
93-
print('no_internet: %s local_mode: %s' % (no_internet, self.sagemaker_session.local_mode))
94-
if self.sagemaker_session.local_mode and no_internet:
95-
self.uploaded_code = None
96-
else:
97-
self._upload_code(deploy_key_prefix)
92+
self._upload_code(deploy_key_prefix)
9893
deploy_env = dict(self.env)
9994
deploy_env.update(self._framework_env_vars())
10095

tests/unit/test_estimator.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from mock import Mock, patch
1818

1919
from sagemaker.estimator import Estimator, Framework, _TrainingJob
20+
from sagemaker.local import LocalSession
2021
from sagemaker.session import s3_input
2122
from sagemaker.model import FrameworkModel
2223
from sagemaker.predictor import RealTimePredictor
@@ -89,12 +90,12 @@ def create_predictor(self, endpoint_name):
8990
@pytest.fixture()
9091
def sagemaker_session():
9192
boto_mock = Mock(name='boto_session', region_name=REGION)
92-
ims = Mock(name='sagemaker_session', boto_session=boto_mock, region_name=REGION,
93-
config=None, local_mode=False)
94-
ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
95-
ims.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
93+
sms = Mock(name='sagemaker_session', boto_session=boto_mock,
94+
boto_region_name=REGION, config=None, local_mode=False)
95+
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
96+
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
9697
return_value=DESCRIBE_TRAINING_JOB_RESULT)
97-
return ims
98+
return sms
9899

99100

100101
def test_sagemaker_s3_uri_invalid(sagemaker_session):
@@ -142,6 +143,24 @@ def test_invalid_custom_code_bucket(sagemaker_session):
142143
}
143144

144145

146+
def test_local_code_location():
147+
config = {
148+
'local': {
149+
'local_code': True,
150+
'region': 'us-west-2'
151+
}
152+
}
153+
sms = Mock(name='sagemaker_session', boto_session=None,
154+
boto_region_name=REGION, config=config, local_mode=True)
155+
t = DummyFramework(entry_point=SCRIPT_PATH, role=ROLE, sagemaker_session=sms,
156+
train_instance_count=1, train_instance_type='local',
157+
base_job_name=IMAGE_NAME, hyperparameters={123: [456], 'learning_rate': 0.1})
158+
159+
t.fit('file:///data/file')
160+
assert t.source_dir == DATA_DIR
161+
assert t.entry_point == 'dummy_script.py'
162+
163+
145164
@patch('time.strftime', return_value=TIMESTAMP)
146165
def test_start_new_convert_hyperparameters_to_str(strftime, sagemaker_session):
147166
uri = 'bucket/mydata'

tests/unit/test_image.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@
4949
'b': 'bee',
5050
'sagemaker_submit_directory': json.dumps('s3://my_bucket/code')}
5151

52+
LOCAL_CODE_HYPERPARAMETERS = {'a': 1,
53+
'b': 2,
54+
'sagemaker_submit_directory': json.dumps('file:///tmp/code')}
5255

5356
@pytest.fixture()
5457
def sagemaker_session():
@@ -216,6 +219,37 @@ def test_train(_download_folder, _cleanup, _execute_and_stream_output, LocalSess
216219
assert config['services'][h]['command'] == 'train'
217220

218221

222+
@patch('sagemaker.local.local_session.LocalSession')
223+
@patch('sagemaker.local.image._execute_and_stream_output')
224+
@patch('sagemaker.local.image._SageMakerContainer._cleanup')
225+
@patch('sagemaker.local.image._SageMakerContainer._download_folder')
226+
def test_train_local_code(_download_folder, _cleanup, _execute_and_stream_output,
227+
_local_session, tmpdir, sagemaker_session):
228+
directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]
229+
with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder',
230+
side_effect=directories):
231+
instance_count = 2
232+
image = 'my-image'
233+
sagemaker_container = _SageMakerContainer('local', instance_count, image,
234+
sagemaker_session=sagemaker_session)
235+
236+
sagemaker_container.train(INPUT_DATA_CONFIG, LOCAL_CODE_HYPERPARAMETERS)
237+
238+
docker_compose_file = os.path.join(sagemaker_container.container_root,
239+
'docker-compose.yaml')
240+
shared_folder_path = os.path.join(sagemaker_container.container_root, 'shared')
241+
242+
with open(docker_compose_file, 'r') as f:
243+
config = yaml.load(f)
244+
assert len(config['services']) == instance_count
245+
for h in sagemaker_container.hosts:
246+
assert config['services'][h]['image'] == image
247+
assert config['services'][h]['command'] == 'train'
248+
volumes = config['services'][h]['volumes']
249+
assert '%s:/opt/ml/code' % '/tmp/code' in volumes
250+
assert '%s:/opt/ml/shared' % shared_folder_path in volumes
251+
252+
219253
@patch('sagemaker.local.image._HostingContainer.up')
220254
@patch('shutil.copy')
221255
@patch('shutil.copytree')
@@ -244,6 +278,38 @@ def test_serve(up, copy, copytree, tmpdir, sagemaker_session):
244278
assert config['services'][h]['command'] == 'serve'
245279

246280

281+
@patch('sagemaker.local.image._HostingContainer.up')
282+
@patch('shutil.copy')
283+
@patch('shutil.copytree')
284+
def test_serve_local_code(up, copy, copytree, tmpdir, sagemaker_session):
285+
286+
with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder',
287+
return_value=str(tmpdir.mkdir('container-root'))):
288+
289+
image = 'my-image'
290+
sagemaker_container = _SageMakerContainer('local', 1, image, sagemaker_session=sagemaker_session)
291+
primary_container = {'ModelDataUrl': '/some/model/path',
292+
'Environment': {'env1': 1,
293+
'env2': 'b',
294+
'SAGEMAKER_SUBMIT_DIRECTORY': 'file:///tmp/code'
295+
}
296+
}
297+
298+
sagemaker_container.serve(primary_container)
299+
docker_compose_file = os.path.join(sagemaker_container.container_root,
300+
'docker-compose.yaml')
301+
302+
with open(docker_compose_file, 'r') as f:
303+
config = yaml.load(f)
304+
305+
for h in sagemaker_container.hosts:
306+
assert config['services'][h]['image'] == image
307+
assert config['services'][h]['command'] == 'serve'
308+
309+
volumes = config['services'][h]['volumes']
310+
assert '%s:/opt/ml/code' % '/tmp/code' in volumes
311+
312+
247313
@patch('os.makedirs')
248314
def test_download_folder(makedirs):
249315
boto_mock = Mock(name='boto_session')

tests/unit/test_lda.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@
3737
@pytest.fixture()
3838
def sagemaker_session():
3939
boto_mock = Mock(name='boto_session', region_name=REGION)
40-
sms = Mock(name='sagemaker_session', boto_session=boto_mock,
41-
region_name=REGION, config=None, local_mode=False)
40+
sms = Mock(name='sagemaker_session', boto_session=boto_mock, config=None, local_mode=False)
4241
sms.boto_region_name = REGION
4342
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
4443
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',

0 commit comments

Comments
 (0)