10
10
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11
11
# ANY KIND, either express or implied. See the License for the specific
12
12
# language governing permissions and limitations under the License.
13
- """A Python script to upgrade framework versions"""
13
+ """A Python script to upgrade framework versions and regions """
14
14
from __future__ import absolute_import
15
15
16
16
import argparse
17
17
import json
18
+ import os
18
19
20
+ from sagemaker .image_uris import config_for_framework
19
21
20
- DLC_FRAMEWORK = { "tensorflow " , "mxnet" , "pytorch" }
22
+ IMAGE_URI_CONFIG_DIR = os . path . join ( ".. " , "image_uri_config" )
21
23
22
24
23
- def _read_json_to_dict (filename ):
24
- """Read a json file into a Python dictionary
25
-
26
- Args:
27
- filename (str): Name of the json file.
28
-
29
- Returns:
30
- dict: A Python Dictionary.
31
- """
32
- with open (filename , "r" ) as f :
33
- existing_content = json .load (f )
34
- return existing_content
35
-
36
-
37
- def _get_latest_values (framework , existing_content , image_type = None ): # pylint: disable=W0621
25
+ def _get_latest_values (existing_content , scope = None ): # pylint: disable=W0621
38
26
"""Get the latest "registries", "py_versions", "repository", values
39
27
40
28
Args:
41
- framework (str): Name of the target framework.
42
29
existing_content (dict): Dictionary of complete framework image information.
43
- image_type (str): Type of the image, required if the target is DLC
30
+ scope (str): Type of the image, required if the target is DLC
44
31
framework (Default: None).
45
32
"""
46
- if framework not in DLC_FRAMEWORK :
33
+ if "scope" not in existing_content :
47
34
latest_version = list (existing_content ["versions" ].keys ())[- 1 ]
48
35
registries = existing_content ["versions" ][latest_version ]["registries" ]
49
36
py_versions = existing_content ["versions" ][latest_version ][ # pylint: disable=W0621
50
37
"py_versions"
51
38
]
52
39
repository = existing_content ["versions" ][latest_version ]["repository" ]
53
40
else :
54
- if image_type is None :
41
+ if scope is None :
55
42
raise ValueError (
56
- "Image type ('training', 'inference', 'eia') " " is required for DLC framework."
43
+ "Image type ('training', 'inference', 'eia') is required for DLC framework."
57
44
)
58
- latest_version = list (existing_content [image_type ]["versions" ].keys ())[- 1 ]
59
- registries = existing_content [image_type ]["versions" ][latest_version ]["registries" ]
60
- py_versions = existing_content [image_type ]["versions" ][latest_version ]["py_versions" ]
61
- repository = existing_content [image_type ]["versions" ][latest_version ]["repository" ]
45
+ latest_version = list (existing_content [scope ]["versions" ].keys ())[- 1 ]
46
+ registries = existing_content [scope ]["versions" ][latest_version ]["registries" ]
47
+ py_versions = existing_content [scope ]["versions" ][latest_version ]["py_versions" ]
48
+ repository = existing_content [scope ]["versions" ][latest_version ]["repository" ]
62
49
63
50
return registries , py_versions , repository
64
51
@@ -78,7 +65,7 @@ def add_dlc_framework_version(
78
65
existing_content ,
79
66
short_version ,
80
67
full_version ,
81
- image_type ,
68
+ scope ,
82
69
processors ,
83
70
py_versions ,
84
71
registries ,
@@ -92,24 +79,24 @@ def add_dlc_framework_version(
92
79
framework (str): Framework name (e.g. tensorflow, pytorch, mxnet)
93
80
short_version (str): Abbreviated framework version (e.g. 1.0, 1.5)
94
81
full_version (str): Complete framework version (e.g. 1.0.0, 1.5.2)
95
- image_type (str): Framework image type, it could be "training", "inference"
82
+ scope (str): Framework image type, it could be "training", "inference"
96
83
or "eia"
97
84
processors (list): Supported processors (e.g. ["cpu", "gpu"])
98
85
py_versions (list): Supported Python versions (e.g. ["py3", "py37"])
99
86
registries (dict): Framework image's region to account mapping.
100
87
repository (str): Framework image's ECR repository.
101
88
"""
102
89
for processor in processors :
103
- if processor not in existing_content [image_type ]["processors" ]:
104
- existing_content [image_type ]["processors" ].append (processor )
105
- existing_content [image_type ]["version_aliases" ][short_version ] = full_version
90
+ if processor not in existing_content [scope ]["processors" ]:
91
+ existing_content [scope ]["processors" ].append (processor )
92
+ existing_content [scope ]["version_aliases" ][short_version ] = full_version
106
93
107
94
add_version = {
108
95
"registries" : registries ,
109
96
"repository" : repository ,
110
97
"py_versions" : py_versions ,
111
98
}
112
- existing_content [image_type ]["versions" ][full_version ] = add_version
99
+ existing_content [scope ]["versions" ][full_version ] = add_version
113
100
114
101
115
102
def add_algo_version (
@@ -152,18 +139,17 @@ def add_algo_version(
152
139
existing_content ["versions" ][full_version ] = add_version
153
140
154
141
155
- def add_region (framework , existing_content , region , account ): # pylint: disable=W0621
142
+ def add_region (existing_content , region , account ): # pylint: disable=W0621
156
143
"""Add region account to framework/algorithm registries.
157
144
158
145
Args:
159
- framework (str): Framework name (e.g. tensorflow, pytorch, mxnet).
160
146
existing_content (dict): Existing framework/algorithm image uri information read from
161
147
json file.
162
148
region (str): New region to be added to framework/algorithm registries (e.g. us-west-2).
163
149
account (str): Region registry account number.
164
150
"""
165
- if framework in DLC_FRAMEWORK :
166
- for scope in existing_content :
151
+ if "scope" not in existing_content :
152
+ for scope in existing_content : # pylint: disable=W0621
167
153
for version in existing_content [scope ]["versions" ]:
168
154
existing_content [scope ]["versions" ][version ]["registries" ][region ] = account
169
155
else :
@@ -172,52 +158,42 @@ def add_region(framework, existing_content, region, account): # pylint: disable
172
158
173
159
174
160
def update_json (
175
- framework ,
176
- existing_content ,
177
- short_version ,
178
- full_version ,
179
- image_type ,
180
- scopes ,
181
- processors ,
182
- py_versions ,
183
- tag_prefix ,
161
+ existing_content , short_version , full_version , scope , processors , py_versions , tag_prefix ,
184
162
): # pylint: disable=W0621
185
163
"""Read framework image uri information from json file to a dictionary, update it with new
186
164
framework version information, then write the dictionary back to json file.
187
165
188
166
Args:
189
- framework (str): Framework name (e.g. tensorflow, pytorch, mxnet).
190
167
existing_content (dict): Existing framework/algorithm image uri information read from
191
168
json file.
192
169
short_version (str): Abbreviated framework version (e.g. 1.0, 1.5).
193
170
full_version (str): Complete framework version (e.g. 1.0.0, 1.5.2).
194
- image_type (str): Framework image type, it could be "training", "inference" or "eia".
195
- scopes (str): Framework image type, it could be "training", "inference".
171
+ scope (str): Framework image type, it could be "training", "inference".
196
172
processors (str): Supported processors (e.g. "cpu,gpu").
197
173
py_versions (str): Supported Python versions (e.g. "py3,py37").
198
174
tag_prefix (str): Algorithm image's tag prefix.
199
175
"""
200
176
py_versions = py_versions .split ("," )
201
177
processors = processors .split ("," )
202
178
latest_registries , latest_py_versions , latest_repository = _get_latest_values (
203
- framework , existing_content , image_type
179
+ existing_content , scope
204
180
)
205
181
if not py_versions :
206
182
py_versions = latest_py_versions
207
183
208
- if framework in DLC_FRAMEWORK :
184
+ if scope in existing_content :
209
185
add_dlc_framework_version (
210
186
existing_content ,
211
187
short_version ,
212
188
full_version ,
213
- image_type ,
189
+ scope ,
214
190
processors ,
215
191
py_versions ,
216
192
latest_registries ,
217
193
latest_repository ,
218
194
)
219
195
else :
220
- scopes = scopes .split ("," )
196
+ scopes = scope .split ("," )
221
197
add_algo_version (
222
198
existing_content ,
223
199
processors ,
@@ -232,15 +208,14 @@ def update_json(
232
208
233
209
if __name__ == "__main__" :
234
210
parser = argparse .ArgumentParser (description = "Framework upgrade tool." )
235
- parser .add_argument ("--framework" , help = "Name of the framework (e.g. tensorflow, mxnet, etc.)" )
211
+ parser .add_argument (
212
+ "--framework" , required = True , help = "Name of the framework (e.g. tensorflow, mxnet, etc.)"
213
+ )
236
214
parser .add_argument ("--short-version" , help = "Abbreviated framework version (e.g. 2.0)" )
237
215
parser .add_argument ("--full-version" , help = "Full framework version (e.g. 2.0.1)" )
238
- parser .add_argument ("--image-type" , help = "Framework image type (e.g. training, inference, eia)" )
239
216
parser .add_argument ("--processors" , help = "Suppoted processors (e.g. cpu, gpu)" )
240
217
parser .add_argument ("--py-versions" , help = "Supported Python versions (e.g. py3,py37)" )
241
- parser .add_argument (
242
- "--scopes" , help = "Scopes for the Algorithm image (e.g. inference, training)"
243
- )
218
+ parser .add_argument ("--scope" , help = "Scope for the Algorithm image (e.g. inference, training)" )
244
219
parser .add_argument (
245
220
"--tag-prefix" , help = "Tag prefix of the Algorithm image (e.g. ray-0.8.5-torch)"
246
221
)
@@ -251,34 +226,23 @@ def update_json(
251
226
framework = args .framework
252
227
short_version = args .short_version
253
228
full_version = args .full_version
254
- image_type = args .image_type
255
229
processors = args .processors
256
230
py_versions = args .py_versions
257
- scopes = args .scopes
231
+ scope = args .scope
258
232
tag_prefix = args .tag_prefix
259
233
region = args .region
260
234
account = args .account
261
235
262
- if not framework :
263
- raise ValueError ("Please specify a framework or algorithm name to upgrade" )
264
- file = "../image_uri_config/{}.json" .format (framework )
265
- content = _read_json_to_dict (file )
236
+ content = config_for_framework (framework )
266
237
267
238
if region or account :
268
239
if region and not account or account and not region :
269
240
raise ValueError ("--region and --account must be used together to expand region." )
270
- add_region (framework , content , region , account )
241
+ add_region (content , region , account )
271
242
else :
272
243
update_json (
273
- framework ,
274
- content ,
275
- short_version ,
276
- full_version ,
277
- image_type ,
278
- scopes ,
279
- processors ,
280
- py_versions ,
281
- tag_prefix ,
244
+ content , short_version , full_version , scope , processors , py_versions , tag_prefix ,
282
245
)
283
246
247
+ file = os .path .join (IMAGE_URI_CONFIG_DIR , "{}.json" .format (framework ))
284
248
_write_dict_to_json (file , content )
0 commit comments