Skip to content

Commit 6184b22

Browse files
authored
Add Local Mode support (#115)
* Add Local Mode support. When passing "local" as the instance type for any estimator, training and deployment happens locally. Similarly, using "local_gpu" will use nvidia-docker-compose and work for GPU training.
1 parent 4f92fbd commit 6184b22

File tree

13 files changed

+1116
-14
lines changed

13 files changed

+1116
-14
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ CHANGELOG
55
1.1.dev4
66
========
77
* feature: Frameworks: Use more idiomatic ECR repository naming scheme
8+
* feature: Add Support for Local Mode
89

910
1.1.3
1011
========

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ def read(fname):
3232
],
3333

3434
# Declare minimal set for installation
35-
install_requires=['boto3>=1.4.8', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=1.0.0'],
35+
install_requires=['boto3>=1.4.8', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=1.0.0', 'urllib3>=1.2',
36+
'PyYAML>=3.2'],
3637

3738
extras_require={
3839
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist',

src/sagemaker/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from sagemaker.amazon.factorization_machines import FactorizationMachinesPredictor
2222
from sagemaker.amazon.ntm import NTM, NTMModel, NTMPredictor
2323

24+
from sagemaker.local.local_session import LocalSession
25+
2426
from sagemaker.model import Model
2527
from sagemaker.predictor import RealTimePredictor
2628
from sagemaker.session import Session
@@ -34,5 +36,5 @@
3436
LinearLearnerModel, LinearLearnerPredictor,
3537
LDA, LDAModel, LDAPredictor,
3638
FactorizationMachines, FactorizationMachinesModel, FactorizationMachinesPredictor,
37-
Model, NTM, NTMModel, NTMPredictor, RealTimePredictor, Session,
39+
Model, NTM, NTMModel, NTMPredictor, RealTimePredictor, Session, LocalSession,
3840
container_def, s3_input, production_variant, get_execution_role]

src/sagemaker/estimator.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from sagemaker.fw_utils import tar_and_upload_dir
2222
from sagemaker.fw_utils import parse_s3_url
2323
from sagemaker.fw_utils import UploadedCode
24+
from sagemaker.local.local_session import LocalSession
2425
from sagemaker.model import Model
2526
from sagemaker.model import (SCRIPT_PARAM_NAME, DIR_PARAM_NAME, CLOUDWATCH_METRICS_PARAM_NAME,
2627
CONTAINER_LOG_LEVEL_PARAM_NAME, JOB_NAME_PARAM_NAME, SAGEMAKER_REGION_PARAM_NAME)
@@ -78,7 +79,17 @@ def __init__(self, role, train_instance_count, train_instance_type,
7879
self.train_volume_size = train_volume_size
7980
self.train_max_run = train_max_run
8081
self.input_mode = input_mode
81-
self.sagemaker_session = sagemaker_session or Session()
82+
83+
if self.train_instance_type in ('local', 'local_gpu'):
84+
self.local_mode = True
85+
if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
86+
raise RuntimeError("Distributed Training in Local GPU is not supported")
87+
88+
self.sagemaker_session = LocalSession()
89+
else:
90+
self.local_mode = False
91+
self.sagemaker_session = sagemaker_session or Session()
92+
8293
self.base_job_name = base_job_name
8394
self._current_job_name = None
8495
self.output_path = output_path
@@ -303,7 +314,7 @@ def start_new(cls, estimator, inputs):
303314
"""Create a new Amazon SageMaker training job from the estimator.
304315
305316
Args:
306-
estimator (sagemaker.estimator.Framework): Estimator object created by the user.
317+
estimator (sagemaker.estimator.EstimatorBase): Estimator object created by the user.
307318
inputs (str): Parameters used when called :meth:`~sagemaker.estimator.EstimatorBase.fit`.
308319
309320
Returns:

src/sagemaker/fw_utils.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,23 @@ def create_image_uri(region, framework, instance_type, framework_version, py_ver
4545
str: The appropriate image URI based on the given parameters.
4646
"""
4747

48-
if not instance_type.startswith('ml.'):
48+
# Handle Local Mode
49+
if instance_type.startswith('local'):
50+
device_type = 'cpu' if instance_type == 'local' else 'gpu'
51+
elif not instance_type.startswith('ml.'):
4952
raise ValueError('{} is not a valid SageMaker instance type. See: '
5053
'https://aws.amazon.com/sagemaker/pricing/instance-types/'.format(instance_type))
51-
family = instance_type.split('.')[1]
52-
53-
# For some frameworks, we have optimized images for specific families, e.g c5 or p3. In those cases,
54-
# we use the family name in the image tag. In other cases, we use 'cpu' or 'gpu'.
55-
if family in optimized_families:
56-
device_type = family
57-
elif family[0] in ['g', 'p']:
58-
device_type = 'gpu'
5954
else:
60-
device_type = 'cpu'
55+
family = instance_type.split('.')[1]
56+
57+
# For some frameworks, we have optimized images for specific families, e.g c5 or p3. In those cases,
58+
# we use the family name in the image tag. In other cases, we use 'cpu' or 'gpu'.
59+
if family in optimized_families:
60+
device_type = family
61+
elif family[0] in ['g', 'p']:
62+
device_type = 'gpu'
63+
else:
64+
device_type = 'cpu'
6165

6266
tag = "{}-{}-{}".format(framework_version, device_type, py_version)
6367
return "{}.dkr.ecr.{}.amazonaws.com/sagemaker-{}:{}" \

src/sagemaker/local/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.

0 commit comments

Comments
 (0)