Skip to content

Enhanced code to extract and parse specific config files from TorchScript archive #378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion examples/apps/ai_spleen_seg_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,11 @@
from monai.deploy.operators.dicom_seg_writer_operator import DICOMSegmentationWriterOperator, SegmentDescription
from monai.deploy.operators.dicom_series_selector_operator import DICOMSeriesSelectorOperator
from monai.deploy.operators.dicom_series_to_volume_operator import DICOMSeriesToVolumeOperator
from monai.deploy.operators.monai_bundle_inference_operator import IOMapping, MonaiBundleInferenceOperator
from monai.deploy.operators.monai_bundle_inference_operator import (
BundleConfigNames,
IOMapping,
MonaiBundleInferenceOperator,
)

# from monai.deploy.operators.stl_conversion_operator import STLConversionOperator # import as needed.

Expand Down Expand Up @@ -62,9 +66,13 @@ def compose(self):
#
# Pertinent MONAI Bundle:
# https://github.com/Project-MONAI/model-zoo/tree/dev/models/spleen_ct_segmentation

config_names = BundleConfigNames(config_names=["inference"]) # Same as the default

bundle_spleen_seg_op = MonaiBundleInferenceOperator(
input_mapping=[IOMapping("image", Image, IOType.IN_MEMORY)],
output_mapping=[IOMapping("pred", Image, IOType.IN_MEMORY)],
bundle_config_names=config_names,
)

# Create DICOM Seg writer providing the required segment description for each segment with
Expand Down
118 changes: 93 additions & 25 deletions monai/deploy/operators/monai_bundle_inference_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
import os
import pickle
import tempfile
import time
import zipfile
from copy import deepcopy
Expand Down Expand Up @@ -62,43 +63,110 @@ def get_bundle_config(bundle_path, config_names):
Gets the configuration parser from the specified Torchscript bundle file path.
"""

def _read_from_archive(archive, root_name: str, relative_path: str, path_list: List[str]):
"""A helper function for reading a file in an zip archive.
bundle_suffixes = (".json", ".yaml", "yml") # The only supported file ext(s)
config_folder = "extra"

Tries to read with the full path of # a archive file, if error, then find the relative
path and then read the file.
def _read_from_archive(archive, root_name: str, config_name: str, do_search=True):
"""A helper function for reading the content of a config in the zip archive.

Tries to read config content at the expected path in the archive, if error occurs,
search and read with alternative paths.
"""

content_text = None
try:
content_text = archive.read(f"{root_name}/{relative_path}")
except KeyError:
logging.debug(f"Trying to find the metadata/config file in the bundle archive: {relative_path}.")
for n in path_list:
if relative_path in n:
content_text = archive.read(n)
break
if content_text is None:
raise
config_name = config_name.split(".")[0] # In case ext is present

# Try directly read with constructed and expected path into the archive
for suffix in bundle_suffixes:
try:
path = Path(root_name, config_folder, config_name).with_suffix(suffix)
logging.debug(f"Trying to read config '{config_name}' content from {path}.")
content_text = archive.read(str(path))
break
except Exception:
logging.debug(f"Error reading from {path}. Will try alternative ways.")
continue

# Try search for the name in the name list of the archive
if not content_text and do_search:
logging.debug(f"Trying to find the file in the archive for config '{config_name}'.")
name_list = archive.namelist()
for suffix in bundle_suffixes:
for n in name_list:
if (f"{config_name}{suffix}").casefold in n.casefold():
logging.debug(f"Trying to read content of config '{config_name}' from {n}.")
content_text = archive.read(n)
break

if not content_text:
raise IOError(f"Cannot read config {config_name}{bundle_suffixes} or its content in the archive.")

return content_text

def _extract_from_archive(
archive, root_name: str, config_names: List[str], dest_folder: Union[str, Path], do_search=True
):
"""A helper function for extract files of configs from the archive to the destination folder

Tries to extract with the full paths from the archive file, if error occurs, tries to search for
and read from the file(s) if do_search is true.
"""

config_names = [cn.split(".")[0] for cn in config_names] # In case the extension is present
file_list = []

# Try directly read first with path into the archive
for suffix in bundle_suffixes:
try:
logging.debug(f"Trying to extract {config_names} with ext {suffix}.")
file_list = [str(Path(root_name, config_folder, cn).with_suffix(suffix)) for cn in config_names]
archive.extractall(members=file_list, path=dest_folder)
break
except Exception as ex:
file_list = []
logging.debug(f"Will try file search after error on extracting {config_names} with {file_list}: {ex}")
continue

# If files not extracted, try search for expected files in the name list of the archive
if (len(file_list) < 1) and do_search:
logging.debug(f"Trying to find the config files in the archive for {config_names}.")
name_list = archive.namelist()
leftovers = deepcopy(config_names) # to track any that are not found.
for cn in config_names:
for suffix in bundle_suffixes:
found = False
for n in name_list:
if (f"{cn}{suffix}").casefold() in n.casefold():
found = True
archive.extract(member=n, path=dest_folder)
break
if found:
leftovers.remove(cn)
break

if len(leftovers) > 0:
raise IOError(f"Failed to extract content for these config(s): {leftovers}.")

return file_list

# End of helper functions

if isinstance(config_names, str):
config_names = [config_names]

name, _ = os.path.splitext(os.path.basename(bundle_path))
name, _ = os.path.splitext(os.path.basename(bundle_path)) # bundle file name same archive folder name
parser = ConfigParser()

# Parser to read the required metadata and extra config contents from the archive
with zipfile.ZipFile(bundle_path, "r") as archive:
name_list = archive.namelist()
metadata_relative_path = "extra/metadata.json"
metadata_text = _read_from_archive(archive, name, metadata_relative_path, name_list)
parser.read_meta(f=json.loads(metadata_text))
with tempfile.TemporaryDirectory() as tmp_dir:
with zipfile.ZipFile(bundle_path, "r") as archive:
metadata_config_name = "metadata"
metadata_text = _read_from_archive(archive, name, metadata_config_name)
parser.read_meta(f=json.loads(metadata_text))

for cn in config_names:
config_relative_path = f"extra/{cn}.json"
config_text = _read_from_archive(archive, name, config_relative_path, name_list)
parser.read_config(f=json.loads(config_text))
# now get the other named configs
file_list = _extract_from_archive(archive, name, config_names, tmp_dir)
parser.read_config([Path(tmp_dir, f_path) for f_path in file_list])

parser.parse()

Expand Down Expand Up @@ -261,7 +329,7 @@ def __init__(
Defaults to "".
bundle_path (Optional[str], optional): For completing . Defaults to None.
bundle_config_names (BundleConfigNames, optional): Relevant config item names in a the bundle.
Defaults to None.
Defaults to DEFAULT_BundleConfigNames.
"""

super().__init__(*args, **kwargs)
Expand Down