Skip to content

Commit 1d837cb

Browse files
author
Chuyang Deng
committed
refactor method and update docstring
1 parent 6f8f5bd commit 1d837cb

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

src/sagemaker/cli/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Tools for automating code update"""
13+
"""Tools for automating code updates"""
1414
from __future__ import absolute_import

src/sagemaker/cli/framework_upgrade.py

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -22,35 +22,33 @@
2222
IMAGE_URI_CONFIG_DIR = os.path.join("..", "image_uri_config")
2323

2424

25-
def _get_latest_values(existing_content, scope=None): # pylint: disable=W0621
26-
"""Get the latest "registries", "py_versions", "repository", values
25+
def _get_latest_values(existing_content, scope=None):
26+
"""Get the latest "registries", "py_versions" and "repository" values
2727
2828
Args:
2929
existing_content (dict): Dictionary of complete framework image information.
3030
scope (str): Type of the image, required if the target is DLC
3131
framework (Default: None).
3232
"""
33-
if "scope" not in existing_content:
34-
latest_version = list(existing_content["versions"].keys())[-1]
35-
registries = existing_content["versions"][latest_version]["registries"]
36-
py_versions = existing_content["versions"][latest_version][ # pylint: disable=W0621
37-
"py_versions"
38-
]
39-
repository = existing_content["versions"][latest_version]["repository"]
33+
if scope in existing_content:
34+
existing_content = existing_content[scope]
4035
else:
41-
if scope is None:
36+
if "versions" not in existing_content:
4237
raise ValueError(
43-
"Image type ('training', 'inference', 'eia') is required for DLC framework."
38+
"Invalid image scope: {}. Valid options: {}.".format(
39+
scope, ", ".join(existing_content.key())
40+
)
4441
)
45-
latest_version = list(existing_content[scope]["versions"].keys())[-1]
46-
registries = existing_content[scope]["versions"][latest_version]["registries"]
47-
py_versions = existing_content[scope]["versions"][latest_version]["py_versions"]
48-
repository = existing_content[scope]["versions"][latest_version]["repository"]
42+
43+
latest_version = list(existing_content[scope]["versions"].keys())[-1]
44+
registries = existing_content[scope]["versions"][latest_version]["registries"]
45+
py_versions = existing_content[scope]["versions"][latest_version]["py_versions"]
46+
repository = existing_content[scope]["versions"][latest_version]["repository"]
4947

5048
return registries, py_versions, repository
5149

5250

53-
def _write_dict_to_json(filename, existing_content): # pylint: disable=W0621
51+
def _write_dict_to_json(filename, existing_content):
5452
"""Write a Python dictionary to a json file.
5553
5654
Args:
@@ -70,7 +68,7 @@ def add_dlc_framework_version(
7068
py_versions,
7169
registries,
7270
repository,
73-
): # pylint: disable=W0621
71+
):
7472
"""Update DLC framework image uri json file with new version information.
7573
7674
Args:
@@ -108,7 +106,7 @@ def add_algo_version(
108106
registries,
109107
repository,
110108
tag_prefix,
111-
): # pylint: disable=W0621
109+
):
112110
"""Update Algorithm image uri json file with new version information.
113111
114112
Args:
@@ -139,7 +137,7 @@ def add_algo_version(
139137
existing_content["versions"][full_version] = add_version
140138

141139

142-
def add_region(existing_content, region, account): # pylint: disable=W0621
140+
def add_region(existing_content, region, account):
143141
"""Add region account to framework/algorithm registries.
144142
145143
Args:
@@ -149,7 +147,7 @@ def add_region(existing_content, region, account): # pylint: disable=W0621
149147
account (str): Region registry account number.
150148
"""
151149
if "scope" not in existing_content:
152-
for scope in existing_content: # pylint: disable=W0621
150+
for scope in existing_content:
153151
for version in existing_content[scope]["versions"]:
154152
existing_content[scope]["versions"][version]["registries"][region] = account
155153
else:
@@ -159,7 +157,7 @@ def add_region(existing_content, region, account): # pylint: disable=W0621
159157

160158
def update_json(
161159
existing_content, short_version, full_version, scope, processors, py_versions, tag_prefix,
162-
): # pylint: disable=W0621
160+
):
163161
"""Read framework image uri information from json file to a dictionary, update it with new
164162
framework version information, then write the dictionary back to json file.
165163
@@ -206,7 +204,8 @@ def update_json(
206204
)
207205

208206

209-
if __name__ == "__main__":
207+
def main():
208+
"""Parse command line arguments, call corresponding methods."""
210209
parser = argparse.ArgumentParser(description="Framework upgrade tool.")
211210
parser.add_argument(
212211
"--framework", required=True, help="Name of the framework (e.g. tensorflow, mxnet, etc.)"
@@ -246,3 +245,7 @@ def update_json(
246245

247246
file = os.path.join(IMAGE_URI_CONFIG_DIR, "{}.json".format(framework))
248247
_write_dict_to_json(file, content)
248+
249+
250+
if __name__ == "__main__":
251+
main()

0 commit comments

Comments
 (0)