Implementation of MONAI Coding for End-to-End 3D Spleen Segmentation Using UNet on Medical CT Volumes

In this tutorial, we build an end-to-end 3D medical image segmentation pipeline using MONAI segmenting the spleen in the Medical Segmentation Decathlon Task09 dataset. We work with volumetric CT scans, apply medical image transformations such as orientation alignment, voxel-spacing normalization, intensity windowing, foreground cropping, and patch-based sampling, and train a 3D NNet model for binary organ segmentation. We also use mixed precision training, DiceCE loss, sliding window inference, dice-based validation, and qualitative visualization to understand how the model learns and how its predictions compare to ground truth masks. And, we’re moving from a crude medical dose to a complete split-train–verify–visualize system.
!pip install -q "monai[nibabel,tqdm,matplotlib]==1.5.2" 2>/dev/null
import os, time, glob, tempfile, warnings
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.amp import autocast, GradScaler
from monai.apps import DecathlonDataset
from monai.data import DataLoader, decollate_batch
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.utils import set_determinism
from monai.transforms import (
Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd,
Spacingd, ScaleIntensityRanged, CropForegroundd, RandCropByPosNegLabeld,
RandFlipd, RandRotate90d, RandShiftIntensityd, AsDiscrete,
)
warnings.filterwarnings("ignore")
We start by applying MONAI depending on the required medical imaging and visualization. We then import PyTorch, NumPy, Matplotlib, and the core MONAI modules needed for datasets, transformations, model training, metrics, and interpretation. We also suppress warnings to keep notebook output clean while focusing on segment workflow.
QUICK_RUN = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
root_dir = tempfile.mkdtemp()
roi_size = (96, 96, 96)
num_samples = 4
batch_size = 2
max_epochs = 15 if QUICK_RUN else 200
val_every = 3
train_cache = 8 if QUICK_RUN else 24
val_cache = 2 if QUICK_RUN else 6
set_determinism(seed=0)
print(f"Device: {device} | epochs: {max_epochs} | data dir: {root_dir}")
train_transforms = Compose(common + [
image_key="image", image_threshold=0),
RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=0),
RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=1),
RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=2),
RandRotate90d(keys=["image", "label"], prob=0.2, max_k=3),
RandShiftIntensityd(keys=["image"], offsets=0.10, prob=0.5),
EnsureTyped(keys=["image", "label"]),
])
val_transforms = Compose(common + [EnsureTyped(keys=["image", "label"])])
We describe the main configuration of the tutorial, including device, dataset directory, patch size, batch size, number of epochs, and cache settings. We then created a preprocessing pipeline for the CT volumes by image loading, orientation alignment, resampling of voxel space, intensity scaling, and foreground cropping. We also describe training and validation changes, with a training route that includes random crops, rotations, rotations, and intensity shifts.
train_ds = DecathlonDataset(
root_dir=root_dir, task="Task09_Spleen", section="training",
transform=train_transforms, download=True, val_frac=0.2,
cache_num=train_cache, num_workers=2, seed=0)
val_ds = DecathlonDataset(
root_dir=root_dir, task="Task09_Spleen", section="validation",
transform=val_transforms, download=False, val_frac=0.2,
cache_num=val_cache, num_workers=2, seed=0)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
num_workers=2, pin_memory=torch.cuda.is_available())
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False,
num_workers=1, pin_memory=torch.cuda.is_available())
print(f"Train volumes: {len(train_ds)} | Val volumes: {len(val_ds)}")
loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
scaler = GradScaler("cuda", enabled=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=False, reduction="mean")
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])
We load the Medical Segmentation Decathlon Task09 Spleen dataset using MONAI’s DecathlonDataset. We split the data into training and validation phases, apply appropriate transformations, and wrap both datasets with PyTorch-style data loaders. We then create a 3D NNet model, define the DiceCE loss, set up the AdamW optimizer, learning rate planner, mixed precision scaler, dice metric, and post-processing steps.
best_dice, best_epoch = -1.0, -1
loss_hist, dice_hist, dice_epochs = [], [], []
best_path = os.path.join(root_dir, "best_spleen_unet.pth")
for epoch in range(1, max_epochs + 1):
model.train(); epoch_loss, t0 = 0.0, time.time()
for batch in train_loader:
x, y = batch["image"].to(device), batch["label"].to(device)
optimizer.zero_grad(set_to_none=True)
with autocast("cuda", enabled=torch.cuda.is_available()):
logits = model(x)
loss = loss_fn(logits, y)
scaler.scale(loss).backward()
scaler.step(optimizer); scaler.update()
epoch_loss += loss.item()
scheduler.step()
epoch_loss /= len(train_loader); loss_hist.append(epoch_loss)
print(f"[{epoch:3d}/{max_epochs}] loss={epoch_loss:.4f} "
f"lr={scheduler.get_last_lr()[0]:.2e} ({time.time()-t0:.0f}s)")
if epoch % val_every == 0 or epoch == max_epochs:
model.eval(); dice_metric.reset()
with torch.no_grad():
for vb in val_loader:
vx, vy = vb["image"].to(device), vb["label"].to(device)
with autocast("cuda", enabled=torch.cuda.is_available()):
vout = sliding_window_inference(vx, roi_size, 4, model,
overlap=0.5)
vout = [post_pred(o) for o in decollate_batch(vout)]
vlab = [post_label(o) for o in decollate_batch(vy)]
dice_metric(y_pred=vout, y=vlab)
d = dice_metric.aggregate().item()
dice_hist.append(d); dice_epochs.append(epoch)
if d > best_dice:
best_dice, best_epoch = d, epoch
torch.save(model.state_dict(), best_path)
print(f" >> val Dice={d:.4f} (best={best_dice:.4f} @ {best_epoch})")
print(f"nDone. Best mean Dice {best_dice:.4f} at epoch {best_epoch}.")
We use a full training loop, where each epoch trains the 3D Net on cut volume patches from the spleen dataset. We use automatic mixed precision to reduce memory usage and speed up training when the GPU is available. We also periodically validate the model using sliding window inference, keep track of the Dice score, and save the best-performing test space.
fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(range(1, len(loss_hist)+1), loss_hist, "-o", ms=3)
ax[0].set(title="Training loss", xlabel="epoch", ylabel="DiceCE loss")
ax[1].plot(dice_epochs, dice_hist, "-o", color="seagreen", ms=4)
ax[1].set(title="Validation mean Dice", xlabel="epoch", ylabel="Dice"); ax[1].set_ylim(0, 1)
plt.tight_layout(); plt.show()
model.load_state_dict(torch.load(best_path, map_location=device)); model.eval()
with torch.no_grad():
sample = next(iter(val_loader))
img = sample["image"].to(device)
with autocast("cuda", enabled=torch.cuda.is_available()):
pred = sliding_window_inference(img, roi_size, 4, model, overlap=0.5)
pred = torch.argmax(pred, dim=1).cpu().numpy()[0]
img_np, lab_np = img.cpu().numpy()[0, 0], sample["label"].numpy()[0, 0]
z = int(np.argmax(lab_np.sum(axis=(0, 1))))
fig, ax = plt.subplots(1, 3, figsize=(13, 5))
ax[0].imshow(img_np[:, :, z], cmap="gray"); ax[0].set_title("CT slice")
ax[1].imshow(lab_np[:, :, z], cmap="viridis"); ax[1].set_title("Ground truth")
ax[2].imshow(pred[:, :, z], cmap="viridis"); ax[2].set_title("Prediction")
for a in ax: a.axis("off")
plt.tight_layout(); plt.show()
We start plotting the training loss and validating Dice points to see how the model improves over time. We then reload the best-preserved model test area and apply one-volume validation using sliding window projection. We visualize a CT slice, ground truth mask, and predict segmentation laterally to test the quality performance of the model.
In conclusion, we completed a MONAI-based workflow for 3D spleen segmentation using the 3D UNet model. We prepared the Medical Segmentation Decathlon dataset, transformed and generated CT volumes, trained the model on DiceCE loss, validated using sliding window inference, and tracked both loss and Dice score over time. We also tested the final prediction by comparing the CT slice, the ground truth label, and the model output in parallel. Now, we have a clear understanding of how MONAI supports medical classification tasks from data loading and preprocessing to model training, testing, evaluation, and quality analysis.
Check it out Full Codes with notebook. Also, feel free to follow us Twitter and don’t forget to join our 150k+ ML SubReddit and Subscribe to Our newspaper. Wait! are you on telegram? now you can join us on telegram too.
Need to work with us on developing your GitHub Repo OR Hug Face Page OR Product Release OR Webinar etc.?contact us
The post Implementation of Coding in MONAI for End-to-End 3D Spleen Using UNet in Medical CT Volumes appeared first on MarkTechPost.



