Skip to content

Commit 2d42219

Browse files
author
Ignacio Quintero
committed
Fix unit tests, changes for serving.
This fixes the unit tests that were broken. Also adds the remaining work that allows serving for TF to work. Pending: MXNet.
1 parent fa8c39b commit 2d42219

23 files changed

+132
-66
lines changed

src/sagemaker/estimator.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,10 @@ def __init__(self, role, train_instance_count, train_instance_type,
8282
self.input_mode = input_mode
8383

8484
if self.train_instance_type in ('local', 'local_gpu'):
85-
self.local_mode = True
8685
if self.train_instance_type == 'local_gpu' and self.train_instance_count > 1:
8786
raise RuntimeError("Distributed Training in Local GPU is not supported")
88-
8987
self.sagemaker_session = sagemaker_session or LocalSession()
9088
else:
91-
self.local_mode = False
9289
self.sagemaker_session = sagemaker_session or Session()
9390

9491
self.base_job_name = base_job_name
@@ -160,8 +157,8 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
160157
# if output_path was specified we use it otherwise initialize here.
161158
# For Local Mode with no_internet=True we don't need an explicit output_path
162159
if self.output_path is None:
163-
if self.local_mode and get_config_value('local.no_internet',
164-
self.sagemaker_session.config):
160+
no_internet = get_config_value('local.no_internet', self.sagemaker_session.config)
161+
if self.sagemaker_session.local_mode and no_internet:
165162
self.output_path = ''
166163
else:
167164
self.output_path = 's3://{}/'.format(self.sagemaker_session.default_bucket())
@@ -327,7 +324,7 @@ def start_new(cls, estimator, inputs):
327324
sagemaker.estimator.Framework: Constructed object that captures all information about the started job.
328325
"""
329326

330-
local_mode = estimator.local_mode
327+
local_mode = estimator.sagemaker_session.local_mode
331328

332329
# Allow file:// input only in local mode
333330
if isinstance(inputs, str) and inputs.startswith('file://'):
@@ -608,19 +605,20 @@ def fit(self, inputs, wait=True, logs=True, job_name=None):
608605
base_name = self.base_job_name or base_name_from_image(self.train_image())
609606
self._current_job_name = name_from_base(base_name)
610607

611-
# if there is no source dir, use the directory containing the entry point.
612-
if self.source_dir is None:
613-
self.source_dir = os.path.dirname(self.entry_point)
614-
self.entry_point = os.path.basename(self.entry_point)
615-
616608
# validate source dir will raise a ValueError if there is something wrong with the
617609
# source directory. We are intentionally not handling it because this is a critical error.
618610
if self.source_dir and not self.source_dir.lower().startswith('s3://'):
619611
validate_source_dir(self.entry_point, self.source_dir)
620612

621613
# if we are in local mode with no_internet=True. We want the container to just
622614
# mount the source dir instead of uploading to S3.
623-
if self.local_mode and get_config_value('local.no_internet', self.sagemaker_session.config):
615+
no_internet = get_config_value('local.no_internet', self.sagemaker_session.config)
616+
if self.sagemaker_session.local_mode and no_internet:
617+
# if there is no source dir, use the directory containing the entry point.
618+
if self.source_dir is None:
619+
self.source_dir = os.path.dirname(self.entry_point)
620+
self.entry_point = os.path.basename(self.entry_point)
621+
624622
code_dir = 'file://' + self.source_dir
625623
script = self.entry_point
626624
else:

src/sagemaker/local/image.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def train(self, input_data_config, hyperparameters):
8686
"""
8787
self.container_root = self._create_tmp_folder()
8888
os.mkdir(os.path.join(self.container_root, 'output'))
89+
os.mkdir(os.path.join(self.container_root, 'shared'))
8990

9091
data_dir = self._create_tmp_folder()
9192
volumes = []
@@ -121,8 +122,8 @@ def train(self, input_data_config, hyperparameters):
121122
training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME])
122123
parsed_uri = urlparse(training_dir)
123124
if parsed_uri.scheme == 'file':
124-
print('appended Volume')
125125
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))
126+
volumes.append(_Volume(os.path.join(self.container_root, 'shared'), '/opt/ml/shared'))
126127

127128
# Create the configuration files for each container that we will create
128129
# Each container will map the additional local volumes (if any).
@@ -179,7 +180,16 @@ def serve(self, primary_container):
179180

180181
_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image)
181182

182-
self._generate_compose_file('serve', additional_env_vars=env_vars)
183+
# If the user script was passed as a file:// mount it to the container.
184+
script_dir = primary_container['Environment'][sagemaker.estimator.DIR_PARAM_NAME.upper()]
185+
parsed_uri = urlparse(script_dir)
186+
volumes = []
187+
if parsed_uri.scheme == 'file':
188+
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))
189+
190+
self._generate_compose_file('serve',
191+
additional_env_vars=env_vars,
192+
additional_volumes=volumes)
183193
compose_command = self._compose()
184194
self.container = _HostingContainer(compose_command)
185195
self.container.up()
@@ -574,6 +584,10 @@ def _ecr_login_if_needed(boto_session, image):
574584
if _check_output('docker images -q %s' % image).strip():
575585
return
576586

587+
if not boto_session:
588+
raise RuntimeError('A boto session is required to login to ECR.'
589+
'Please pull the image: %s manually.' % image)
590+
577591
ecr = boto_session.client('ecr')
578592
auth = ecr.get_authorization_token(registryIds=[image.split('.')[0]])
579593
authorization_data = auth['authorizationData'][0]

src/sagemaker/local/local_session.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
190190

191191
self.sagemaker_client = LocalSagemakerClient(self)
192192
self.sagemaker_runtime_client = LocalSagemakerRuntimeClient(self.config)
193+
self.local_mode = True
193194

194195
def logs_for_job(self, job_name, wait=False, poll=5):
195196
# override logs_for_job() as it doesn't need to perform any action

src/sagemaker/model.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from sagemaker.fw_utils import tar_and_upload_dir, parse_s3_url
1818
from sagemaker.session import Session
19-
from sagemaker.utils import name_from_image
19+
from sagemaker.utils import name_from_image, get_config_value
2020

2121

2222
class Model(object):
@@ -160,7 +160,11 @@ def prepare_container_def(self, instance_type):
160160
Returns:
161161
dict[str, str]: A container definition object usable with the CreateModel API.
162162
"""
163-
self._upload_code(self.key_prefix or self.name or name_from_image(self.image))
163+
no_internet = get_config_value('local.no_internet', self.sagemaker_session.config)
164+
if self.sagemaker_session.local_mode and no_internet:
165+
self.uploaded_code = None
166+
else:
167+
self._upload_code(self.key_prefix or self.name or name_from_image(self.image))
164168
deploy_env = dict(self.env)
165169
deploy_env.update(self._framework_env_vars())
166170
return sagemaker.container_def(self.image, self.model_data, deploy_env)
@@ -173,8 +177,17 @@ def _upload_code(self, key_prefix):
173177
directory=self.source_dir)
174178

175179
def _framework_env_vars(self):
176-
return {SCRIPT_PARAM_NAME.upper(): self.uploaded_code.script_name,
177-
DIR_PARAM_NAME.upper(): self.uploaded_code.s3_prefix,
178-
CLOUDWATCH_METRICS_PARAM_NAME.upper(): str(self.enable_cloudwatch_metrics).lower(),
179-
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level),
180-
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.boto_session.region_name}
180+
if self.uploaded_code:
181+
script_name = self.uploaded_code.script_name
182+
dir_name = self.uploaded_code.s3_prefix
183+
else:
184+
script_name = self.entry_point
185+
dir_name = 'file://' + self.source_dir
186+
187+
return {
188+
SCRIPT_PARAM_NAME.upper(): script_name,
189+
DIR_PARAM_NAME.upper(): dir_name,
190+
CLOUDWATCH_METRICS_PARAM_NAME.upper(): str(self.enable_cloudwatch_metrics).lower(),
191+
CONTAINER_LOG_LEVEL_PARAM_NAME.upper(): str(self.container_log_level),
192+
SAGEMAKER_REGION_PARAM_NAME.upper(): self.sagemaker_session.region_name
193+
}

src/sagemaker/session.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
8787
"""
8888
self.boto_session = boto_session or boto3.Session()
8989

90-
self.region = self.boto_session.region_name
91-
if self.region is None:
90+
self.region_name = self.boto_session.region_name
91+
if self.region_name is None:
9292
raise ValueError('Must setup local AWS configuration with a region supported by SageMaker.')
9393

9494
self.sagemaker_client = sagemaker_client or self.boto_session.client('sagemaker')
@@ -97,6 +97,8 @@ def _initialize(self, boto_session, sagemaker_client, sagemaker_runtime_client):
9797
self.sagemaker_runtime_client = sagemaker_runtime_client or self.boto_session.client('runtime.sagemaker')
9898
prepend_user_agent(self.sagemaker_runtime_client)
9999

100+
self.local_mode = False
101+
100102
@property
101103
def boto_region_name(self):
102104
return self.boto_session.region_name

src/sagemaker/tensorflow/estimator.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from sagemaker.estimator import Framework
2222
from sagemaker.fw_utils import create_image_uri, framework_name_from_image, framework_version_from_tag
23+
from sagemaker.utils import get_config_value
2324

2425
from sagemaker.tensorflow.defaults import TF_VERSION
2526
from sagemaker.tensorflow.model import TensorFlowModel
@@ -305,7 +306,12 @@ def hyperparameters(self):
305306
hyperparameters = super(TensorFlow, self).hyperparameters()
306307

307308
if not self.checkpoint_path:
308-
self.checkpoint_path = os.path.join(self.output_path, self._current_job_name, 'checkpoints')
309+
no_internet = get_config_value('local.no_internet', self.sagemaker_session.config)
310+
if self.sagemaker_session.local_mode and no_internet:
311+
self.checkpoint_path = '/opt/ml/shared/checkpoints'
312+
else:
313+
self.checkpoint_path = os.path.join(self.output_path,
314+
self._current_job_name, 'checkpoints')
309315

310316
additional_hyperparameters = {'checkpoint_path': self.checkpoint_path,
311317
'training_steps': self.training_steps,

src/sagemaker/tensorflow/model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sagemaker.predictor import RealTimePredictor
1717
from sagemaker.tensorflow.defaults import TF_VERSION
1818
from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer
19-
from sagemaker.utils import name_from_image
19+
from sagemaker.utils import name_from_image, get_config_value
2020

2121

2222
class TensorFlowPredictor(RealTimePredictor):
@@ -83,11 +83,18 @@ def prepare_container_def(self, instance_type):
8383
"""
8484
deploy_image = self.image
8585
if not deploy_image:
86-
region_name = self.sagemaker_session.boto_session.region_name
86+
region_name = self.sagemaker_session.region_name
8787
deploy_image = create_image_uri(region_name, self.__framework_name__, instance_type,
8888
self.framework_version, self.py_version)
89+
8990
deploy_key_prefix = self.key_prefix or self.name or name_from_image(deploy_image)
90-
self._upload_code(deploy_key_prefix)
91+
92+
no_internet = get_config_value('local.no_internet', self.sagemaker_session.config)
93+
print('no_internet: %s local_mode: %s' % (no_internet, self.sagemaker_session.local_mode))
94+
if self.sagemaker_session.local_mode and no_internet:
95+
self.uploaded_code = None
96+
else:
97+
self._upload_code(deploy_key_prefix)
9198
deploy_env = dict(self.env)
9299
deploy_env.update(self._framework_env_vars())
93100

tests/component/test_mxnet_estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
@pytest.fixture()
3535
def sagemaker_session():
3636
boto_mock = Mock(name='boto_session', region_name=REGION)
37-
ims = Mock(name='sagemaker_session', boto_session=boto_mock)
37+
ims = Mock(name='sagemaker_session', boto_session=boto_mock,
38+
config=None, local_mode=False, region_name=REGION)
39+
3840
ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
3941
ims.expand_role = Mock(name="expand_role", return_value=ROLE)
4042
ims.sagemaker_client.describe_training_job = Mock(return_value={'ModelArtifacts':

tests/component/test_tf_estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
@pytest.fixture()
3535
def sagemaker_session():
3636
boto_mock = Mock(name='boto_session', region_name=REGION)
37-
ims = Mock(name='sagemaker_session', boto_session=boto_mock)
37+
ims = Mock(name='sagemaker_session', boto_session=boto_mock, config=None,
38+
local_mode=False, region_name=REGION)
3839
ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
3940
ims.expand_role = Mock(name="expand_role", return_value=ROLE)
4041
ims.sagemaker_client.describe_training_job = Mock(return_value={'ModelArtifacts':
@@ -62,6 +63,7 @@ def test_deploy(sagemaker_session, tf_version):
6263
{'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
6364
'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',
6465
'SAGEMAKER_SUBMIT_DIRECTORY': SOURCE_DIR,
66+
'SAGEMAKER_REQUIREMENTS': '',
6567
'SAGEMAKER_REGION': REGION,
6668
'SAGEMAKER_PROGRAM': SCRIPT},
6769
'Image': image,

tests/unit/test_amazon_estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828
@pytest.fixture()
2929
def sagemaker_session():
3030
boto_mock = Mock(name='boto_session', region_name=REGION)
31-
sms = Mock(name='sagemaker_session', boto_session=boto_mock)
31+
sms = Mock(name='sagemaker_session', boto_session=boto_mock,
32+
region_name=REGION, config=None, local_mode=False)
3233
sms.boto_region_name = REGION
3334
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
3435
returned_job_description = {'AlgorithmSpecification': {'TrainingInputMode': 'File',

tests/unit/test_estimator.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def create_predictor(self, endpoint_name):
8989
@pytest.fixture()
9090
def sagemaker_session():
9191
boto_mock = Mock(name='boto_session', region_name=REGION)
92-
ims = Mock(name='sagemaker_session', boto_session=boto_mock, region_name=REGION)
92+
ims = Mock(name='sagemaker_session', boto_session=boto_mock, region_name=REGION,
93+
config=None, local_mode=False)
9394
ims.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
9495
ims.sagemaker_client.describe_training_job = Mock(name='describe_training_job',
9596
return_value=DESCRIBE_TRAINING_JOB_RESULT)
@@ -533,18 +534,26 @@ def test_generic_to_deploy(sagemaker_session):
533534

534535

535536
@patch('sagemaker.estimator.LocalSession')
536-
def test_local_mode(sagemaker_session):
537-
e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, 'local', output_path='s3://bucket/prefix',
538-
sagemaker_session=sagemaker_session)
539-
assert e.local_mode is True
537+
@patch('sagemaker.estimator.Session')
538+
def test_local_mode(session_class, local_session_class):
539+
local_session = Mock()
540+
local_session.local_mode = True
541+
542+
session = Mock()
543+
session.local_mode = False
544+
545+
local_session_class.return_value = local_session
546+
session_class.return_value = session
547+
548+
e = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, 'local')
549+
print(e.sagemaker_session.local_mode)
550+
assert e.sagemaker_session.local_mode is True
540551

541-
e2 = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, 'local_gpu', output_path='s3://bucket/prefix',
542-
sagemaker_session=sagemaker_session)
543-
assert e2.local_mode is True
552+
e2 = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, 'local_gpu')
553+
assert e2.sagemaker_session.local_mode is True
544554

545-
e3 = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, output_path='s3://bucket/prefix',
546-
sagemaker_session=sagemaker_session)
547-
assert e3.local_mode is False
555+
e3 = Estimator(IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE)
556+
assert e3.sagemaker_session.local_mode is False
548557

549558

550559
@patch('sagemaker.estimator.LocalSession')

tests/unit/test_fm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@
3939
@pytest.fixture()
4040
def sagemaker_session():
4141
boto_mock = Mock(name='boto_session', region_name=REGION)
42-
sms = Mock(name='sagemaker_session', boto_session=boto_mock)
42+
sms = Mock(name='sagemaker_session', boto_session=boto_mock,
43+
region_name=REGION, config=None, local_mode=False)
4344
sms.boto_region_name = REGION
4445
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
4546
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',

tests/unit/test_fw_utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,23 +90,19 @@ def test_tar_and_upload_dir_s3(sagemaker_session):
9090
assert result == UploadedCode('s3://m', 'mnist.py')
9191

9292

93-
def test_tar_and_upload_dir_does_not_exits(sagemaker_session):
94-
bucket = 'mybucker'
95-
s3_key_prefix = 'something/source'
93+
def test_validate_source_dir_does_not_exits(sagemaker_session):
9694
script = 'mnist.py'
9795
directory = ' !@#$%^&*()path probably in not there.!@#$%^&*()'
9896
with pytest.raises(ValueError) as error:
99-
tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, directory)
97+
validate_source_dir(script, directory)
10098
assert 'does not exist' in str(error)
10199

102100

103-
def test_tar_and_upload_dir_is_not_directory(sagemaker_session):
104-
bucket = 'mybucker'
105-
s3_key_prefix = 'something/source'
101+
def test_validate_source_dir_is_not_directory(sagemaker_session):
106102
script = 'mnist.py'
107103
directory = inspect.getfile(inspect.currentframe())
108104
with pytest.raises(ValueError) as error:
109-
tar_and_upload_dir(sagemaker_session, bucket, s3_key_prefix, script, directory)
105+
validate_source_dir(script, directory)
110106
assert 'is not a directory' in str(error)
111107

112108

tests/unit/test_image.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,12 @@ def test_serve(up, copy, copytree, tmpdir, sagemaker_session):
226226

227227
image = 'my-image'
228228
sagemaker_container = _SageMakerContainer('local', 1, image, sagemaker_session=sagemaker_session)
229-
primary_container = {'ModelDataUrl': '/some/model/path', 'Environment': {'env1': 1, 'env2': 'b'}}
229+
primary_container = {'ModelDataUrl': '/some/model/path',
230+
'Environment': {'env1': 1,
231+
'env2': 'b',
232+
'SAGEMAKER_SUBMIT_DIRECTORY': 's3://some/path'
233+
}
234+
}
230235

231236
sagemaker_container.serve(primary_container)
232237
docker_compose_file = os.path.join(sagemaker_container.container_root, 'docker-compose.yaml')

tests/unit/test_kmeans.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
@pytest.fixture()
3939
def sagemaker_session():
4040
boto_mock = Mock(name='boto_session', region_name=REGION)
41-
sms = Mock(name='sagemaker_session', boto_session=boto_mock)
41+
sms = Mock(name='sagemaker_session', boto_session=boto_mock,
42+
region_name=REGION, config=None, local_mode=False)
4243
sms.boto_region_name = REGION
4344
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
4445
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',

tests/unit/test_lda.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
@pytest.fixture()
3838
def sagemaker_session():
3939
boto_mock = Mock(name='boto_session', region_name=REGION)
40-
sms = Mock(name='sagemaker_session', boto_session=boto_mock)
40+
sms = Mock(name='sagemaker_session', boto_session=boto_mock,
41+
region_name=REGION, config=None, local_mode=False)
4142
sms.boto_region_name = REGION
4243
sms.default_bucket = Mock(name='default_bucket', return_value=BUCKET_NAME)
4344
sms.sagemaker_client.describe_training_job = Mock(name='describe_training_job',

0 commit comments

Comments
 (0)