PyTorch & Lightning Integration

The Datamint Python API provides seamless integration with PyTorch and PyTorch Lightning, enabling efficient machine learning workflows for medical imaging tasks.

Overview

Key integration features:

  • ImageDataset / VolumeDataset / VideoDataset: Modular, PyTorch-compatible datasets for 2D images, 3D volumes, and video sequences

  • DatamintDataModule: Lightning-compatible data module

  • MLFlowModelCheckpoint: Advanced model checkpointing with MLflow integration

  • Automatic Experiment Tracking: Seamless logging and model registration

  • Medical Image Optimizations: Specialized handling for medical data formats

PyTorch Dataset Integration

Basic PyTorch Usage

import torch
from torch.utils.data import DataLoader
from datamint.dataset import ImageDataset, VolumeDataset

# 2D images (X-rays, single-frame DICOM, PNG, JPEG, …)
dataset = ImageDataset(
    project="liver-classification",
    include_unannotated=False,
)

# 3D volumes (NIfTI, DICOM series, …)
# dataset = VolumeDataset(project="ct-liver-segmentation")

# Create a standard PyTorch DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4,
)

# Training loop – each batch is a dict
for batch in dataloader:
    images = batch['image']           # Shape: [B, C, H, W]
    segmentations = batch['segmentations']  # Shape varies by mode
    metadata = batch['metainfo']      # List of dicts
    # (...)

Dataset Transforms

Apply transforms for data augmentation and preprocessing:

import albumentations as A
import torch
from torch.utils.data import DataLoader
from datamint.dataset import ImageDataset


class XrayFractureDataset(ImageDataset):
    def __getitem__(self, idx):
        item = super().__getitem__(idx)

        # 'image' is a tensor of shape (C, H, W)
        image = item['image']

        has_fracture = 'fracture' in item.get('image_labels', [])
        label = torch.tensor(has_fracture, dtype=torch.int32)

        return image, label


# Create an instance of your custom dataset
dataset = XrayFractureDataset(
    project='MY_PROJECT_NAME',
    alb_transform=A.Compose([A.Resize(224, 224)]),
)

# Create a DataLoader to handle batching and shuffling of the dataset
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

for images, labels in dataloader:
    # images: (batch_size, C, 224, 224)
    # labels: (batch_size,)
    pass  # (...) do something with the batch

Loading all data and metadata:

from datamint.dataset import ImageDataset

# Create an instance of ImageDataset
dataset = ImageDataset(
    project='MY_PROJECT_NAME'
)
# Create a DataLoader
dataloader = dataset.get_dataloader(batch_size=4, shuffle=False) # Same parameters as PyTorch DataLoader

for batch in dataloader:
    images = batch['image']  # shape: (batch_size, C, H, W)
    segmentations = batch['segmentations']
    image_labels = batch['image_labels']
    image_categories = batch['image_categories']

    # (... do something with the batch)