Skip to content

Commit 6f8f5bd

Browse files
author
Chuyang Deng
committed
correct typos and simplify methods
1 parent 02f3684 commit 6f8f5bd

File tree

3 files changed

+46
-80
lines changed

3 files changed

+46
-80
lines changed

src/sagemaker/cli/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Tools to upgrade framework/algorithm versions and regions"""
13+
"""Tools for automating code update"""
1414
from __future__ import absolute_import

src/sagemaker/cli/framework_upgrade.py

Lines changed: 37 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -10,55 +10,42 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""A Python script to upgrade framework versions"""
13+
"""A Python script to upgrade framework versions and regions"""
1414
from __future__ import absolute_import
1515

1616
import argparse
1717
import json
18+
import os
1819

20+
from sagemaker.image_uris import config_for_framework
1921

20-
DLC_FRAMEWORK = {"tensorflow", "mxnet", "pytorch"}
22+
IMAGE_URI_CONFIG_DIR = os.path.join("..", "image_uri_config")
2123

2224

23-
def _read_json_to_dict(filename):
24-
"""Read a json file into a Python dictionary
25-
26-
Args:
27-
filename (str): Name of the json file.
28-
29-
Returns:
30-
dict: A Python Dictionary.
31-
"""
32-
with open(filename, "r") as f:
33-
existing_content = json.load(f)
34-
return existing_content
35-
36-
37-
def _get_latest_values(framework, existing_content, image_type=None): # pylint: disable=W0621
25+
def _get_latest_values(existing_content, scope=None): # pylint: disable=W0621
3826
"""Get the latest "registries", "py_versions", "repository", values
3927
4028
Args:
41-
framework (str): Name of the target framework.
4229
existing_content (dict): Dictionary of complete framework image information.
43-
image_type (str): Type of the image, required if the target is DLC
30+
scope (str): Type of the image, required if the target is DLC
4431
framework (Default: None).
4532
"""
46-
if framework not in DLC_FRAMEWORK:
33+
if "scope" not in existing_content:
4734
latest_version = list(existing_content["versions"].keys())[-1]
4835
registries = existing_content["versions"][latest_version]["registries"]
4936
py_versions = existing_content["versions"][latest_version][ # pylint: disable=W0621
5037
"py_versions"
5138
]
5239
repository = existing_content["versions"][latest_version]["repository"]
5340
else:
54-
if image_type is None:
41+
if scope is None:
5542
raise ValueError(
56-
"Image type ('training', 'inference', 'eia') " "is required for DLC framework."
43+
"Image type ('training', 'inference', 'eia') is required for DLC framework."
5744
)
58-
latest_version = list(existing_content[image_type]["versions"].keys())[-1]
59-
registries = existing_content[image_type]["versions"][latest_version]["registries"]
60-
py_versions = existing_content[image_type]["versions"][latest_version]["py_versions"]
61-
repository = existing_content[image_type]["versions"][latest_version]["repository"]
45+
latest_version = list(existing_content[scope]["versions"].keys())[-1]
46+
registries = existing_content[scope]["versions"][latest_version]["registries"]
47+
py_versions = existing_content[scope]["versions"][latest_version]["py_versions"]
48+
repository = existing_content[scope]["versions"][latest_version]["repository"]
6249

6350
return registries, py_versions, repository
6451

@@ -78,7 +65,7 @@ def add_dlc_framework_version(
7865
existing_content,
7966
short_version,
8067
full_version,
81-
image_type,
68+
scope,
8269
processors,
8370
py_versions,
8471
registries,
@@ -92,24 +79,24 @@ def add_dlc_framework_version(
9279
framework (str): Framework name (e.g. tensorflow, pytorch, mxnet)
9380
short_version (str): Abbreviated framework version (e.g. 1.0, 1.5)
9481
full_version (str): Complete framework version (e.g. 1.0.0, 1.5.2)
95-
image_type (str): Framework image type, it could be "training", "inference"
82+
scope (str): Framework image type, it could be "training", "inference"
9683
or "eia"
9784
processors (list): Supported processors (e.g. ["cpu", "gpu"])
9885
py_versions (list): Supported Python versions (e.g. ["py3", "py37"])
9986
registries (dict): Framework image's region to account mapping.
10087
repository (str): Framework image's ECR repository.
10188
"""
10289
for processor in processors:
103-
if processor not in existing_content[image_type]["processors"]:
104-
existing_content[image_type]["processors"].append(processor)
105-
existing_content[image_type]["version_aliases"][short_version] = full_version
90+
if processor not in existing_content[scope]["processors"]:
91+
existing_content[scope]["processors"].append(processor)
92+
existing_content[scope]["version_aliases"][short_version] = full_version
10693

10794
add_version = {
10895
"registries": registries,
10996
"repository": repository,
11097
"py_versions": py_versions,
11198
}
112-
existing_content[image_type]["versions"][full_version] = add_version
99+
existing_content[scope]["versions"][full_version] = add_version
113100

114101

115102
def add_algo_version(
@@ -152,18 +139,17 @@ def add_algo_version(
152139
existing_content["versions"][full_version] = add_version
153140

154141

155-
def add_region(framework, existing_content, region, account): # pylint: disable=W0621
142+
def add_region(existing_content, region, account): # pylint: disable=W0621
156143
"""Add region account to framework/algorithm registries.
157144
158145
Args:
159-
framework (str): Framework name (e.g. tensorflow, pytorch, mxnet).
160146
existing_content (dict): Existing framework/algorithm image uri information read from
161147
json file.
162148
region (str): New region to be added to framework/algorithm registries (e.g. us-west-2).
163149
account (str): Region registry account number.
164150
"""
165-
if framework in DLC_FRAMEWORK:
166-
for scope in existing_content:
151+
if "scope" not in existing_content:
152+
for scope in existing_content: # pylint: disable=W0621
167153
for version in existing_content[scope]["versions"]:
168154
existing_content[scope]["versions"][version]["registries"][region] = account
169155
else:
@@ -172,52 +158,42 @@ def add_region(framework, existing_content, region, account): # pylint: disable
172158

173159

174160
def update_json(
175-
framework,
176-
existing_content,
177-
short_version,
178-
full_version,
179-
image_type,
180-
scopes,
181-
processors,
182-
py_versions,
183-
tag_prefix,
161+
existing_content, short_version, full_version, scope, processors, py_versions, tag_prefix,
184162
): # pylint: disable=W0621
185163
"""Read framework image uri information from json file to a dictionary, update it with new
186164
framework version information, then write the dictionary back to json file.
187165
188166
Args:
189-
framework (str): Framework name (e.g. tensorflow, pytorch, mxnet).
190167
existing_content (dict): Existing framework/algorithm image uri information read from
191168
json file.
192169
short_version (str): Abbreviated framework version (e.g. 1.0, 1.5).
193170
full_version (str): Complete framework version (e.g. 1.0.0, 1.5.2).
194-
image_type (str): Framework image type, it could be "training", "inference" or "eia".
195-
scopes (str): Framework image type, it could be "training", "inference".
171+
scope (str): Framework image type, it could be "training", "inference".
196172
processors (str): Supported processors (e.g. "cpu,gpu").
197173
py_versions (str): Supported Python versions (e.g. "py3,py37").
198174
tag_prefix (str): Algorithm image's tag prefix.
199175
"""
200176
py_versions = py_versions.split(",")
201177
processors = processors.split(",")
202178
latest_registries, latest_py_versions, latest_repository = _get_latest_values(
203-
framework, existing_content, image_type
179+
existing_content, scope
204180
)
205181
if not py_versions:
206182
py_versions = latest_py_versions
207183

208-
if framework in DLC_FRAMEWORK:
184+
if scope in existing_content:
209185
add_dlc_framework_version(
210186
existing_content,
211187
short_version,
212188
full_version,
213-
image_type,
189+
scope,
214190
processors,
215191
py_versions,
216192
latest_registries,
217193
latest_repository,
218194
)
219195
else:
220-
scopes = scopes.split(",")
196+
scopes = scope.split(",")
221197
add_algo_version(
222198
existing_content,
223199
processors,
@@ -232,15 +208,14 @@ def update_json(
232208

233209
if __name__ == "__main__":
234210
parser = argparse.ArgumentParser(description="Framework upgrade tool.")
235-
parser.add_argument("--framework", help="Name of the framework (e.g. tensorflow, mxnet, etc.)")
211+
parser.add_argument(
212+
"--framework", required=True, help="Name of the framework (e.g. tensorflow, mxnet, etc.)"
213+
)
236214
parser.add_argument("--short-version", help="Abbreviated framework version (e.g. 2.0)")
237215
parser.add_argument("--full-version", help="Full framework version (e.g. 2.0.1)")
238-
parser.add_argument("--image-type", help="Framework image type (e.g. training, inference, eia)")
239216
parser.add_argument("--processors", help="Suppoted processors (e.g. cpu, gpu)")
240217
parser.add_argument("--py-versions", help="Supported Python versions (e.g. py3,py37)")
241-
parser.add_argument(
242-
"--scopes", help="Scopes for the Algorithm image (e.g. inference, training)"
243-
)
218+
parser.add_argument("--scope", help="Scope for the Algorithm image (e.g. inference, training)")
244219
parser.add_argument(
245220
"--tag-prefix", help="Tag prefix of the Algorithm image (e.g. ray-0.8.5-torch)"
246221
)
@@ -251,34 +226,23 @@ def update_json(
251226
framework = args.framework
252227
short_version = args.short_version
253228
full_version = args.full_version
254-
image_type = args.image_type
255229
processors = args.processors
256230
py_versions = args.py_versions
257-
scopes = args.scopes
231+
scope = args.scope
258232
tag_prefix = args.tag_prefix
259233
region = args.region
260234
account = args.account
261235

262-
if not framework:
263-
raise ValueError("Please specify a framework or algorithm name to upgrade")
264-
file = "../image_uri_config/{}.json".format(framework)
265-
content = _read_json_to_dict(file)
236+
content = config_for_framework(framework)
266237

267238
if region or account:
268239
if region and not account or account and not region:
269240
raise ValueError("--region and --account must be used together to expand region.")
270-
add_region(framework, content, region, account)
241+
add_region(content, region, account)
271242
else:
272243
update_json(
273-
framework,
274-
content,
275-
short_version,
276-
full_version,
277-
image_type,
278-
scopes,
279-
processors,
280-
py_versions,
281-
tag_prefix,
244+
content, short_version, full_version, scope, processors, py_versions, tag_prefix,
282245
)
283246

247+
file = os.path.join(IMAGE_URI_CONFIG_DIR, "{}.json".format(framework))
284248
_write_dict_to_json(file, content)

tests/unit/sagemaker/cli/test_framework_version_upgrade_tool.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,20 @@ def dlc_expected_region_content():
141141

142142
@pytest.fixture
143143
def algo_region_content():
144-
region_info = {"versions": {"1.0": {"registries": {"us-west-2": "123456789012"}}}}
144+
region_info = {
145+
"scope": ["training"],
146+
"versions": {"1.0": {"registries": {"us-west-2": "123456789012"}}},
147+
}
145148
return region_info
146149

147150

148151
@pytest.fixture
149152
def algo_expected_region_content():
150153
region_info = {
154+
"scope": ["training"],
151155
"versions": {
152156
"1.0": {"registries": {"us-west-2": "123456789012", "us-east-1": "987654321098"}}
153-
}
157+
},
154158
}
155159
return region_info
156160

@@ -205,14 +209,12 @@ def test_add_coach_tensorflow(algo_content, algo_expected_content):
205209
def test_dlc_add_region(dlc_region_content, dlc_expected_region_content):
206210
region = "us-east-1"
207211
account = "987654321098"
208-
framework = "tensorflow"
209-
framework_upgrade.add_region(framework, dlc_region_content, region, account)
212+
framework_upgrade.add_region(dlc_region_content, region, account)
210213
assert dlc_region_content == dlc_expected_region_content
211214

212215

213216
def test_algo_add_region(algo_region_content, algo_expected_region_content):
214217
region = "us-east-1"
215218
account = "987654321098"
216-
framework = "blazingtext"
217-
framework_upgrade.add_region(framework, algo_region_content, region, account)
219+
framework_upgrade.add_region(algo_region_content, region, account)
218220
assert algo_region_content == algo_expected_region_content

0 commit comments

Comments
 (0)