10
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
- """Placeholder docstring """
13
+ """Utility methods used by framework classes """
14
14
from __future__ import absolute_import
15
15
16
- from collections import namedtuple
17
-
18
16
import os
19
17
import re
20
18
import shutil
21
19
import tempfile
20
+ from collections import namedtuple
22
21
23
22
import sagemaker .utils
24
- from sagemaker .utils import get_ecr_image_uri_prefix , ECR_URI_PATTERN
25
23
from sagemaker import s3
24
+ from sagemaker .utils import get_ecr_image_uri_prefix , ECR_URI_PATTERN
26
25
27
26
_TAR_SOURCE_FILENAME = "source.tar.gz"
28
27
69
68
"tensorflow-serving-eia" : "tensorflow-inference-eia" ,
70
69
"mxnet" : "mxnet-training" ,
71
70
"mxnet-serving" : "mxnet-inference" ,
71
+ "mxnet-serving-eia" : "mxnet-inference-eia" ,
72
72
"pytorch" : "pytorch-training" ,
73
73
"pytorch-serving" : "pytorch-inference" ,
74
- "mxnet-serving-eia" : "mxnet-inference-eia" ,
75
74
}
76
75
77
76
MERGED_FRAMEWORKS_LOWEST_VERSIONS = {
78
- "tensorflow-scriptmode" : [1 , 13 , 1 ],
77
+ "tensorflow-scriptmode" : { "py3" : [1 , 13 , 1 ], "py2" : [ 1 , 14 , 0 ]} ,
79
78
"tensorflow-serving" : [1 , 13 , 0 ],
80
79
"tensorflow-serving-eia" : [1 , 14 , 0 ],
81
- "mxnet" : [1 , 4 , 1 ],
82
- "mxnet-serving" : [1 , 4 , 1 ],
80
+ "mxnet" : {"py3" : [1 , 4 , 1 ], "py2" : [1 , 6 , 0 ]},
81
+ "mxnet-serving" : {"py3" : [1 , 4 , 1 ], "py2" : [1 , 6 , 0 ]},
82
+ "mxnet-serving-eia" : [1 , 4 , 1 ],
83
83
"pytorch" : [1 , 2 , 0 ],
84
84
"pytorch-serving" : [1 , 2 , 0 ],
85
- "mxnet-serving-eia" : [1 , 4 , 1 ],
86
85
}
87
86
88
87
89
88
def is_version_equal_or_higher (lowest_version , framework_version ):
90
89
"""Determine whether the ``framework_version`` is equal to or higher than
91
90
``lowest_version``
91
+
92
92
Args:
93
93
lowest_version (List[int]): lowest version represented in an integer
94
94
list
95
95
framework_version (str): framework version string
96
+
96
97
Returns:
97
- bool: Whether or not framework_version is equal to or higher than
98
- lowest_version
98
+ bool: Whether or not `` framework_version`` is equal to or higher than
99
+ `` lowest_version``
99
100
"""
100
101
version_list = [int (s ) for s in framework_version .split ("." )]
101
102
return version_list >= lowest_version [0 : len (version_list )]
102
103
103
104
104
- def _is_merged_versions (framework , framework_version ):
105
- """
105
+ def _is_dlc_version (framework , framework_version , py_version ):
106
+ """Return if the framework's version uses the corresponding DLC image.
107
+
106
108
Args:
107
- framework:
108
- framework_version:
109
+ framework (str): The framework name, e.g. "tensorflow-scriptmode"
110
+ framework_version (str): The framework version
111
+ py_version (str): The Python version, e.g. "py3"
112
+
113
+ Returns:
114
+ bool: Whether or not the framework's version uses the DLC image.
109
115
"""
110
116
lowest_version_list = MERGED_FRAMEWORKS_LOWEST_VERSIONS .get (framework )
117
+ if isinstance (lowest_version_list , dict ):
118
+ lowest_version_list = lowest_version_list [py_version ]
119
+
111
120
if lowest_version_list :
112
121
return is_version_equal_or_higher (lowest_version_list , framework_version )
113
122
return False
114
123
115
124
116
- def _using_merged_images (region , framework , py_version , framework_version ):
117
- """
118
- Args:
119
- region:
120
- framework:
121
- py_version:
122
- accelerator_type:
123
- framework_version:
124
- """
125
- is_gov_region = region in VALID_ACCOUNTS_BY_REGION
126
- not_py2 = py_version == "py3" or py_version is None
127
- is_merged_versions = _is_merged_versions (framework , framework_version )
128
-
129
- return (
130
- ((not is_gov_region ) or region in ASIMOV_VALID_ACCOUNTS_BY_REGION )
131
- and is_merged_versions
132
- # TODO: should be not mxnet-1.14.1-py2 instead?
133
- and (
134
- not_py2
135
- or _is_tf_14_or_later (framework , framework_version )
136
- or _is_pt_12_or_later (framework , framework_version )
137
- or _is_mxnet_16_or_later (framework , framework_version )
138
- )
139
- )
140
-
125
+ def _use_dlc_image (region , framework , py_version , framework_version ):
126
+ """Return if the DLC image should be used for the given framework,
127
+ framework version, Python version, and region.
141
128
142
- def _is_tf_14_or_later (framework , framework_version ):
143
- """
144
129
Args:
145
- framework:
146
- framework_version:
147
- """
148
- # Asimov team now owns Tensorflow 1.14.0 py2 and py3
149
- asimov_lowest_tf_py2 = [1 , 14 , 0 ]
150
- version = [int (s ) for s in framework_version .split ("." )]
151
- return (
152
- framework == "tensorflow-scriptmode" and version >= asimov_lowest_tf_py2 [0 : len (version )]
153
- )
154
-
130
+ region (str): The AWS region.
131
+ framework (str): The framework name, e.g. "tensorflow-scriptmode".
132
+ py_version (str): The Python version, e.g. "py3".
133
+ framework_version (str): The framework version.
155
134
156
- def _is_pt_12_or_later (framework , framework_version ):
157
- """
158
- Args:
159
- framework: Name of the frameowork
160
- framework_version: framework version
135
+ Returns:
136
+ bool: Whether or not to use the corresponding DLC image.
161
137
"""
162
- # Asimov team now owns PyTorch 1.2.0 py2 and py3
163
- asimov_lowest_pt = [1 , 2 , 0 ]
164
- version = [int (s ) for s in framework_version .split ("." )]
165
- is_pytorch = framework in ("pytorch" , "pytorch-serving" )
166
- return is_pytorch and version >= asimov_lowest_pt [0 : len (version )]
138
+ is_gov_region = region in VALID_ACCOUNTS_BY_REGION
139
+ is_dlc_version = _is_dlc_version (framework , framework_version , py_version )
167
140
168
-
169
- def _is_mxnet_16_or_later (framework , framework_version ):
170
- """
171
- Args:
172
- framework: Name of the frameowork
173
- framework_version: framework version
174
- """
175
- # Asimov team now owns MXNet 1.6.0 py2 and py3
176
- asimov_lowest_pt = [1 , 6 , 0 ]
177
- version = [int (s ) for s in framework_version .split ("." )]
178
- is_mxnet = framework in ("mxnet" , "mxnet-serving" )
179
- return is_mxnet and version >= asimov_lowest_pt [0 : len (version )]
141
+ return ((not is_gov_region ) or region in ASIMOV_VALID_ACCOUNTS_BY_REGION ) and is_dlc_version
180
142
181
143
182
144
def _registry_id (region , framework , py_version , account , framework_version ):
183
- """
145
+ """Return the Amazon ECR registry number (or AWS account ID) for
146
+ the given framework, framework version, Python version, and region.
147
+
184
148
Args:
185
- region:
186
- framework:
187
- py_version:
188
- account:
189
- accelerator_type:
190
- framework_version:
149
+ region (str): The AWS region.
150
+ framework (str): The framework name, e.g. "tensorflow-scriptmode".
151
+ py_version (str): The Python version, e.g. "py3".
152
+ account (str): The AWS account ID to use as a default.
153
+ framework_version (str): The framework version.
154
+
155
+ Returns:
156
+ str: The appropriate Amazon ECR registry number. If there is no
157
+ specific one for the framework, framework version, Python version,
158
+ and region, then ``account`` is returned.
191
159
"""
192
- if _using_merged_images (region , framework , py_version , framework_version ):
160
+ if _use_dlc_image (region , framework , py_version , framework_version ):
193
161
if region in ASIMOV_OPT_IN_ACCOUNTS_BY_REGION :
194
162
return ASIMOV_OPT_IN_ACCOUNTS_BY_REGION .get (region )
195
163
if region in ASIMOV_VALID_ACCOUNTS_BY_REGION :
@@ -211,6 +179,7 @@ def create_image_uri(
211
179
optimized_families = None ,
212
180
):
213
181
"""Return the ECR URI of an image.
182
+
214
183
Args:
215
184
region (str): AWS region where the image is uploaded.
216
185
framework (str): framework used by the image.
@@ -225,6 +194,7 @@ def create_image_uri(
225
194
accelerator_type (str): SageMaker Elastic Inference accelerator type.
226
195
optimized_families (str): Instance families for which there exist
227
196
specific optimized images.
197
+
228
198
Returns:
229
199
str: The appropriate image URI based on the given parameters.
230
200
"""
@@ -240,7 +210,7 @@ def create_image_uri(
240
210
):
241
211
framework += "-eia"
242
212
243
- # Handle Account Number for Gov Cloud and frameworks with DLC merged images
213
+ # Handle account number for specific cases (e.g. GovCloud, opt-in regions, DLC images etc.)
244
214
if account is None :
245
215
account = _registry_id (
246
216
region = region ,
@@ -271,18 +241,19 @@ def create_image_uri(
271
241
else :
272
242
device_type = "cpu"
273
243
274
- using_merged_images = _using_merged_images (region , framework , py_version , framework_version )
244
+ use_dlc_image = _use_dlc_image (region , framework , py_version , framework_version )
275
245
276
- if not py_version or (using_merged_images and framework == "tensorflow-serving-eia" ):
246
+ if not py_version or (use_dlc_image and framework == "tensorflow-serving-eia" ):
277
247
tag = "{}-{}" .format (framework_version , device_type )
278
248
else :
279
249
tag = "{}-{}-{}" .format (framework_version , device_type , py_version )
280
250
281
- if using_merged_images :
282
- return "{}/{}:{}" .format (
283
- get_ecr_image_uri_prefix (account , region ), MERGED_FRAMEWORKS_REPO_MAP [framework ], tag
284
- )
285
- return "{}/sagemaker-{}:{}" .format (get_ecr_image_uri_prefix (account , region ), framework , tag )
251
+ if use_dlc_image :
252
+ ecr_repo = MERGED_FRAMEWORKS_REPO_MAP [framework ]
253
+ else :
254
+ ecr_repo = "sagemaker-{}" .format (framework )
255
+
256
+ return "{}/{}:{}" .format (get_ecr_image_uri_prefix (account , region ), ecr_repo , tag )
286
257
287
258
288
259
def _accelerator_type_valid_for_framework (
0 commit comments