I would like to perform binary segmentation. How can I use the DiceFocalLoss function? #1568
Unanswered
BelieferQAQ
asked this question in
Q&A
Replies: 1 comment 1 reply
-
Hi @BelieferQAQ, for the imbalance issue, you might consider adding weight in (convert to discussion now since it's not a issue, feel free to create another one if meet any issues) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, I am currently working on lymph node segmentation. However, the lymph nodes are very small, and they are labeled with voxel value 1. Therefore, I want to use the DiceFocalLoss function. The following code snippet is part of my current implementation, but I am not getting satisfactory training results. I would appreciate your advice.
num_samples = 2
device = torch.device("cuda:1")
train_transforms = Compose(
[
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
ScaleIntensityRanged(
keys=["image"],
a_min=-175,
a_max=250,
b_min=0.0,
b_max=1.0,
clip=True,
),
CropForegroundd(keys=["image", "label"], source_key="image"),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
EnsureTyped(keys=["image", "label"], device=device, track_meta=False),
RandCropByPosNegLabeld(
keys=["image", "label"],
label_key="label",
spatial_size=(96, 96, 96),
pos=1,
neg=1,
num_samples=num_samples,
image_key="image",
image_threshold=0,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[0],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[1],
prob=0.10,
),
RandFlipd(
keys=["image", "label"],
spatial_axis=[2],
prob=0.10,
),
RandRotate90d(
keys=["image", "label"],
prob=0.10,
max_k=3,
),
RandShiftIntensityd(
keys=["image"],
offsets=0.10,
prob=0.50,
),
]
)
val_transforms = Compose(
[
LoadImaged(keys=["image", "label"], ensure_channel_first=True),
ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
CropForegroundd(keys=["image", "label"], source_key="image"),
Orientationd(keys=["image", "label"], axcodes="RAS"),
Spacingd(
keys=["image", "label"],
pixdim=(1.5, 1.5, 2.0),
mode=("bilinear", "nearest"),
),
EnsureTyped(keys=["image", "label"], device=device, track_meta=True),
]
)
model = SwinUNETR(
img_size=(96, 96, 96),
in_channels=1,
out_channels=2,
feature_size=48,
use_checkpoint=True,
).to(device)
torch.backends.cudnn.benchmark = True
loss_function = DiceCELoss(to_onehot_y=True, softmax=True)
loss_function = DiceFocalLoss(
include_background=True,
to_onehot_y=True,
sigmoid=False,
softmax=False,
other_act=None,
squared_pred=False,
jaccard=False,
reduction="mean",
smooth_nr=1e-5,
smooth_dr=1e-5,
batch=False,
gamma=2.0,
focal_weight=None,
lambda_dice=1.0,
lambda_focal=1.0
)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scaler = torch.cuda.amp.GradScaler()
def validation(epoch_iterator_val):
model.eval()
with torch.no_grad():
for batch in epoch_iterator_val:
val_inputs, val_labels = (batch["image"].to(device), batch["label"].to(device))
with torch.cuda.amp.autocast():
val_outputs = sliding_window_inference(val_inputs, (96, 96, 96), 4, model)
val_labels_list = decollate_batch(val_labels)
val_labels_convert = [post_label(val_label_tensor) for val_label_tensor in val_labels_list]
val_outputs_list = decollate_batch(val_outputs)
val_output_convert = [post_pred(val_pred_tensor) for val_pred_tensor in val_outputs_list]
dice_metric(y_pred=val_output_convert, y=val_labels_convert)
epoch_iterator_val.set_description("Validate (%d / %d Steps)" % (global_step, 10.0))
mean_dice_val = dice_metric.aggregate().item()
dice_metric.reset()
return mean_dice_val
def train(global_step, train_loader, dice_val_best, global_step_best):
model.train()
epoch_loss = 0
step = 0
epoch_iterator = tqdm(train_loader, desc="Training (X / X Steps) (loss=X.X)", dynamic_ncols=True)
for step, batch in enumerate(epoch_iterator):
step += 1
x, y = (batch["image"].to(device), batch["label"].to(device))
with torch.cuda.amp.autocast():
logit_map = model(x)
loss = loss_function(logit_map, y)
scaler.scale(loss).backward()
epoch_loss += loss.item()
scaler.unscale_(optimizer)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
epoch_iterator.set_description(f"Training ({global_step} / {max_iterations} Steps) (loss={loss:2.5f})")
if (global_step % eval_num == 0 and global_step != 0) or global_step == max_iterations:
epoch_iterator_val = tqdm(val_loader, desc="Validate (X / X Steps) (dice=X.X)", dynamic_ncols=True)
dice_val = validation(epoch_iterator_val)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
metric_values.append(dice_val)
if dice_val > dice_val_best:
dice_val_best = dice_val
global_step_best = global_step
torch.save(model.state_dict(), os.path.join(root_dir, "best_metric_S11107_focal_dice_model.pth"))
print(
"Model Was Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(dice_val_best, dice_val)
)
else:
print(
"Model Was Not Saved ! Current Best Avg. Dice: {} Current Avg. Dice: {}".format(
dice_val_best, dice_val
)
)
global_step += 1
return global_step, dice_val_best, global_step_best
root_dir = "/data/jupyter/wd/swinUnetr-lymphNode/save_models_swin_1107"
max_iterations = 60000
eval_num = 200
post_label = AsDiscrete(to_onehot=2)
post_pred = AsDiscrete(argmax=True, to_onehot=2)
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
global_step = 0
dice_val_best = 0.0
global_step_best = 0
epoch_loss_values = []
metric_values = []
while global_step < max_iterations:
global_step, dice_val_best, global_step_best = train(global_step, train_loader, dice_val_best, global_step_best)
Beta Was this translation helpful? Give feedback.
All reactions