13
13
from __future__ import absolute_import
14
14
15
15
import inspect
16
- from mock import Mock , patch
17
16
import os
18
- from sagemaker .fw_utils import create_image_uri , framework_name_from_image , framework_version_from_tag , \
19
- model_code_key_prefix
20
- from sagemaker .fw_utils import tar_and_upload_dir , parse_s3_url , UploadedCode , validate_source_dir
17
+
21
18
import pytest
19
+ from mock import Mock , patch
22
20
21
+ from sagemaker .fw_utils import create_image_uri , framework_name_from_image , \
22
+ framework_version_from_tag , \
23
+ model_code_key_prefix
24
+ from sagemaker .fw_utils import tar_and_upload_dir , parse_s3_url , UploadedCode , validate_source_dir
23
25
from sagemaker .utils import name_from_image
24
26
25
27
DATA_DIR = 'data_dir'
33
35
@pytest .fixture ()
34
36
def sagemaker_session ():
35
37
boto_mock = Mock (name = 'boto_session' , region_name = REGION )
36
- ims = Mock (name = 'sagemaker_session' , boto_session = boto_mock )
37
- ims .default_bucket = Mock (name = 'default_bucket' , return_value = BUCKET_NAME )
38
- ims .expand_role = Mock (name = "expand_role" , return_value = ROLE )
39
- ims .sagemaker_client .describe_training_job = Mock ( return_value = { 'ModelArtifacts' :
40
- {'S3ModelArtifacts' : 's3://m/m.tar.gz' }})
41
- return ims
38
+ session_mock = Mock (name = 'sagemaker_session' , boto_session = boto_mock )
39
+ session_mock .default_bucket = Mock (name = 'default_bucket' , return_value = BUCKET_NAME )
40
+ session_mock .expand_role = Mock (name = "expand_role" , return_value = ROLE )
41
+ session_mock .sagemaker_client .describe_training_job = \
42
+ Mock ( return_value = { 'ModelArtifacts' : {'S3ModelArtifacts' : 's3://m/m.tar.gz' }})
43
+ return session_mock
42
44
43
45
44
46
def test_create_image_uri_cpu ():
@@ -49,6 +51,16 @@ def test_create_image_uri_cpu():
49
51
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu-py2'
50
52
51
53
54
+ def test_create_image_uri_no_python ():
55
+ image_uri = create_image_uri ('mars-south-3' , 'mlfw' , 'ml.c4.large' , '1.0rc' , account = '23' )
56
+ assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-cpu'
57
+
58
+
59
+ def test_create_image_uri_bad_python ():
60
+ with pytest .raises (ValueError ):
61
+ create_image_uri ('mars-south-3' , 'mlfw' , 'ml.c4.large' , '1.0rc' , 'py0' )
62
+
63
+
52
64
def test_create_image_uri_gpu ():
53
65
image_uri = create_image_uri ('mars-south-3' , 'mlfw' , 'ml.p3.2xlarge' , '1.0rc' , 'py3' , '23' )
54
66
assert image_uri == '23.dkr.ecr.mars-south-3.amazonaws.com/sagemaker-mlfw:1.0rc-gpu-py3'
@@ -127,7 +139,8 @@ def test_tar_and_upload_dir_not_s3(sagemaker_session):
127
139
script = os .path .basename (__file__ )
128
140
directory = os .path .dirname (os .path .abspath (inspect .getfile (inspect .currentframe ())))
129
141
result = tar_and_upload_dir (sagemaker_session , bucket , s3_key_prefix , script , directory )
130
- assert result == UploadedCode ('s3://{}/{}/sourcedir.tar.gz' .format (bucket , s3_key_prefix ), script )
142
+ assert result == UploadedCode ('s3://{}/{}/sourcedir.tar.gz' .format (bucket , s3_key_prefix ),
143
+ script )
131
144
132
145
133
146
def test_framework_name_from_image_mxnet ():
@@ -149,21 +162,24 @@ def test_legacy_name_from_framework_image():
149
162
150
163
151
164
def test_legacy_name_from_wrong_framework ():
152
- framework , py_ver , tag = framework_name_from_image ('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1' )
165
+ framework , py_ver , tag = framework_name_from_image (
166
+ '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py2-gpu:1' )
153
167
assert framework is None
154
168
assert py_ver is None
155
169
assert tag is None
156
170
157
171
158
172
def test_legacy_name_from_wrong_python ():
159
- framework , py_ver , tag = framework_name_from_image ('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1' )
173
+ framework , py_ver , tag = framework_name_from_image (
174
+ '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1' )
160
175
assert framework is None
161
176
assert py_ver is None
162
177
assert tag is None
163
178
164
179
165
180
def test_legacy_name_from_wrong_device ():
166
- framework , py_ver , tag = framework_name_from_image ('123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1' )
181
+ framework , py_ver , tag = framework_name_from_image (
182
+ '123.dkr.ecr.us-west-2.amazonaws.com/sagemaker-myown-py4-gpu:1' )
167
183
assert framework is None
168
184
assert py_ver is None
169
185
assert tag is None
0 commit comments