22
22
IMAGE_URI_CONFIG_DIR = os .path .join (".." , "image_uri_config" )
23
23
24
24
25
- def _get_latest_values (existing_content , scope = None ): # pylint: disable=W0621
26
- """Get the latest "registries", "py_versions", "repository", values
25
+ def _get_latest_values (existing_content , scope = None ):
26
+ """Get the latest "registries", "py_versions" and "repository" values
27
27
28
28
Args:
29
29
existing_content (dict): Dictionary of complete framework image information.
30
30
scope (str): Type of the image, required if the target is DLC
31
31
framework (Default: None).
32
32
"""
33
- if "scope" not in existing_content :
34
- latest_version = list (existing_content ["versions" ].keys ())[- 1 ]
35
- registries = existing_content ["versions" ][latest_version ]["registries" ]
36
- py_versions = existing_content ["versions" ][latest_version ][ # pylint: disable=W0621
37
- "py_versions"
38
- ]
39
- repository = existing_content ["versions" ][latest_version ]["repository" ]
33
+ if scope in existing_content :
34
+ existing_content = existing_content [scope ]
40
35
else :
41
- if scope is None :
36
+ if "versions" not in existing_content :
42
37
raise ValueError (
43
- "Image type ('training', 'inference', 'eia') is required for DLC framework."
38
+ "Invalid image scope: {}. Valid options: {}." .format (
39
+ scope , ", " .join (existing_content .key ())
40
+ )
44
41
)
45
- latest_version = list (existing_content [scope ]["versions" ].keys ())[- 1 ]
46
- registries = existing_content [scope ]["versions" ][latest_version ]["registries" ]
47
- py_versions = existing_content [scope ]["versions" ][latest_version ]["py_versions" ]
48
- repository = existing_content [scope ]["versions" ][latest_version ]["repository" ]
42
+
43
+ latest_version = list (existing_content [scope ]["versions" ].keys ())[- 1 ]
44
+ registries = existing_content [scope ]["versions" ][latest_version ]["registries" ]
45
+ py_versions = existing_content [scope ]["versions" ][latest_version ]["py_versions" ]
46
+ repository = existing_content [scope ]["versions" ][latest_version ]["repository" ]
49
47
50
48
return registries , py_versions , repository
51
49
52
50
53
- def _write_dict_to_json (filename , existing_content ): # pylint: disable=W0621
51
+ def _write_dict_to_json (filename , existing_content ):
54
52
"""Write a Python dictionary to a json file.
55
53
56
54
Args:
@@ -70,7 +68,7 @@ def add_dlc_framework_version(
70
68
py_versions ,
71
69
registries ,
72
70
repository ,
73
- ): # pylint: disable=W0621
71
+ ):
74
72
"""Update DLC framework image uri json file with new version information.
75
73
76
74
Args:
@@ -108,7 +106,7 @@ def add_algo_version(
108
106
registries ,
109
107
repository ,
110
108
tag_prefix ,
111
- ): # pylint: disable=W0621
109
+ ):
112
110
"""Update Algorithm image uri json file with new version information.
113
111
114
112
Args:
@@ -139,7 +137,7 @@ def add_algo_version(
139
137
existing_content ["versions" ][full_version ] = add_version
140
138
141
139
142
- def add_region (existing_content , region , account ): # pylint: disable=W0621
140
+ def add_region (existing_content , region , account ):
143
141
"""Add region account to framework/algorithm registries.
144
142
145
143
Args:
@@ -149,7 +147,7 @@ def add_region(existing_content, region, account): # pylint: disable=W0621
149
147
account (str): Region registry account number.
150
148
"""
151
149
if "scope" not in existing_content :
152
- for scope in existing_content : # pylint: disable=W0621
150
+ for scope in existing_content :
153
151
for version in existing_content [scope ]["versions" ]:
154
152
existing_content [scope ]["versions" ][version ]["registries" ][region ] = account
155
153
else :
@@ -159,7 +157,7 @@ def add_region(existing_content, region, account): # pylint: disable=W0621
159
157
160
158
def update_json (
161
159
existing_content , short_version , full_version , scope , processors , py_versions , tag_prefix ,
162
- ): # pylint: disable=W0621
160
+ ):
163
161
"""Read framework image uri information from json file to a dictionary, update it with new
164
162
framework version information, then write the dictionary back to json file.
165
163
@@ -206,7 +204,8 @@ def update_json(
206
204
)
207
205
208
206
209
- if __name__ == "__main__" :
207
+ def main ():
208
+ """Parse command line arguments, call corresponding methods."""
210
209
parser = argparse .ArgumentParser (description = "Framework upgrade tool." )
211
210
parser .add_argument (
212
211
"--framework" , required = True , help = "Name of the framework (e.g. tensorflow, mxnet, etc.)"
@@ -246,3 +245,7 @@ def update_json(
246
245
247
246
file = os .path .join (IMAGE_URI_CONFIG_DIR , "{}.json" .format (framework ))
248
247
_write_dict_to_json (file , content )
248
+
249
+
250
+ if __name__ == "__main__" :
251
+ main ()
0 commit comments