Pytorch integration

Before continuing, you may want to check the Setup API key section to easily set up your API key, if you haven’t done so yet.

Dataset

Datamint provides a custom PyTorch dataset class that can be used to load data from the server in a PyTorch-friendly way. To use it, import the DatamintDataset class and create an instance of it, passing the necessary parameters.

from datamint import Dataset

dataset = Dataset('../data',
                  project_name='MyProjectName', # Must exists in the server
                  # return_frame_by_frame=True, # Optional, if you want each item to be a frame instead of a video/3d-image
                 )

and then use it in your PyTorch code as usual.

Here is a complete example that inherits DatamintDataset:

import datamint
import torch
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader


class XrayFractureDataset(datamint.Dataset):
    def __getitem__(self, idx):
        image, dicom_metainfo, metainfo = super().__getitem__(idx)

        # Get all relevant information from the dicom_metainfo object
        patient_sex = dicom_metainfo.PatientSex

        # Get all relevant information from the metainfo object
        has_fracture = 'fracture' in metainfo['labels']
        has_fracture = torch.tensor(has_fracture, dtype=torch.int32)

        return image, patient_sex, has_fracture


# Create an instance of your custom dataset
dataset = XrayFractureDataset(root='data',
                              dataset_name='YOUR_DATASET_NAME',
                              version='latest',
                              api_key='my_api_key',
                              transform=ToTensor())

# Create a DataLoader to handle batching and shuffling of the dataset
dataloader = DataLoader(dataset,
                        batch_size=4,
                        shuffle=True)

for images, patients_sex, labels in dataloader:
    images = images.to(device)
    # labels will already be a tensor of shape (batch_size,) containing 0s and 1s

    # (...) do something with the batch

Alternative code, if you want to load all the data and metadata:

import datamint
import torch
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader


# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# Create an instance of the datamint.Dataset
dataset = datamint.Dataset(root='data',
                            dataset_name='TestCTdataset',
                            version='latest',
                            api_key='my_api_key',
                            transform=ToTensor()
                            )

# This function tells the dataloader how to group the items in a batch
def collate_fn(batch):
    images = [item[0] for item in batch]
    dicom_metainfo = [item[1] for item in batch]
    metainfo = [item[2] for item in batch]

    return torch.stack(images), dicom_metainfo, metainfo


# Create a DataLoader to handle batching and shuffling of the dataset
dataloader = DataLoader(dataset,
                        batch_size=4,
                        collate_fn=collate_fn,
                        shuffle=True)

for images, dicom_metainfo, metainfo in dataloader:
    images = images.to(device)
    metainfo = metainfo

    # (... do something with the batch)