Skip to content

Commit 5a6672c

Browse files
authored
Update with the changes in sobel gradients (#5503)
Fixes #5500 ### Description Update `HoVerNetLoss` with the improvements and changes in `SobelGradients`. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). Signed-off-by: Behrooz <[email protected]>
1 parent d51edc3 commit 5a6672c

File tree

5 files changed

+36
-30
lines changed

5 files changed

+36
-30
lines changed

monai/apps/pathology/losses/hovernet_loss.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,21 @@ def __init__(
6161

6262
self.dice = DiceLoss(softmax=True, smooth_dr=1e-03, smooth_nr=1e-03, reduction="sum", batch=True)
6363
self.ce = CrossEntropyLoss(reduction="mean")
64-
self.sobel = SobelGradients(kernel_size=5)
64+
self.sobel_v = SobelGradients(kernel_size=5, spatial_axes=0)
65+
self.sobel_h = SobelGradients(kernel_size=5, spatial_axes=1)
6566

6667
def _compute_sobel(self, image: torch.Tensor) -> torch.Tensor:
68+
"""Compute the Sobel gradients of the horizontal vertical map (HoVerMap).
69+
More specifically, it will compute horizontal gradient of the input horizontal gradient map (channel=0) and
70+
vertical gradient of the input vertical gradient map (channel=1).
6771
68-
batch_size = image.shape[0]
69-
result_h = self.sobel(torch.squeeze(image[:, 0], dim=1))[batch_size:]
70-
result_v = self.sobel(torch.squeeze(image[:, 1], dim=1))[:batch_size]
72+
Args:
73+
image: a tensor with the shape of BxCxHxW representing HoVerMap
7174
72-
return torch.cat([result_h[:, None, ...], result_v[:, None, ...]], dim=1)
75+
"""
76+
result_h = self.sobel_h(image[:, 0])
77+
result_v = self.sobel_v(image[:, 1])
78+
return torch.stack([result_h, result_v], dim=1)
7379

7480
def _mse_gradient_loss(self, prediction: torch.Tensor, target: torch.Tensor, focus: torch.Tensor) -> torch.Tensor:
7581
"""Compute the MSE loss of the gradients of the horizontal and vertical centroid distance maps"""

monai/transforms/post/array.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,7 @@ def __call__(self, data):
832832

833833

834834
class SobelGradients(Transform):
835-
"""Calculate Sobel gradients of a grayscale image with the shape of (CxH[xWxDx...]).
835+
"""Calculate Sobel gradients of a grayscale image with the shape of CxH[xWxDx...] or BxH[xWxDx...].
836836
837837
Args:
838838
kernel_size: the size of the Sobel kernel. Defaults to 3.
@@ -922,7 +922,7 @@ def __call__(self, image: NdarrayOrTensor) -> torch.Tensor:
922922
grad_list = []
923923
for ax in spatial_axes:
924924
kernels = [kernel_smooth] * n_spatial_dims
925-
kernels[ax - 1] = kernel_diff
925+
kernels[ax] = kernel_diff
926926
grad = separable_filtering(image_tensor, kernels, mode=self.padding)
927927
if self.normalize_gradients:
928928
grad_min = grad.min()

tests/test_hovernet_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,17 +140,17 @@ def test_shape_generator(num_classes=1, num_objects=3, batch_size=1, height=5, w
140140

141141
TEST_CASE_3 = [ # batch size of 2, 3 classes with minor rotation of nuclear prediction
142142
{"prediction": inputs_test[3].inputs, "target": inputs_test[3].targets},
143-
3.6169,
143+
3.6348,
144144
]
145145

146146
TEST_CASE_4 = [ # batch size of 2, 3 classes with medium rotation of nuclear prediction
147147
{"prediction": inputs_test[4].inputs, "target": inputs_test[4].targets},
148-
4.5079,
148+
4.5312,
149149
]
150150

151151
TEST_CASE_5 = [ # batch size of 2, 3 classes with medium rotation of nuclear prediction
152152
{"prediction": inputs_test[5].inputs, "target": inputs_test[5].targets},
153-
5.4663,
153+
5.4929,
154154
]
155155

156156
CASES = [TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]

tests/test_sobel_gradient.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,17 @@
2222

2323
# Output with reflect padding
2424
OUTPUT_3x3 = torch.zeros(2, 16, 16, dtype=torch.float32)
25-
OUTPUT_3x3[1, 7, :] = 0.5
26-
OUTPUT_3x3[1, 9, :] = -0.5
25+
OUTPUT_3x3[0, 7, :] = 0.5
26+
OUTPUT_3x3[0, 9, :] = -0.5
2727

2828
# Output with zero padding
2929
OUTPUT_3x3_ZERO_PAD = OUTPUT_3x3.clone()
30-
OUTPUT_3x3_ZERO_PAD[0, 7, 0] = OUTPUT_3x3_ZERO_PAD[0, 9, 0] = 0.125
31-
OUTPUT_3x3_ZERO_PAD[0, 8, 0] = 0.25
32-
OUTPUT_3x3_ZERO_PAD[0, 7, -1] = OUTPUT_3x3_ZERO_PAD[0, 9, -1] = -0.125
33-
OUTPUT_3x3_ZERO_PAD[0, 8, -1] = -0.25
34-
OUTPUT_3x3_ZERO_PAD[1, 7, 0] = OUTPUT_3x3_ZERO_PAD[1, 7, -1] = 3.0 / 8.0
35-
OUTPUT_3x3_ZERO_PAD[1, 9, 0] = OUTPUT_3x3_ZERO_PAD[1, 9, -1] = -3.0 / 8.0
30+
OUTPUT_3x3_ZERO_PAD[1, 7, 0] = OUTPUT_3x3_ZERO_PAD[1, 9, 0] = 0.125
31+
OUTPUT_3x3_ZERO_PAD[1, 8, 0] = 0.25
32+
OUTPUT_3x3_ZERO_PAD[1, 7, -1] = OUTPUT_3x3_ZERO_PAD[1, 9, -1] = -0.125
33+
OUTPUT_3x3_ZERO_PAD[1, 8, -1] = -0.25
34+
OUTPUT_3x3_ZERO_PAD[0, 7, 0] = OUTPUT_3x3_ZERO_PAD[0, 7, -1] = 3.0 / 8.0
35+
OUTPUT_3x3_ZERO_PAD[0, 9, 0] = OUTPUT_3x3_ZERO_PAD[0, 9, -1] = -3.0 / 8.0
3636

3737
TEST_CASE_0 = [IMAGE, {"kernel_size": 3, "dtype": torch.float32}, OUTPUT_3x3]
3838
TEST_CASE_1 = [IMAGE, {"kernel_size": 3, "dtype": torch.float64}, OUTPUT_3x3]
@@ -68,7 +68,7 @@
6868
"spatial_axes": (0, 1),
6969
"dtype": torch.float64,
7070
},
71-
torch.cat([OUTPUT_3x3[0:1], OUTPUT_3x3[1:2] + 0.5]),
71+
torch.cat([OUTPUT_3x3[0:1] + 0.5, OUTPUT_3x3[1:2]]),
7272
]
7373
TEST_CASE_10 = [ # Normalized gradients but non-normalized kernels
7474
IMAGE,
@@ -79,7 +79,7 @@
7979
"spatial_axes": (0, 1),
8080
"dtype": torch.float64,
8181
},
82-
torch.cat([OUTPUT_3x3[0:1], OUTPUT_3x3[1:2] + 0.5]),
82+
torch.cat([OUTPUT_3x3[0:1] + 0.5, OUTPUT_3x3[1:2]]),
8383
]
8484

8585
TEST_CASE_KERNEL_0 = [

tests/test_sobel_gradientd.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,17 +22,17 @@
2222

2323
# Output with reflect padding
2424
OUTPUT_3x3 = torch.zeros(2, 16, 16, dtype=torch.float32)
25-
OUTPUT_3x3[1, 7, :] = 0.5
26-
OUTPUT_3x3[1, 9, :] = -0.5
25+
OUTPUT_3x3[0, 7, :] = 0.5
26+
OUTPUT_3x3[0, 9, :] = -0.5
2727

2828
# Output with zero padding
2929
OUTPUT_3x3_ZERO_PAD = OUTPUT_3x3.clone()
30-
OUTPUT_3x3_ZERO_PAD[0, 7, 0] = OUTPUT_3x3_ZERO_PAD[0, 9, 0] = 0.125
31-
OUTPUT_3x3_ZERO_PAD[0, 8, 0] = 0.25
32-
OUTPUT_3x3_ZERO_PAD[0, 7, -1] = OUTPUT_3x3_ZERO_PAD[0, 9, -1] = -0.125
33-
OUTPUT_3x3_ZERO_PAD[0, 8, -1] = -0.25
34-
OUTPUT_3x3_ZERO_PAD[1, 7, 0] = OUTPUT_3x3_ZERO_PAD[1, 7, -1] = 3.0 / 8.0
35-
OUTPUT_3x3_ZERO_PAD[1, 9, 0] = OUTPUT_3x3_ZERO_PAD[1, 9, -1] = -3.0 / 8.0
30+
OUTPUT_3x3_ZERO_PAD[1, 7, 0] = OUTPUT_3x3_ZERO_PAD[1, 9, 0] = 0.125
31+
OUTPUT_3x3_ZERO_PAD[1, 8, 0] = 0.25
32+
OUTPUT_3x3_ZERO_PAD[1, 7, -1] = OUTPUT_3x3_ZERO_PAD[1, 9, -1] = -0.125
33+
OUTPUT_3x3_ZERO_PAD[1, 8, -1] = -0.25
34+
OUTPUT_3x3_ZERO_PAD[0, 7, 0] = OUTPUT_3x3_ZERO_PAD[0, 7, -1] = 3.0 / 8.0
35+
OUTPUT_3x3_ZERO_PAD[0, 9, 0] = OUTPUT_3x3_ZERO_PAD[0, 9, -1] = -3.0 / 8.0
3636

3737
TEST_CASE_0 = [{"image": IMAGE}, {"keys": "image", "kernel_size": 3, "dtype": torch.float32}, {"image": OUTPUT_3x3}]
3838
TEST_CASE_1 = [{"image": IMAGE}, {"keys": "image", "kernel_size": 3, "dtype": torch.float64}, {"image": OUTPUT_3x3}]
@@ -86,7 +86,7 @@
8686
"normalize_gradients": True,
8787
"dtype": torch.float32,
8888
},
89-
{"image": torch.cat([OUTPUT_3x3[0:1], OUTPUT_3x3[1:2] + 0.5])},
89+
{"image": torch.cat([OUTPUT_3x3[0:1] + 0.5, OUTPUT_3x3[1:2]])},
9090
]
9191
TEST_CASE_11 = [ # Normalized gradients but non-normalized kernels
9292
{"image": IMAGE},
@@ -98,7 +98,7 @@
9898
"normalize_gradients": True,
9999
"dtype": torch.float32,
100100
},
101-
{"image": torch.cat([OUTPUT_3x3[0:1], OUTPUT_3x3[1:2] + 0.5])},
101+
{"image": torch.cat([OUTPUT_3x3[0:1] + 0.5, OUTPUT_3x3[1:2]])},
102102
]
103103

104104
TEST_CASE_KERNEL_0 = [

0 commit comments

Comments
 (0)