Skip to content

MAISI Quality check #1789

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
66ae578
add quality check
Can-Zhao Aug 17, 2024
28f5c96
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2024
7c0bd11
add quality check
Can-Zhao Aug 17, 2024
91ad068
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 17, 2024
494ed36
refactor
Can-Zhao Aug 17, 2024
ad5afba
add docstring
Can-Zhao Aug 18, 2024
2d87cd5
Merge branch 'main' into maisi_quality
Can-Zhao Aug 20, 2024
74d4e87
rm unused import, correct typo, add statistics file
Can-Zhao Aug 20, 2024
38daf8a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2024
38aacc3
np.nanmedian
Can-Zhao Aug 20, 2024
b10550e
Merge branch 'maisi_quality' of https://github.com/Can-Zhao/tutorials…
Can-Zhao Aug 20, 2024
d1e096c
add logging
Can-Zhao Aug 20, 2024
0db8d90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2024
4e0c77e
add description on num_label_acceleration_thresh
Can-Zhao Aug 20, 2024
dd57eb3
Merge branch 'maisi_quality' of https://github.com/Can-Zhao/tutorials…
Can-Zhao Aug 20, 2024
36b3f0e
add description on quality check
Can-Zhao Aug 20, 2024
83d8811
add description on input FOV
Can-Zhao Aug 20, 2024
4e80b50
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2024
2fdedb0
add description on input FOV
Can-Zhao Aug 20, 2024
e1d0821
add description on input FOV
Can-Zhao Aug 20, 2024
88d84cd
Merge branch 'maisi_quality' of https://github.com/Can-Zhao/tutorials…
Can-Zhao Aug 20, 2024
c38039d
typo
Can-Zhao Aug 20, 2024
828c4e3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 20, 2024
b4b2159
typo
Can-Zhao Aug 20, 2024
6caf7ad
add description on input FOV
Can-Zhao Aug 20, 2024
7622ad6
add description on input FOV
Can-Zhao Aug 20, 2024
8266e9d
update checking on input FOV
Can-Zhao Aug 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions generation/maisi/scripts/quality_check.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import nibabel as nib
import numpy as np


def get_masked_data(label_data, image_data, labels):
"""
Extracts and returns the image data corresponding to specified labels within a 3D volume.

This function efficiently masks the `image_data` array based on the provided `labels` in the `label_data` array.
The function handles cases with both a large and small number of labels, optimizing performance accordingly.

Args:
label_data (np.ndarray): A NumPy array containing label data, representing different anatomical
regions or classes in a 3D medical image.
image_data (np.ndarray): A NumPy array containing the image data from which the relevant regions
will be extracted.
labels (list of int): A list of integers representing the label values to be used for masking.

Returns:
np.ndarray: A NumPy array containing the elements of `image_data` that correspond to the specified
labels in `label_data`. If no labels are provided, an empty array is returned.

Raises:
ValueError: If `image_data` and `label_data` do not have the same shape.

Example:
label_int_dict = {"liver": [1], "kidney": [5, 14]}
masked_data = get_masked_data(label_data, image_data, label_int_dict["kidney"])
"""

# Check if the shapes of image_data and label_data match
if image_data.shape != label_data.shape:
raise ValueError(
f"Shape mismatch: image_data has shape {image_data.shape}, "
f"but label_data has shape {label_data.shape}. They must be the same."
)

if not labels:
return np.array([]) # Return an empty array if no labels are provided

# Optimize performance based on the number of labels
if len(labels) >= 3:
label_set = set(labels) # Convert labels to a set for faster membership testing
mask = np.isin(label_data, list(label_set))
else:
# Use logical OR to combine masks if the number of labels is small
mask = np.zeros_like(label_data, dtype=bool)
for label in labels:
mask = np.logical_or(mask, label_data == label)

# Retrieve the masked data
masked_data = image_data[mask.astype(bool)]

return masked_data


def is_outlier(statistics, image_data, label_data, label_int_dict):
"""
Perform a quality check on the generated image by comparing its statistics with precomputed thresholds.

Args:
statistics (dict): Dictionary containing precomputed statistics including mean +/- 3sigma ranges.
image_data (np.ndarray): The image data to be checked, typically a 3D NumPy array.
label_data (np.ndarray): The label data corresponding to the image, used for masking regions of interest.
label_int_dict (dict): Dictionary mapping label names to their corresponding integer lists.
e.g., label_int_dict = {"liver": [1], "kidney": [5, 14]}

Returns:
dict: A dictionary with labels as keys, each containing the quality check result,
including whether it's an outlier, the median value, and the thresholds used.
If no data is found for a label, the median value will be `None` and `is_outlier` will be `False`.

Example:
# Example input data
statistics = {
"liver": {
"sigma_6_low": -21.596463547885904,
"sigma_6_high": 156.27881534763367
},
"kidney": {
"sigma_6_low": -15.0,
"sigma_6_high": 120.0
}
}
label_int_dict = {
"liver": [1],
"kidney": [5, 14]
}
image_data = np.random.rand(100, 100, 100) # Replace with actual image data
label_data = np.zeros((100, 100, 100)) # Replace with actual label data
label_data[40:60, 40:60, 40:60] = 1 # Example region for liver
label_data[70:90, 70:90, 70:90] = 5 # Example region for kidney
result = is_outlier(statistics, image_data, label_data, label_int_dict)
"""
outlier_results = {}

for label_name, stats in statistics.items():
# Get the thresholds from the statistics
low_thresh = stats["sigma_6_low"] # or "sigma_12_low" depending on your needs
high_thresh = stats["sigma_6_high"] # or "sigma_12_high" depending on your needs

# Retrieve the corresponding label integers
labels = label_int_dict.get(label_name, [])
masked_data = get_masked_data(label_data, image_data, labels)
masked_data = masked_data[~np.isnan(masked_data)]

if len(masked_data) == 0 or masked_data.size == 0:
outlier_results[label_name] = {
"is_outlier": False,
"median_value": None,
"low_thresh": low_thresh,
"high_thresh": high_thresh,
}
continue

# Compute the median of the masked region
median_value = np.median(masked_data)

if np.isnan(median_value):
median_value = None
is_outlier = False
else:
# Determine if the median value is an outlier
is_outlier = median_value < low_thresh or median_value > high_thresh

outlier_results[label_name] = {
"is_outlier": is_outlier,
"median_value": median_value,
"low_thresh": low_thresh,
"high_thresh": high_thresh,
}

return outlier_results
44 changes: 33 additions & 11 deletions generation/maisi/scripts/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .augmentation import augmentation
from .find_masks import find_masks
from .utils import binarize_labels, general_mask_generation_post_process, get_body_region_index_from_mask, remap_labels
from .quality_check import is_outlier


class ReconModel(torch.nn.Module):
Expand Down Expand Up @@ -497,7 +498,7 @@ def __init__(
controllable_anatomy_size,
image_output_ext=".nii.gz",
label_output_ext=".nii.gz",
quality_check_args=None,
real_img_median_statistics="./configs/image_median_statistics.json",
spacing=[1, 1, 1],
num_inference_steps=None,
mask_generation_num_inference_steps=None,
Expand Down Expand Up @@ -563,9 +564,26 @@ def __init__(
self.autoencoder_sliding_window_infer_size = autoencoder_sliding_window_infer_size
self.autoencoder_sliding_window_infer_overlap = autoencoder_sliding_window_infer_overlap

# quality check disabled for this version
self.quality_check_args = quality_check_args
# quality check args
self.max_try_time = 5 # if not pass quality check, will try self.max_try_time times
with open(real_img_median_statistics, "r") as json_file:
self.median_statistics = json.load(json_file)
self.label_int_dict = {
"liver": [1],
"spleen": [3],
"pancreas": [4],
"kidney": [5, 14],
"lung": [28, 29, 30, 31, 31],
"brain": [22],
"hepatic tumor": [26],
"bone lesion": [128],
"lung tumor": [23],
"colon cancer primaries": [27],
"pancreatic tumor": [24],
"bone": list(range(33, 57)) + list(range(63, 98)) + [120, 122, 127],
}

# networks
self.autoencoder.eval()
self.diffusion_unet.eval()
self.controlnet.eval()
Expand Down Expand Up @@ -669,8 +687,10 @@ def sample_multiple_images(self, num_img):
spacing_tensor,
)
# current quality always return True
pass_quality_check = self.quality_check(synthetic_images)
if pass_quality_check or try_time > 3:
pass_quality_check = self.quality_check(
synthetic_images.cpu().detach().numpy(), comebine_label_or.cpu().detach().numpy()
)
if pass_quality_check or try_time > self.max_try_time:
# save image/label pairs
output_postfix = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
synthetic_labels.meta["filename_or_obj"] = "sample.nii.gz"
Expand Down Expand Up @@ -1006,15 +1026,17 @@ def find_closest_masks(self, num_img):
raise ValueError("Cannot find body region with given organ list.")
return final_candidates

def quality_check(self, image):
def quality_check(self, image_data, label_data):
"""
Perform a quality check on the generated image. This version disabled quality check and always return True.

Perform a quality check on the generated image.
Args:
image (torch.Tensor): The generated image.

image_data (np.ndarray): The generated image.
label_data (np.ndarray): The corresponding whole body mask.
Returns:
bool: True if the image passes the quality check, False otherwise.
"""
# This version disabled quality check
outlier_results = is_outlier(self.median_statistics, image_data, label_data, self.label_int_dict)
for label, result in outlier_results.items():
if result.get("is_outlier", False):
return False
return True
Loading