|
13 | 13 | import logging
|
14 | 14 | import os
|
15 | 15 | import pickle
|
| 16 | +import tempfile |
16 | 17 | import time
|
17 | 18 | import zipfile
|
18 | 19 | from copy import deepcopy
|
@@ -62,43 +63,110 @@ def get_bundle_config(bundle_path, config_names):
|
62 | 63 | Gets the configuration parser from the specified Torchscript bundle file path.
|
63 | 64 | """
|
64 | 65 |
|
65 |
| - def _read_from_archive(archive, root_name: str, relative_path: str, path_list: List[str]): |
66 |
| - """A helper function for reading a file in an zip archive. |
| 66 | + bundle_suffixes = (".json", ".yaml", "yml") # The only supported file ext(s) |
| 67 | + config_folder = "extra" |
67 | 68 |
|
68 |
| - Tries to read with the full path of # a archive file, if error, then find the relative |
69 |
| - path and then read the file. |
| 69 | + def _read_from_archive(archive, root_name: str, config_name: str, do_search=True): |
| 70 | + """A helper function for reading the content of a config in the zip archive. |
| 71 | +
|
| 72 | + Tries to read config content at the expected path in the archive, if error occurs, |
| 73 | + search and read with alternative paths. |
70 | 74 | """
|
| 75 | + |
71 | 76 | content_text = None
|
72 |
| - try: |
73 |
| - content_text = archive.read(f"{root_name}/{relative_path}") |
74 |
| - except KeyError: |
75 |
| - logging.debug(f"Trying to find the metadata/config file in the bundle archive: {relative_path}.") |
76 |
| - for n in path_list: |
77 |
| - if relative_path in n: |
78 |
| - content_text = archive.read(n) |
79 |
| - break |
80 |
| - if content_text is None: |
81 |
| - raise |
| 77 | + config_name = config_name.split(".")[0] # In case ext is present |
| 78 | + |
| 79 | + # Try directly read with constructed and expected path into the archive |
| 80 | + for suffix in bundle_suffixes: |
| 81 | + try: |
| 82 | + path = Path(root_name, config_folder, config_name).with_suffix(suffix) |
| 83 | + logging.debug(f"Trying to read config '{config_name}' content from {path}.") |
| 84 | + content_text = archive.read(str(path)) |
| 85 | + break |
| 86 | + except Exception: |
| 87 | + logging.debug(f"Error reading from {path}. Will try alternative ways.") |
| 88 | + continue |
| 89 | + |
| 90 | + # Try search for the name in the name list of the archive |
| 91 | + if not content_text and do_search: |
| 92 | + logging.debug(f"Trying to find the file in the archive for config '{config_name}'.") |
| 93 | + name_list = archive.namelist() |
| 94 | + for suffix in bundle_suffixes: |
| 95 | + for n in name_list: |
| 96 | + if (f"{config_name}{suffix}").casefold in n.casefold(): |
| 97 | + logging.debug(f"Trying to read content of config '{config_name}' from {n}.") |
| 98 | + content_text = archive.read(n) |
| 99 | + break |
| 100 | + |
| 101 | + if not content_text: |
| 102 | + raise IOError(f"Cannot read config {config_name}{bundle_suffixes} or its content in the archive.") |
82 | 103 |
|
83 | 104 | return content_text
|
84 | 105 |
|
| 106 | + def _extract_from_archive( |
| 107 | + archive, root_name: str, config_names: List[str], dest_folder: Union[str, Path], do_search=True |
| 108 | + ): |
| 109 | + """A helper function for extract files of configs from the archive to the destination folder |
| 110 | +
|
| 111 | + Tries to extract with the full paths from the archive file, if error occurs, tries to search for |
| 112 | + and read from the file(s) if do_search is true. |
| 113 | + """ |
| 114 | + |
| 115 | + config_names = [cn.split(".")[0] for cn in config_names] # In case the extension is present |
| 116 | + file_list = [] |
| 117 | + |
| 118 | + # Try directly read first with path into the archive |
| 119 | + for suffix in bundle_suffixes: |
| 120 | + try: |
| 121 | + logging.debug(f"Trying to extract {config_names} with ext {suffix}.") |
| 122 | + file_list = [str(Path(root_name, config_folder, cn).with_suffix(suffix)) for cn in config_names] |
| 123 | + archive.extractall(members=file_list, path=dest_folder) |
| 124 | + break |
| 125 | + except Exception as ex: |
| 126 | + file_list = [] |
| 127 | + logging.debug(f"Will try file search after error on extracting {config_names} with {file_list}: {ex}") |
| 128 | + continue |
| 129 | + |
| 130 | + # If files not extracted, try search for expected files in the name list of the archive |
| 131 | + if (len(file_list) < 1) and do_search: |
| 132 | + logging.debug(f"Trying to find the config files in the archive for {config_names}.") |
| 133 | + name_list = archive.namelist() |
| 134 | + leftovers = deepcopy(config_names) # to track any that are not found. |
| 135 | + for cn in config_names: |
| 136 | + for suffix in bundle_suffixes: |
| 137 | + found = False |
| 138 | + for n in name_list: |
| 139 | + if (f"{cn}{suffix}").casefold() in n.casefold(): |
| 140 | + found = True |
| 141 | + archive.extract(member=n, path=dest_folder) |
| 142 | + break |
| 143 | + if found: |
| 144 | + leftovers.remove(cn) |
| 145 | + break |
| 146 | + |
| 147 | + if len(leftovers) > 0: |
| 148 | + raise IOError(f"Failed to extract content for these config(s): {leftovers}.") |
| 149 | + |
| 150 | + return file_list |
| 151 | + |
| 152 | + # End of helper functions |
| 153 | + |
85 | 154 | if isinstance(config_names, str):
|
86 | 155 | config_names = [config_names]
|
87 | 156 |
|
88 |
| - name, _ = os.path.splitext(os.path.basename(bundle_path)) |
| 157 | + name, _ = os.path.splitext(os.path.basename(bundle_path)) # bundle file name same archive folder name |
89 | 158 | parser = ConfigParser()
|
90 | 159 |
|
91 | 160 | # Parser to read the required metadata and extra config contents from the archive
|
92 |
| - with zipfile.ZipFile(bundle_path, "r") as archive: |
93 |
| - name_list = archive.namelist() |
94 |
| - metadata_relative_path = "extra/metadata.json" |
95 |
| - metadata_text = _read_from_archive(archive, name, metadata_relative_path, name_list) |
96 |
| - parser.read_meta(f=json.loads(metadata_text)) |
| 161 | + with tempfile.TemporaryDirectory() as tmp_dir: |
| 162 | + with zipfile.ZipFile(bundle_path, "r") as archive: |
| 163 | + metadata_config_name = "metadata" |
| 164 | + metadata_text = _read_from_archive(archive, name, metadata_config_name) |
| 165 | + parser.read_meta(f=json.loads(metadata_text)) |
97 | 166 |
|
98 |
| - for cn in config_names: |
99 |
| - config_relative_path = f"extra/{cn}.json" |
100 |
| - config_text = _read_from_archive(archive, name, config_relative_path, name_list) |
101 |
| - parser.read_config(f=json.loads(config_text)) |
| 167 | + # now get the other named configs |
| 168 | + file_list = _extract_from_archive(archive, name, config_names, tmp_dir) |
| 169 | + parser.read_config([Path(tmp_dir, f_path) for f_path in file_list]) |
102 | 170 |
|
103 | 171 | parser.parse()
|
104 | 172 |
|
@@ -261,7 +329,7 @@ def __init__(
|
261 | 329 | Defaults to "".
|
262 | 330 | bundle_path (Optional[str], optional): For completing . Defaults to None.
|
263 | 331 | bundle_config_names (BundleConfigNames, optional): Relevant config item names in a the bundle.
|
264 |
| - Defaults to None. |
| 332 | + Defaults to DEFAULT_BundleConfigNames. |
265 | 333 | """
|
266 | 334 |
|
267 | 335 | super().__init__(*args, **kwargs)
|
|
0 commit comments