Skip to content

Commit a6f51b7

Browse files
chuyang-dengChuyang Denglaurenyu
authored
feature: add framework upgrade tool (#1762)
* feature: add framework upgrade tool * add framework upgrade test * add algo image upgrade logic * support for region update * move framework upgrade tool to cli folder * fix typo * fix auto changes * fix auto changes * fix auto changes * correct typos and simplify methods * refactor method and update docstring * simplify and rename method * fix dict KeyError and and unit tests Co-authored-by: Chuyang Deng <[email protected]> Co-authored-by: Lauren Yu <[email protected]>
1 parent abd873e commit a6f51b7

File tree

3 files changed

+501
-0
lines changed

3 files changed

+501
-0
lines changed

src/sagemaker/cli/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying athis file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tools for automating code updates"""
14+
from __future__ import absolute_import
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A Python script to upgrade framework versions and regions"""
14+
from __future__ import absolute_import
15+
16+
import argparse
17+
import json
18+
import os
19+
20+
from sagemaker.image_uris import config_for_framework
21+
22+
IMAGE_URI_CONFIG_DIR = os.path.join("..", "image_uri_config")
23+
24+
25+
def get_latest_values(existing_content, scope=None):
26+
"""Get the latest "registries", "py_versions" and "repository" values
27+
28+
Args:
29+
existing_content (dict): Dictionary of complete framework image information.
30+
scope (str): Type of the image, required if the target is DLC
31+
framework (Default: None).
32+
"""
33+
if scope in existing_content:
34+
existing_content = existing_content[scope]
35+
else:
36+
if "versions" not in existing_content:
37+
raise ValueError(
38+
"Invalid image scope: {}. Valid options: {}.".format(
39+
scope, ", ".join(existing_content.key())
40+
)
41+
)
42+
43+
latest_version = list(existing_content["versions"].keys())[-1]
44+
registries = existing_content["versions"][latest_version]["registries"]
45+
py_versions = existing_content["versions"][latest_version]["py_versions"]
46+
repository = existing_content["versions"][latest_version]["repository"]
47+
48+
return registries, py_versions, repository
49+
50+
51+
def _write_dict_to_json(filename, existing_content):
52+
"""Write a Python dictionary to a json file.
53+
54+
Args:
55+
filename (str): Name of the target json file.
56+
existing_content (dict): Dictionary to be written to the json file.
57+
"""
58+
with open(filename, "w") as f:
59+
json.dump(existing_content, f, sort_keys=True, indent=4)
60+
61+
62+
def add_dlc_framework_version(
63+
existing_content,
64+
short_version,
65+
full_version,
66+
scope,
67+
processors,
68+
py_versions,
69+
registries,
70+
repository,
71+
):
72+
"""Update DLC framework image uri json file with new version information.
73+
74+
Args:
75+
existing_content (dict): Existing framework image uri information read from
76+
"<framework>.json" file.
77+
framework (str): Framework name (e.g. tensorflow, pytorch, mxnet)
78+
short_version (str): Abbreviated framework version (e.g. 1.0, 1.5)
79+
full_version (str): Complete framework version (e.g. 1.0.0, 1.5.2)
80+
scope (str): Framework image type, it could be "training", "inference"
81+
or "eia"
82+
processors (list): Supported processors (e.g. ["cpu", "gpu"])
83+
py_versions (list): Supported Python versions (e.g. ["py3", "py37"])
84+
registries (dict): Framework image's region to account mapping.
85+
repository (str): Framework image's ECR repository.
86+
"""
87+
for processor in processors:
88+
if processor not in existing_content[scope]["processors"]:
89+
existing_content[scope]["processors"].append(processor)
90+
existing_content[scope]["version_aliases"][short_version] = full_version
91+
92+
new_version = {
93+
"registries": registries,
94+
"repository": repository,
95+
"py_versions": py_versions,
96+
}
97+
existing_content[scope]["versions"][full_version] = new_version
98+
99+
100+
def add_algo_version(
101+
existing_content,
102+
processors,
103+
scopes,
104+
full_version,
105+
py_versions,
106+
registries,
107+
repository,
108+
tag_prefix,
109+
):
110+
"""Update Algorithm image uri json file with new version information.
111+
112+
Args:
113+
existing_content (dict): Existing algorithm image uri information read from
114+
"<algorithm>.json" file.
115+
processors (list): Supported processors (e.g. ["cpu", "gpu"])
116+
scopes (list): Framework image type, it could be "training", "inference
117+
full_version (str): Complete framework version (e.g. 1.0.0, 1.5.2)
118+
py_versions (list): Supported Python versions (e.g. ["py3", "py37"])
119+
registries (dict): Algorithm image's region to account mapping.
120+
repository (str): Algorithm's corresponding repository name.
121+
tag_prefix (str): Algorithm image's tag prefix.
122+
"""
123+
for processor in processors:
124+
if processor not in existing_content["processors"]:
125+
existing_content["processors"].append(processor)
126+
for scope in scopes:
127+
if scope not in existing_content["scope"]:
128+
existing_content["scope"].append(scope)
129+
130+
new_version = {
131+
"py_versions": py_versions,
132+
"registries": registries,
133+
"repository": repository,
134+
}
135+
if tag_prefix:
136+
new_version["tag_prefix"] = tag_prefix
137+
existing_content["versions"][full_version] = new_version
138+
139+
140+
def add_region(existing_content, region, account):
141+
"""Add region account to framework/algorithm registries.
142+
143+
Args:
144+
existing_content (dict): Existing framework/algorithm image uri information read from
145+
json file.
146+
region (str): New region to be added to framework/algorithm registries (e.g. us-west-2).
147+
account (str): Region registry account number.
148+
"""
149+
if "scope" not in existing_content:
150+
for scope in existing_content:
151+
for version in existing_content[scope]["versions"]:
152+
existing_content[scope]["versions"][version]["registries"][region] = account
153+
else:
154+
for version in existing_content["versions"]:
155+
existing_content["versions"][version]["registries"][region] = account
156+
157+
158+
def add_version(
159+
existing_content, short_version, full_version, scope, processors, py_versions, tag_prefix,
160+
):
161+
"""Read framework image uri information from json file to a dictionary, update it with new
162+
framework version information, then write the dictionary back to json file.
163+
164+
Args:
165+
existing_content (dict): Existing framework/algorithm image uri information read from
166+
json file.
167+
short_version (str): Abbreviated framework version (e.g. 1.0, 1.5).
168+
full_version (str): Complete framework version (e.g. 1.0.0, 1.5.2).
169+
scope (str): Framework image type, it could be "training", "inference".
170+
processors (str): Supported processors (e.g. "cpu,gpu").
171+
py_versions (str): Supported Python versions (e.g. "py3,py37").
172+
tag_prefix (str): Algorithm image's tag prefix.
173+
"""
174+
py_versions = py_versions.split(",")
175+
processors = processors.split(",")
176+
latest_registries, latest_py_versions, latest_repository = get_latest_values(
177+
existing_content, scope
178+
)
179+
if not py_versions:
180+
py_versions = latest_py_versions
181+
182+
if scope in existing_content:
183+
add_dlc_framework_version(
184+
existing_content,
185+
short_version,
186+
full_version,
187+
scope,
188+
processors,
189+
py_versions,
190+
latest_registries,
191+
latest_repository,
192+
)
193+
else:
194+
scopes = scope.split(",")
195+
add_algo_version(
196+
existing_content,
197+
processors,
198+
scopes,
199+
full_version,
200+
py_versions,
201+
latest_registries,
202+
latest_repository,
203+
tag_prefix,
204+
)
205+
206+
207+
def main():
208+
"""Parse command line arguments, call corresponding methods."""
209+
parser = argparse.ArgumentParser(description="Framework upgrade tool.")
210+
parser.add_argument(
211+
"--framework", required=True, help="Name of the framework (e.g. tensorflow, mxnet, etc.)"
212+
)
213+
parser.add_argument("--short-version", help="Abbreviated framework version (e.g. 2.0)")
214+
parser.add_argument("--full-version", help="Full framework version (e.g. 2.0.1)")
215+
parser.add_argument("--processors", help="Suppoted processors (e.g. cpu, gpu)")
216+
parser.add_argument("--py-versions", help="Supported Python versions (e.g. py3,py37)")
217+
parser.add_argument("--scope", help="Scope for the Algorithm image (e.g. inference, training)")
218+
parser.add_argument(
219+
"--tag-prefix", help="Tag prefix of the Algorithm image (e.g. ray-0.8.5-torch)"
220+
)
221+
parser.add_argument("--region", help="New region to be added (e.g. us-west-2)")
222+
parser.add_argument("--account", help="Registry account of new region")
223+
224+
args = parser.parse_args()
225+
226+
content = config_for_framework(args.framework)
227+
228+
if args.region or args.account:
229+
if args.region and not args.account or args.account and not args.region:
230+
raise ValueError("--region and --account must be used together to expand region.")
231+
add_region(content, args.region, args.account)
232+
else:
233+
add_version(
234+
content,
235+
args.short_version,
236+
args.full_version,
237+
args.scope,
238+
args.processors,
239+
args.py_versions,
240+
args.tag_prefix,
241+
)
242+
243+
file = os.path.join(IMAGE_URI_CONFIG_DIR, "{}.json".format(args.framework))
244+
_write_dict_to_json(file, content)
245+
246+
247+
if __name__ == "__main__":
248+
main()

0 commit comments

Comments
 (0)