PyTorch & Lightning Integration

The Datamint Python API provides seamless integration with PyTorch and PyTorch Lightning, enabling efficient machine learning workflows for medical imaging tasks.

Overview

Key integration features:

  • DatamintDataModule: Lightning-compatible data module

  • MLFlowModelCheckpoint: Advanced model checkpointing with MLflow integration

  • Automatic Experiment Tracking: Seamless logging and model registration

  • Medical Image Optimizations: Specialized handling for medical data formats

PyTorch Dataset Integration

Basic PyTorch Usage

import torch
from torch.utils.data import DataLoader
from datamint import Dataset

# Load dataset. This is a PyTorch-compatible dataset that can be used directly.
dataset = Dataset(
    project_name="liver-segmentation",
    return_annotations=True,
    return_frame_by_frame=True,
    include_unannotated=False
)

# Create PyTorch DataLoader
dataloader = DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4,
    collate_fn=dataset.get_collate_fn()
)

# Training loop
for batch in dataloader:
    images = batch['image']      # Shape: [B, C, H, W]
    masks = batch['segmentation'] # Shape: [B, H, W]
    metadata = batch['metainfo']  # List of dicts
    # (...)

Dataset Transforms

Apply transforms for data augmentation and preprocessing:

import datamint
import torch
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader


class XrayFractureDataset(datamint.Dataset):
    def __getitem__(self, idx):
        image, dicom_metainfo, metainfo = super().__getitem__(idx)

        # Get all relevant information from the dicom_metainfo object
        patient_sex = dicom_metainfo.PatientSex

        # Get all relevant information from the metainfo object
        has_fracture = 'fracture' in metainfo['labels']
        has_fracture = torch.tensor(has_fracture, dtype=torch.int32)

        return image, patient_sex, has_fracture


# Create an instance of your custom dataset
dataset = XrayFractureDataset(root='data',
                              dataset_name='YOUR_DATASET_NAME',
                              version='latest',
                              api_key='my_api_key',
                              transform=ToTensor())

# Create a DataLoader to handle batching and shuffling of the dataset
dataloader = DataLoader(dataset,
                        batch_size=4,
                        shuffle=True)

for images, patients_sex, labels in dataloader:
    images = images.to(device)
    # labels will already be a tensor of shape (batch_size,) containing 0s and 1s

    # (...) do something with the batch

Alternative code, if you want to load all the data and metadata:

import datamint
import torch
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader


# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Create an instance of the datamint.Dataset
dataset = datamint.Dataset(root='data',
                            dataset_name='TestCTdataset',
                            version='latest',
                            api_key='my_api_key',
                            transform=ToTensor()
                            )

# This function tells the dataloader how to group the items in a batch
def collate_fn(batch):
    images = [item[0] for item in batch]
    dicom_metainfo = [item[1] for item in batch]
    metainfo = [item[2] for item in batch]

    return torch.stack(images), dicom_metainfo, metainfo


# Create a DataLoader to handle batching and shuffling of the dataset
dataloader = DataLoader(dataset,
                        batch_size=4,
                        collate_fn=collate_fn,
                        shuffle=True)

for images, dicom_metainfo, metainfo in dataloader:
    images = images.to(device)
    metainfo = metainfo

    # (... do something with the batch)