PyTorch & Lightning Integration
The Datamint Python API provides seamless integration with PyTorch and PyTorch Lightning, enabling efficient machine learning workflows for medical imaging tasks.
Overview
Key integration features:
ImageDataset / VolumeDataset / VideoDataset: Modular, PyTorch-compatible datasets for 2D images, 3D volumes, and video sequences
DatamintDataModule: Lightning-compatible data module
Trainer API: Task-focused trainers such as
UNetPPTrainerandSemanticSegmentation2DTrainerMLFlowModelCheckpoint: Advanced model checkpointing with MLflow integration
Automatic Experiment Tracking: Seamless logging and model registration
Medical Image Optimizations: Specialized handling for medical data formats
PyTorch Dataset Integration
Basic PyTorch Usage
import torch
from torch.utils.data import DataLoader
from datamint.dataset import ImageDataset, VolumeDataset
# 2D images (X-rays, single-frame DICOM, PNG, JPEG, …)
dataset = ImageDataset(
project="liver-classification",
include_unannotated=False,
)
# 3D volumes (NIfTI, DICOM series, …)
# dataset = VolumeDataset(project="ct-liver-segmentation")
# Create a standard PyTorch DataLoader
dataloader = DataLoader(
dataset,
batch_size=16,
shuffle=True,
num_workers=4,
)
# Training loop – each batch is a dict
for batch in dataloader:
images = batch['image'] # Shape: [B, C, H, W]
segmentations = batch['segmentations'] # Shape varies by mode
metadata = batch['metainfo'] # List of dicts
# (...)
Dataset Transforms
Apply transforms for data augmentation and preprocessing:
import albumentations as A
import torch
from torch.utils.data import DataLoader
from datamint.dataset import ImageDataset
class XrayFractureDataset(ImageDataset):
def __getitem__(self, idx):
item = super().__getitem__(idx)
# 'image' is a tensor of shape (C, H, W)
image = item['image']
has_fracture = 'fracture' in item.get('image_labels', [])
label = torch.tensor(has_fracture, dtype=torch.int32)
return image, label
# Create an instance of your custom dataset
dataset = XrayFractureDataset(
project='MY_PROJECT_NAME',
alb_transform=A.Compose([A.Resize(224, 224)]),
)
# Create a DataLoader to handle batching and shuffling of the dataset
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
for images, labels in dataloader:
# images: (batch_size, C, 224, 224)
# labels: (batch_size,)
pass # (...) do something with the batch
Loading all data and metadata:
from datamint.dataset import ImageDataset
# Create an instance of ImageDataset
dataset = ImageDataset(
project='MY_PROJECT_NAME'
)
# Create a DataLoader
dataloader = dataset.get_dataloader(batch_size=4, shuffle=False) # Same parameters as PyTorch DataLoader
for batch in dataloader:
images = batch['image'] # shape: (batch_size, C, H, W)
segmentations = batch['segmentations']
image_labels = batch['image_labels']
image_categories = batch['image_categories']
# (... do something with the batch)
Split Reproducibility
Use split()
for using split assignments.
The resolved split datasets store the effective historical
snapshot timestamp in split_as_of_timestamp, which you can pass back into
future training runs to reuse the exact same assignment state.
from datamint.dataset import ImageDataset
from datamint.lightning import DatamintDataModule
dataset = ImageDataset(project="my-project", include_unannotated=True)
parts = dataset.split()
snapshot = parts["train"].split_as_of_timestamp
datamodule = DatamintDataModule(
dataset,
split=True,
split_as_of_timestamp=snapshot,
)
DatamintDataModule and the built-in trainers propagate the resolved
split_source and split_as_of_timestamp values into MLflow lineage, so
you can trace which project split snapshot was used during training and replay
it later.
Trainer API
Use the Trainer API when you want Datamint to build the dataset, datamodule, default model, MLflow logger, and checkpoint callbacks for you.
from datamint.lightning import UNetPPTrainer
trainer = UNetPPTrainer(
project="BUSI_Segmentation",
image_size=256,
batch_size=16,
max_epochs=20,
accelerator="auto",
)
results = trainer.fit()
The trainer layer is also the recommended way to integrate an external model architecture while still reusing Datamint’s dataset handling and MLflow workflow.
For a side-by-side comparison of raw PyTorch / Lightning code versus Datamint, see datamint_vs_raw_pytorch.
See Trainer API for more details.