Skip to content

Commit a504db4

Browse files
icywang86ruinadiaya
authored andcommitted
Fix script mode image async attach and deploy (#692)
Enable framework_name_from_image to parse script mode image URIs. Currently attach to a script job will generate an estimator fails to deploy because the sdk thinks this is an custom image and use the training image for hosting.
1 parent 3f7bd59 commit a504db4

File tree

10 files changed

+62
-25
lines changed

10 files changed

+62
-25
lines changed

src/sagemaker/chainer/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
157157
init_params[argument[len('sagemaker_'):]] = value
158158

159159
image_name = init_params.pop('image')
160-
framework, py_version, tag = framework_name_from_image(image_name)
160+
framework, py_version, tag, _ = framework_name_from_image(image_name)
161161

162162
if not framework:
163163
# If we were unable to parse the framework name from the image it is not one of our

src/sagemaker/fw_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -211,29 +211,31 @@ def framework_name_from_image(image_name):
211211
str: The framework name
212212
str: The Python version
213213
str: The image tag
214+
str: If the image is script mode
214215
"""
215216
sagemaker_pattern = re.compile(r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)amazonaws.com(/)(.*:.*)$')
216217
sagemaker_match = sagemaker_pattern.match(image_name)
217218
if sagemaker_match is None:
218-
return None, None, None
219+
return None, None, None, None
219220
else:
220221
# extract framework, python version and image tag
221222
# We must support both the legacy and current image name format.
222223
name_pattern = re.compile(
223-
r'^sagemaker(?:-rl)?-(tensorflow|mxnet|chainer|pytorch|scikit-learn):(.*)-(.*?)-(py2|py3)$')
224+
r'^sagemaker(?:-rl)?-(tensorflow|mxnet|chainer|pytorch|scikit-learn)(?:-)?(scriptmode)?:(.*)-(.*?)-(py2|py3)$') # noqa
224225
legacy_name_pattern = re.compile(
225226
r'^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$')
226227

227228
name_match = name_pattern.match(sagemaker_match.group(8))
228229
legacy_match = legacy_name_pattern.match(sagemaker_match.group(8))
229230

230231
if name_match is not None:
231-
fw, ver, device, py = name_match.group(1), name_match.group(2), name_match.group(3), name_match.group(4)
232-
return fw, py, '{}-{}-{}'.format(ver, device, py)
232+
fw, scriptmode, ver, device, py = name_match.group(1), name_match.group(2), name_match.group(3),\
233+
name_match.group(4), name_match.group(5)
234+
return fw, py, '{}-{}-{}'.format(ver, device, py), scriptmode
233235
elif legacy_match is not None:
234-
return legacy_match.group(1), legacy_match.group(2), legacy_match.group(4)
236+
return legacy_match.group(1), legacy_match.group(2), legacy_match.group(4), None
235237
else:
236-
return None, None, None
238+
return None, None, None, None
237239

238240

239241
def framework_version_from_tag(image_tag):

src/sagemaker/mxnet/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
132132
"""
133133
init_params = super(MXNet, cls)._prepare_init_params_from_job_description(job_details, model_channel_name)
134134
image_name = init_params.pop('image')
135-
framework, py_version, tag = framework_name_from_image(image_name)
135+
framework, py_version, tag, _ = framework_name_from_image(image_name)
136136

137137
if not framework:
138138
# If we were unable to parse the framework name from the image it is not one of our

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
114114
"""
115115
init_params = super(PyTorch, cls)._prepare_init_params_from_job_description(job_details, model_channel_name)
116116
image_name = init_params.pop('image')
117-
framework, py_version, tag = framework_name_from_image(image_name)
117+
framework, py_version, tag, _ = framework_name_from_image(image_name)
118118

119119
if not framework:
120120
# If we were unable to parse the framework name from the image it is not one of our

src/sagemaker/rl/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
253253
._prepare_init_params_from_job_description(job_details, model_channel_name)
254254

255255
image_name = init_params.pop('image')
256-
framework, _, tag = fw_utils.framework_name_from_image(image_name)
256+
framework, _, tag, _ = fw_utils.framework_name_from_image(image_name)
257257

258258
if not framework:
259259
# If we were unable to parse the framework name from the image it is not one of our

src/sagemaker/sklearn/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
136136
init_params = super(SKLearn, cls)._prepare_init_params_from_job_description(job_details)
137137

138138
image_name = init_params.pop('image')
139-
framework, py_version, _ = framework_name_from_image(image_name)
139+
framework, py_version, _, _ = framework_name_from_image(image_name)
140140
init_params['py_version'] = py_version
141141

142142
if framework and framework != cls.__framework_name__:

src/sagemaker/tensorflow/estimator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,13 +356,16 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
356356
init_params[argument] = value
357357

358358
image_name = init_params.pop('image')
359-
framework, py_version, tag = fw.framework_name_from_image(image_name)
359+
framework, py_version, tag, script_mode = fw.framework_name_from_image(image_name)
360360
if not framework:
361361
# If we were unable to parse the framework name from the image it is not one of our
362362
# officially supported images, in this case just add the image to the init params.
363363
init_params['image_name'] = image_name
364364
return init_params
365365

366+
if script_mode:
367+
init_params['script_mode'] = True
368+
366369
init_params['py_version'] = py_version
367370

368371
# We switched image tagging scheme from regular image version (e.g. '1.0') to more expressive

tests/data/tensorflow_mnist/mnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,6 @@ def serving_input_fn():
187187
tf.estimator.train_and_evaluate(mnist_classifier, train_spec, eval_spec)
188188

189189
if args.current_host == args.hosts[0]:
190-
mnist_classifier.export_savedmodel(args.model_dir, serving_input_fn)
190+
mnist_classifier.export_savedmodel('/opt/ml/model', serving_input_fn)
191191

192192
tf_logger.info('====== Training finished =========')

tests/integ/test_tf_script_mode.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,10 @@
1212
# language governing permissions and limitations under the License.
1313
from __future__ import absolute_import
1414

15+
import numpy as np
1516
import os
1617
import pytest
18+
import time
1719

1820
import boto3
1921
from sagemaker.tensorflow import TensorFlow
@@ -40,7 +42,7 @@ def test_mnist(sagemaker_session, instance_type):
4042
train_instance_type=instance_type,
4143
sagemaker_session=sagemaker_session,
4244
py_version='py3',
43-
framework_version='1.11',
45+
framework_version=TensorFlow.LATEST_VERSION,
4446
base_job_name='test-tf-sm-mnist')
4547
inputs = estimator.sagemaker_session.upload_data(
4648
path=os.path.join(RESOURCE_PATH, 'data'),
@@ -49,7 +51,7 @@ def test_mnist(sagemaker_session, instance_type):
4951
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
5052
estimator.fit(inputs)
5153
_assert_s3_files_exist(estimator.model_dir,
52-
['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta', 'saved_model.pb'])
54+
['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta'])
5355

5456

5557
@pytest.mark.canary_quick
@@ -63,7 +65,7 @@ def test_mnist_distributed(sagemaker_session, instance_type):
6365
sagemaker_session=sagemaker_session,
6466
py_version=integ.PYTHON_VERSION,
6567
script_mode=True,
66-
framework_version='1.11',
68+
framework_version=TensorFlow.LATEST_VERSION,
6769
distributions=PARAMETER_SERVER_DISTRIBUTION,
6870
base_job_name='test-tf-sm-mnist')
6971
inputs = estimator.sagemaker_session.upload_data(
@@ -73,7 +75,32 @@ def test_mnist_distributed(sagemaker_session, instance_type):
7375
with timeout.timeout(minutes=integ.TRAINING_DEFAULT_TIMEOUT_MINUTES):
7476
estimator.fit(inputs)
7577
_assert_s3_files_exist(estimator.model_dir,
76-
['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta', 'saved_model.pb'])
78+
['graph.pbtxt', 'model.ckpt-0.index', 'model.ckpt-0.meta'])
79+
80+
81+
def test_mnist_async(sagemaker_session):
82+
estimator = TensorFlow(entry_point=SCRIPT,
83+
role='SageMakerRole',
84+
train_instance_count=1,
85+
train_instance_type='ml.c5.4xlarge',
86+
sagemaker_session=sagemaker_session,
87+
py_version='py3',
88+
framework_version=TensorFlow.LATEST_VERSION,
89+
base_job_name='test-tf-sm-mnist')
90+
inputs = estimator.sagemaker_session.upload_data(
91+
path=os.path.join(RESOURCE_PATH, 'data'),
92+
key_prefix='scriptmode/mnist')
93+
estimator.fit(inputs, wait=False)
94+
training_job_name = estimator.latest_training_job.name
95+
time.sleep(20)
96+
endpoint_name = training_job_name
97+
with timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
98+
estimator = TensorFlow.attach(training_job_name=training_job_name, sagemaker_session=sagemaker_session)
99+
predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge',
100+
endpoint_name=endpoint_name)
101+
102+
result = predictor.predict(np.zeros(784))
103+
print('predict result: {}'.format(result))
77104

78105

79106
def _assert_s3_files_exist(s3_url, files):

tests/unit/test_fw_utils.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -354,45 +354,50 @@ def walk():
354354

355355
def test_framework_name_from_image_mxnet():
356356
image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.1-gpu-py3'
357-
assert ('mxnet', 'py3', '1.1-gpu-py3') == fw_utils.framework_name_from_image(image_name)
357+
assert ('mxnet', 'py3', '1.1-gpu-py3', None) == fw_utils.framework_name_from_image(image_name)
358358

359359

360360
def test_framework_name_from_image_tf():
361361
image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.6-cpu-py2'
362-
assert ('tensorflow', 'py2', '1.6-cpu-py2') == fw_utils.framework_name_from_image(image_name)
362+
assert ('tensorflow', 'py2', '1.6-cpu-py2', None) == fw_utils.framework_name_from_image(image_name)
363+
364+
365+
def test_framework_name_from_image_tf_scriptmode():
366+
image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-scriptmode:1.12-cpu-py3'
367+
assert ('tensorflow', 'py3', '1.12-cpu-py3', 'scriptmode') == fw_utils.framework_name_from_image(image_name)
363368

364369

365370
def test_framework_name_from_image_rl():
366371
image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-rl-mxnet:toolkit1.1-gpu-py3'
367-
assert ('mxnet', 'py3', 'toolkit1.1-gpu-py3') == fw_utils.framework_name_from_image(image_name)
372+
assert ('mxnet', 'py3', 'toolkit1.1-gpu-py3', None) == fw_utils.framework_name_from_image(image_name)
368373

369374

370375
def test_legacy_name_from_framework_image():
371376
image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py3-gpu:2.5.6-gpu-py2'
372-
framework, py_ver, tag = fw_utils.framework_name_from_image(image_name)
377+
framework, py_ver, tag, _ = fw_utils.framework_name_from_image(image_name)
373378
assert framework == 'mxnet'
374379
assert py_ver == 'py3'
375380
assert tag == '2.5.6-gpu-py2'
376381

377382

378383
def test_legacy_name_from_wrong_framework():
379-
framework, py_ver, tag = fw_utils.framework_name_from_image(
384+
framework, py_ver, tag, _ = fw_utils.framework_name_from_image(
380385
'123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1')
381386
assert framework is None
382387
assert py_ver is None
383388
assert tag is None
384389

385390

386391
def test_legacy_name_from_wrong_python():
387-
framework, py_ver, tag = fw_utils.framework_name_from_image(
392+
framework, py_ver, tag, _ = fw_utils.framework_name_from_image(
388393
'123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
389394
assert framework is None
390395
assert py_ver is None
391396
assert tag is None
392397

393398

394399
def test_legacy_name_from_wrong_device():
395-
framework, py_ver, tag = fw_utils.framework_name_from_image(
400+
framework, py_ver, tag, _ = fw_utils.framework_name_from_image(
396401
'123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
397402
assert framework is None
398403
assert py_ver is None
@@ -401,7 +406,7 @@ def test_legacy_name_from_wrong_device():
401406

402407
def test_legacy_name_from_image_any_tag():
403408
image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:any-tag'
404-
framework, py_ver, tag = fw_utils.framework_name_from_image(image_name)
409+
framework, py_ver, tag, _ = fw_utils.framework_name_from_image(image_name)
405410
assert framework == 'tensorflow'
406411
assert py_ver == 'py2'
407412
assert tag == 'any-tag'

0 commit comments

Comments
 (0)