Skip to content

Commit 4f92fbd

Browse files
authored
Use new repository names for Framework images (#114)
* Use new repository names for Framework images * Update changelog * Fix flake8 * Fix framework_name_from_image to support both old and new image names * Add unit test on tensorflow attach for new repo name * Update docstring on legacy ecr repo naming
1 parent 0bbd9ba commit 4f92fbd

File tree

5 files changed

+143
-36
lines changed

5 files changed

+143
-36
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
CHANGELOG
33
=========
44

5+
1.1.dev4
6+
========
7+
* feature: Frameworks: Use more idiomatic ECR repository naming scheme
8+
59
1.1.3
610
========
711

src/sagemaker/fw_utils.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,28 +28,40 @@
2828
"""
2929

3030

31-
def create_image_uri(region, framework, instance_type, framework_version, py_version, account='520713654638'):
31+
def create_image_uri(region, framework, instance_type, framework_version, py_version, account='520713654638',
32+
optimized_families=[]):
3233
"""Return the ECR URI of an image.
3334
3435
Args:
3536
region (str): AWS region where the image is uploaded.
3637
framework (str): framework used by the image.
37-
instance_type (str): EC2 instance type. Used to determine whether to use the CPU image or GPU image.
38+
instance_type (str): SageMaker instance type. Used to determine device type (cpu/gpu/family-specific optimized).
3839
framework_version (str): The version of the framework.
3940
py_version (str): Python version. One of 'py2' or 'py3'.
4041
account (str): AWS account that contains the image. (default: '520713654638')
42+
optimized_families (str): Instance families for which there exist specific optimized images.
4143
4244
Returns:
4345
str: The appropriate image URI based on the given parameters.
4446
"""
45-
device_type = 'cpu'
46-
# Instance types that start with G, P are GPU powered: https://aws.amazon.com/sagemaker/pricing/instance-types/
47-
if instance_type[3] in ['g', 'p']:
47+
48+
if not instance_type.startswith('ml.'):
49+
raise ValueError('{} is not a valid SageMaker instance type. See: '
50+
'https://aws.amazon.com/sagemaker/pricing/instance-types/'.format(instance_type))
51+
family = instance_type.split('.')[1]
52+
53+
# For some frameworks, we have optimized images for specific families, e.g c5 or p3. In those cases,
54+
# we use the family name in the image tag. In other cases, we use 'cpu' or 'gpu'.
55+
if family in optimized_families:
56+
device_type = family
57+
elif family[0] in ['g', 'p']:
4858
device_type = 'gpu'
59+
else:
60+
device_type = 'cpu'
4961

5062
tag = "{}-{}-{}".format(framework_version, device_type, py_version)
51-
return "{}.dkr.ecr.{}.amazonaws.com/sagemaker-{}-{}-{}:{}" \
52-
.format(account, region, framework, py_version, device_type, tag)
63+
return "{}.dkr.ecr.{}.amazonaws.com/sagemaker-{}:{}" \
64+
.format(account, region, framework, tag)
5365

5466

5567
def tar_and_upload_dir(session, bucket, s3_key_prefix, script, directory):
@@ -107,8 +119,13 @@ def framework_name_from_image(image_name):
107119
"""Extract the framework and Python version from the image name.
108120
109121
Args:
110-
image_name (str): Image URI, which should take the form
111-
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<framework>-<py_ver>-<device>:<tag>'
122+
image_name (str): Image URI, which should be one of the following forms:
123+
legacy:
124+
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>-<py_ver>-<device>:<container_version>'
125+
legacy:
126+
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>-<py_ver>-<device>:<fw_version>-<device>-<py_ver>'
127+
current:
128+
'<account>.dkr.ecr.<region>.amazonaws.com/sagemaker-<fw>:<fw_version>-<device>-<py_ver>'
112129
113130
Returns:
114131
tuple: A tuple containing:
@@ -123,14 +140,19 @@ def framework_name_from_image(image_name):
123140
return None, None, None
124141
else:
125142
# extract framework, python version and image tag
126-
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$')
127-
143+
# We must support both the legacy and current image name format.
144+
name_pattern = re.compile('^sagemaker-(tensorflow|mxnet):(.*?)-(.*?)-(py2|py3)$')
145+
legacy_name_pattern = re.compile('^sagemaker-(tensorflow|mxnet)-(py2|py3)-(cpu|gpu):(.*)$')
128146
name_match = name_pattern.match(sagemaker_match.group(8))
147+
legacy_match = legacy_name_pattern.match(sagemaker_match.group(8))
129148

130-
if name_match is None:
131-
return None, None, None
149+
if name_match is not None:
150+
fw, ver, device, py = name_match.group(1), name_match.group(2), name_match.group(3), name_match.group(4)
151+
return fw, py, '{}-{}-{}'.format(ver, device, py)
152+
elif legacy_match is not None:
153+
return legacy_match.group(1), legacy_match.group(2), legacy_match.group(4)
132154
else:
133-
return name_match.group(1), name_match.group(2), name_match.group(4)
155+
return None, None, None
134156

135157

136158
def framework_version_from_tag(image_tag):

tests/unit/test_fw_utils.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,18 +37,42 @@ def sagemaker_session():
3737

3838

3939
def test_create_image_uri_cpu():
40-
image_uri = create_image_uri('mars-south-3', 'mlfw', 'any-non-gpu-device', '1.0rc', 'py2', '23')
41-
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw-py2-cpu:1.0rc-cpu-py2'
40+
image_uri = create_image_uri('mars-south-3', 'mlfw', 'ml.c4.large', '1.0rc', 'py2', '23')
41+
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2'
4242

4343

4444
def test_create_image_uri_gpu():
4545
image_uri = create_image_uri('mars-south-3', 'mlfw', 'ml.p3.2xlarge', '1.0rc', 'py3', '23')
46-
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw-py3-gpu:1.0rc-gpu-py3'
46+
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3'
4747

4848

4949
def test_create_image_uri_default_account():
5050
image_uri = create_image_uri('mars-south-3', 'mlfw', 'ml.p3.2xlarge', '1.0rc', 'py3')
51-
assert image_uri == '520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw-py3-gpu:1.0rc-gpu-py3'
51+
assert image_uri == '520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3'
52+
53+
54+
def test_invalid_instance_type():
55+
# instance type is missing 'ml.' prefix
56+
with pytest.raises(ValueError):
57+
create_image_uri('mars-south-3', 'mlfw', 'p3.2xlarge', '1.0.0', 'py3')
58+
59+
60+
def test_optimized_family():
61+
image_uri = create_image_uri('mars-south-3', 'mlfw', 'ml.p3.2xlarge', '1.0.0', 'py3',
62+
optimized_families=['c5', 'p3'])
63+
assert image_uri == '520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0.0-p3-py3'
64+
65+
66+
def test_unoptimized_cpu_family():
67+
image_uri = create_image_uri('mars-south-3', 'mlfw', 'ml.m4.xlarge', '1.0.0', 'py3',
68+
optimized_families=['c5', 'p3'])
69+
assert image_uri == '520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0.0-cpu-py3'
70+
71+
72+
def test_unoptimized_gpu_family():
73+
image_uri = create_image_uri('mars-south-3', 'mlfw', 'ml.p2.xlarge', '1.0.0', 'py3',
74+
optimized_families=['c5', 'p3'])
75+
assert image_uri == '520713654638.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0.0-gpu-py3'
5276

5377

5478
def test_tar_and_upload_dir_s3(sagemaker_session):
@@ -99,36 +123,46 @@ def test_tar_and_upload_dir_not_s3(sagemaker_session):
99123
assert result == UploadedCode('s3://{}/{}/sourcedir.tar.gz'.format(bucket, s3_key_prefix), script)
100124

101125

102-
def test_framework_name_from_framework_image():
126+
def test_framework_name_from_image_mxnet():
127+
image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:1.1-gpu-py3'
128+
assert ('mxnet', 'py3', '1.1-gpu-py3') == framework_name_from_image(image_name)
129+
130+
131+
def test_framework_name_from_image_tf():
132+
image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:1.6-cpu-py2'
133+
assert ('tensorflow', 'py2', '1.6-cpu-py2') == framework_name_from_image(image_name)
134+
135+
136+
def test_legacy_name_from_framework_image():
103137
image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py3-gpu:2.5.6-gpu-py2'
104138
framework, py_ver, tag = framework_name_from_image(image_name)
105139
assert framework == 'mxnet'
106140
assert py_ver == 'py3'
107141
assert tag == '2.5.6-gpu-py2'
108142

109143

110-
def test_framework_name_from_wrong_framework():
144+
def test_legacy_name_from_wrong_framework():
111145
framework, py_ver, tag = framework_name_from_image('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1')
112146
assert framework is None
113147
assert py_ver is None
114148
assert tag is None
115149

116150

117-
def test_framework_name_from_wrong_python():
151+
def test_legacy_name_from_wrong_python():
118152
framework, py_ver, tag = framework_name_from_image('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
119153
assert framework is None
120154
assert py_ver is None
121155
assert tag is None
122156

123157

124-
def test_framework_name_from_wrong_device():
158+
def test_legacy_name_from_wrong_device():
125159
framework, py_ver, tag = framework_name_from_image('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1')
126160
assert framework is None
127161
assert py_ver is None
128162
assert tag is None
129163

130164

131-
def test_framework_name_from_image_any_tag():
165+
def test_legacy_name_from_image_any_tag():
132166
image_name = '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:any-tag'
133167
framework, py_ver, tag = framework_name_from_image(image_name)
134168
assert framework == 'tensorflow'

tests/unit/test_mxnet.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
BUCKET_NAME = 'mybucket'
3131
INSTANCE_COUNT = 1
3232
INSTANCE_TYPE = 'ml.c4.4xlarge'
33-
IMAGE_CPU_NAME = 'sagemaker-mxnet-py2-cpu'
33+
IMAGE_CPU_NAME = 'sagemaker-mxnet'
3434
JOB_NAME = '{}-{}'.format(IMAGE_CPU_NAME, TIMESTAMP)
3535
FULL_IMAGE_URI = '520713654638.dkr.ecr.us-west-2.amazonaws.com/{}:{}-cpu-py2'
3636
ROLE = 'Dummy'
@@ -138,10 +138,10 @@ def test_mxnet(strftime, sagemaker_session, mxnet_version):
138138

139139
model = mx.create_model()
140140

141-
expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet-py2-gpu:{}-gpu-py2'
141+
expected_image_base = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-mxnet:{}-gpu-py2'
142142
assert {'Environment':
143143
{'SAGEMAKER_SUBMIT_DIRECTORY':
144-
's3://mybucket/sagemaker-mxnet-py2-cpu-{}/sourcedir.tar.gz'.format(TIMESTAMP),
144+
's3://mybucket/sagemaker-mxnet-{}/sourcedir.tar.gz'.format(TIMESTAMP),
145145
'SAGEMAKER_PROGRAM': 'dummy_script.py',
146146
'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false',
147147
'SAGEMAKER_REGION': 'us-west-2',

tests/unit/test_tf_estimator.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@
3131
BUCKET_NAME = 'mybucket'
3232
INSTANCE_COUNT = 1
3333
INSTANCE_TYPE = 'ml.c4.4xlarge'
34-
CPU_IMAGE_NAME = 'sagemaker-tensorflow-py2-cpu'
35-
GPU_IMAGE_NAME = 'sagemaker-tensorflow-py2-gpu'
36-
JOB_NAME = '{}-{}'.format(CPU_IMAGE_NAME, TIMESTAMP)
34+
IMAGE_REPO_NAME = 'sagemaker-tensorflow'
35+
JOB_NAME = '{}-{}'.format(IMAGE_REPO_NAME, TIMESTAMP)
3736
ROLE = 'Dummy'
3837
REGION = 'us-west-2'
3938
DOCKER_TAG = '1.0'
@@ -53,11 +52,11 @@ def sagemaker_session():
5352

5453

5554
def _get_full_cpu_image_uri(version):
56-
return IMAGE_URI_FORMAT_STRING.format(REGION, CPU_IMAGE_NAME, version, 'cpu', 'py2')
55+
return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_REPO_NAME, version, 'cpu', 'py2')
5756

5857

5958
def _get_full_gpu_image_uri(version):
60-
return IMAGE_URI_FORMAT_STRING.format(REGION, GPU_IMAGE_NAME, version, 'gpu', 'py2')
59+
return IMAGE_URI_FORMAT_STRING.format(REGION, IMAGE_REPO_NAME, version, 'gpu', 'py2')
6160

6261

6362
def _create_train_job(tf_version):
@@ -231,11 +230,11 @@ def test_tf(time, strftime, sagemaker_session, tf_version):
231230
'SAGEMAKER_REGION': 'us-west-2',
232231
'SAGEMAKER_CONTAINER_LOG_LEVEL': '20'
233232
},
234-
'Image': create_image_uri('us-west-2', "tensorflow", GPU_IMAGE_NAME, tf_version, "py2"),
235-
'ModelDataUrl': 's3://m/m.tar.gz'} == model.prepare_container_def(GPU_IMAGE_NAME)
233+
'Image': create_image_uri('us-west-2', "tensorflow", INSTANCE_TYPE, tf_version, "py2"),
234+
'ModelDataUrl': 's3://m/m.tar.gz'} == model.prepare_container_def(INSTANCE_TYPE)
236235

237-
assert 'cpu' in model.prepare_container_def(CPU_IMAGE_NAME)['Image']
238-
predictor = tf.deploy(1, GPU_IMAGE_NAME)
236+
assert 'cpu' in model.prepare_container_def(INSTANCE_TYPE)['Image']
237+
predictor = tf.deploy(1, INSTANCE_TYPE)
239238
assert isinstance(predictor, TensorFlowPredictor)
240239

241240

@@ -257,7 +256,7 @@ def test_run_tensorboard_locally_without_tensorboard_binary(time, strftime, pope
257256
def test_model(sagemaker_session, tf_version):
258257
model = TensorFlowModel("s3://some/data.tar.gz", role=ROLE, entry_point=SCRIPT_PATH,
259258
sagemaker_session=sagemaker_session)
260-
predictor = model.deploy(1, GPU_IMAGE_NAME)
259+
predictor = model.deploy(1, INSTANCE_TYPE)
261260
assert isinstance(predictor, TensorFlowPredictor)
262261

263262

@@ -410,6 +409,54 @@ def test_attach(sagemaker_session, tf_version):
410409
assert estimator.checkpoint_path == 's3://other/1508872349'
411410

412411

412+
def test_attach_new_repo_name(sagemaker_session, tf_version):
413+
training_image = '520713654638.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow:{}-cpu-py2'.format(tf_version)
414+
rjd = {'AlgorithmSpecification':
415+
{'TrainingInputMode': 'File',
416+
'TrainingImage': training_image},
417+
'HyperParameters':
418+
{'sagemaker_submit_directory': '"s3://some/sourcedir.tar.gz"',
419+
'checkpoint_path': '"s3://other/1508872349"',
420+
'sagemaker_program': '"iris-dnn-classifier.py"',
421+
'sagemaker_enable_cloudwatch_metrics': 'false',
422+
'sagemaker_container_log_level': '"logging.INFO"',
423+
'sagemaker_job_name': '"neo"',
424+
'training_steps': '100',
425+
'evaluation_steps': '10'},
426+
'RoleArn': 'arn:aws:iam::366:role/SageMakerRole',
427+
'ResourceConfig':
428+
{'VolumeSizeInGB': 30,
429+
'InstanceCount': 1,
430+
'InstanceType': 'ml.c4.xlarge'},
431+
'StoppingCondition': {'MaxRuntimeInSeconds': 24 * 60 * 60},
432+
'TrainingJobName': 'neo',
433+
'TrainingJobStatus': 'Completed',
434+
'OutputDataConfig': {'KmsKeyId': '',
435+
'S3OutputPath': 's3://place/output/neo'},
436+
'TrainingJobOutput': {'S3TrainingJobOutput': 's3://here/output.tar.gz'}}
437+
sagemaker_session.sagemaker_client.describe_training_job = Mock(name='describe_training_job', return_value=rjd)
438+
439+
estimator = TensorFlow.attach(training_job_name='neo', sagemaker_session=sagemaker_session)
440+
assert estimator.latest_training_job.job_name == 'neo'
441+
assert estimator.py_version == 'py2'
442+
assert estimator.framework_version == tf_version
443+
assert estimator.role == 'arn:aws:iam::366:role/SageMakerRole'
444+
assert estimator.train_instance_count == 1
445+
assert estimator.train_max_run == 24 * 60 * 60
446+
assert estimator.input_mode == 'File'
447+
assert estimator.training_steps == 100
448+
assert estimator.evaluation_steps == 10
449+
assert estimator.input_mode == 'File'
450+
assert estimator.base_job_name == 'neo'
451+
assert estimator.output_path == 's3://place/output/neo'
452+
assert estimator.output_kms_key == ''
453+
assert estimator.hyperparameters()['training_steps'] == '100'
454+
assert estimator.source_dir == 's3://some/sourcedir.tar.gz'
455+
assert estimator.entry_point == 'iris-dnn-classifier.py'
456+
assert estimator.checkpoint_path == 's3://other/1508872349'
457+
assert estimator.train_image() == training_image
458+
459+
413460
def test_attach_old_container(sagemaker_session):
414461
training_image = '1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-tensorflow-py2-cpu:1.0'
415462
rjd = {'AlgorithmSpecification':

0 commit comments

Comments
 (0)