Datamint vs Raw PyTorch

This page shows how Datamint removes boilerplate from common medical-imaging training workflows by comparing side-by-side a raw PyTorch / Lightning setup with the equivalent Datamint code. All examples target 2-D semantic segmentation (e.g. BUSI, skin-dataset, or liver-segmentation projects).

Workflow comparison at a glance

Responsibility

Raw PyTorch / Lightning

Datamint

📂 Data loading
Write a custom torch.utils.data.Dataset subclass that
reads files, decodes DICOM / NIfTI / PNG,
applies augmentations, and returns aligned tensors.
One line: ImageDataset
handles DICOM series, NIfTI, PNG, JPEG, and annotation parsing
automatically.
📊 Train / val / test splits
Manually partition resource IDs, tag them,
or write a split function;
must track split manually to
ensure reproducibility with seeds.
split() resolves
project-scoped splits (or falls back to legacy split:* tags) and
returns a snapshot timestamp you can replay later.
🔌 DataModule wiring
Implement lightning.pytorch.core.LightningDataModule with
prepare_data, setup, train_dataloader, val_dataloader,
test_dataloader.
DatamintDataModule wraps any
DatamintBaseDataset and provides all
dataloaders out of the box.
🧱 Model scaffolding
Subclass lightning.pytorch.LightningModule; write
forward, training_step, validation_step, test_step,
configure_optimizers.
SegmentationModule
provides forward, predict, and predict_batch with
Datamint-native inference. Loss and metrics are injected automatically
when you pass the class (not an instance) to model=.
⚙️ Trainer configuration
Instantiate lightning.pytorch.trainer.Trainer with
max_epochs, accelerator, devices, precision, logger,
and checkpoint callbacks.
builds the dataset, datamodule, model, MLflow logger, and checkpoint
callbacks for you. Extra kwargs are forwarded to
lightning.pytorch.trainer.Trainer.
📈 Experiment tracking
Configure lightning.pytorch.loggers.MLFlowLogger or
lightning.pytorch.loggers.TensorBoardLogger; log
hyper-parameters, metrics, and artifacts manually.
MLflow is auto-configured on first import of
datamint.mlflow. The trainer logs metrics, hyper-parameters,
and checkpoints automatically; register_model=True registers the
final artifact in MLflow.
🚀 Model deployment
Export the checkpoint, write inference code,
wrap it in an MLflow mlflow.pyfunc.PythonModel,
and deploy to your serving platform.
Trained model already a DatamintModel
that can be deployed via the DeployModelApi API.

Example 1 – Dataset and split setup

  1import os
  2from pathlib import Path
  3from typing import List, Tuple
  4
  5import albumentations as A
  6import numpy as np
  7import torch
  8from PIL import Image
  9from torch.utils.data import Dataset, DataLoader
 10
 11
 12class SegmentationDataset(Dataset):
 13    """Manually wired dataset for 2-D segmentation."""
 14
 15    def __init__(
 16        self,
 17        image_paths: List[Path],
 18        mask_paths: List[Path],
 19        transform: A.Compose,
 20    ):
 21        self.image_paths = image_paths
 22        self.mask_paths = mask_paths
 23        self.transform = transform
 24
 25    def __len__(self) -> int:
 26        return len(self.image_paths)
 27
 28    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
 29        # -- File-type detection and decoding ---------------------------------
 30        ext = self.image_paths[idx].suffix.lower()
 31
 32        if ext in (".dcm", ".dicom"):
 33            # DICOM: requires pydicom to read, then handle pixel spacing,
 34            # rescale slope/intercept, windowing, and multi-frame series.
 35            import pydicom
 36            ds = pydicom.dcmread(str(self.image_paths[idx]))
 37            img = ds.pixel_array.astype(float)
 38            # Interpret which shape dimension is the channel, which varies across dicoms and may require manual handling.
 39            # (...) This is a common source of bugs
 40
 41        elif ext == ".nii" or ext == ".nii.gz":
 42            # NIfTI: requires nibabel or nitransforms; handles 3-D volumes,
 43            # custom affine transforms, and voxel-to-RAS mapping.
 44            import nibabel as nib
 45            nifti_img = nib.load(str(self.image_paths[idx]))
 46            img = nifti_img.get_fdata().astype(np.float32)
 47            # For 3-D volumes you must slice or process the full volume.
 48            # This is non-trivial for segmentation tasks since you must ensure alignment with the mask and
 49            # know which axis is the correct one to slice on.
 50            # (...)
 51
 52        elif ext in (".png", ".jpg", ".jpeg"):
 53            img = Image.open(self.image_paths[idx]).convert("RGB")
 54            img = np.array(img).astype(np.float32)
 55        else:
 56            raise ValueError(f"Unsupported image format: {ext}")
 57
 58        # -- Mask loading (similar complexity for medical formats) -----------
 59        mask_ext = self.mask_paths[idx].suffix.lower()
 60        if mask_ext in (".nii", ".nii.gz"):
 61            nifti_mask = nib.load(str(self.mask_paths[idx]))
 62            mask = nifti_mask.get_fdata().astype(np.int64)
 63        elif mask_ext in (".png",):
 64            mask = np.array(Image.open(self.mask_paths[idx]).convert("L")).astype(np.int64)
 65        else:
 66            raise ValueError(f"Unsupported mask format: {mask_ext}")
 67
 68        # -- Augmentation -----------------------------------------------------
 69        augmented = self.transform(image=img, mask=mask)
 70        image = torch.from_numpy(augmented["image"]).float().permute(2, 0, 1)
 71        mask = torch.from_numpy(augmented["mask"]).long()
 72
 73        return image, mask
 74
 75
 76# -- Split logic (manual) -------------------------------------------
 77all_paths = list(Path("data/images").iterdir())
 78np.random.seed(42)
 79np.random.shuffle(all_paths)
 80train_paths = all_paths[:140]
 81val_paths = all_paths[140:160]
 82test_paths = all_paths[160:]
 83
 84# -- Training transform (with augmentations) ------------------------
 85train_transform = A.Compose([
 86    A.HorizontalFlip(p=0.5),
 87    A.RandomBrightnessContrast(p=0.3),
 88    A.Normalize(),
 89])
 90
 91train_ds = SegmentationDataset(
 92    train_paths,
 93    [Path("data/masks") / p.name for p in train_paths],
 94    train_transform,
 95)
 96train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4)
 97
 98# -- Eval transform (no augmentations) ------------------------------
 99eval_transform = A.Compose([A.Normalize()])
100
101val_ds = SegmentationDataset(
102    val_paths,
103    [Path("data/masks") / p.name for p in val_paths],
104    eval_transform,
105)
106val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4)
107
108test_ds = SegmentationDataset(
109    test_paths,
110    [Path("data/masks") / p.name for p in test_paths],
111    eval_transform,
112)
113test_loader = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=4)
 1import albumentations as A
 2from datamint.dataset import ImageDataset
 3from datamint.lightning import DatamintDataModule
 4
 5# 1. Load the project -- Datamint resolves all images and masks.
 6dataset = ImageDataset(project="BUSI_Segmentation")
 7
 8# 2. Define transforms per stage.
 9train_tfm = A.Compose([
10    A.HorizontalFlip(p=0.5),
11    A.RandomBrightnessContrast(p=0.3),
12    A.Normalize(),
13])
14eval_tfm = A.Compose([A.Normalize()])
15
16# 3. Wrap in a DataModule -- splitting and dataloaders are handled.
17dm = DatamintDataModule(
18    dataset,
19    batch_size=16,
20    split={'train': 0.7, 'val': 0.1, 'test': 0.1},
21    split_seed=42,
22    train_transform=train_tfm,
23    eval_transform=eval_tfm,
24)
25
26# Call prepare_data() / setup() before accessing dataloaders.
27dm.prepare_data()
28
29train_loader = dm.train_dataloader()
30val_loader = dm.val_dataloader()
31test_loader = dm.test_dataloader()

Aspect

Raw PyTorch

Datamint

Dataset class

SegmentationDataset (30-50 lines PNG/JPEG) (100-150 lines DICOM/NIfTI)

ImageDataset(project="...") (1 line)

Train/val/test splits

Manual split logic + 3 instances (20-30 lines)

split={'train': 0.7, ...} (inline)

DataModule

Custom LightningDataModule (40-60 lines)

DatamintDataModule(dataset, ...) (1 line)

Total boilerplate

~120-180 lines (PNG/JPEG) ~200-300 lines (DICOM/NIfTI)

~5-10 lines

Example 2 – Full training loop

 1import lightning as L
 2import segmentation_models_pytorch as smp
 3import torch
 4
 5
 6class SegmentationModel(L.LightningModule):
 7    def __init__(self, num_classes: int = 2):
 8        super().__init__()
 9        self.model = smp.Unet(
10            encoder_name="resnet34",
11            encoder_weights="imagenet",
12            in_channels=3,
13            classes=num_classes,
14        )
15        self.loss_fn = torch.nn.CrossEntropyLoss()
16
17    def forward(self, x):
18        return self.model(x)
19
20    def training_step(self, batch, batch_idx):
21        images, masks = batch
22        logits = self(images)
23        loss = self.loss_fn(logits, masks)
24        self.log("train/loss", loss)
25        return loss
26
27    def validation_step(self, batch, batch_idx):
28        images, masks = batch
29        logits = self(images)
30        loss = self.loss_fn(logits, masks)
31        self.log("val/loss", loss)
32        return loss
33
34    def test_step(self, batch, batch_idx):
35        images, masks = batch
36        logits = self(images)
37        loss = self.loss_fn(logits, masks)
38        self.log("test/loss", loss)
39        return loss
40
41    def configure_optimizers(self):
42        return torch.optim.Adam(self.parameters(), lr=1e-4)
43
44
45# -- Trainer setup --------------------------------------------------
46model = SegmentationModel(num_classes=2)
47trainer = L.Trainer(
48    max_epochs=20,
49    accelerator="auto",
50    devices=1,
51    precision="16-mixed",
52    callbacks=[
53        L.pytorch.callbacks.ModelCheckpoint(
54            monitor="val/loss",
55            save_top_k=1,
56            mode="min",
57        )
58    ],
59    logger=L.pytorch.loggers.MLFlowLogger(experiment_name="BUSI_Segmentation"),
60)
61
62trainer.fit(model, datamodule=dm)
63trainer.test(model, datamodule=dm)
 1from datamint.lightning import UNetPPTrainer
 2
 3trainer = UNetPPTrainer(
 4    project="BUSI_Segmentation",
 5    image_size=256,
 6    batch_size=16,
 7    max_epochs=20,
 8    accelerator="auto",
 9    # All extra kwargs are forwarded to lightning.Trainer.
10    precision="16-mixed",
11    devices=1,
12)
13
14results = trainer.fit()
15trainer.test()  # evaluates on the test split
16
17# The model is already registered in MLflow.
18print(results["test_results"])

Aspect

Raw PyTorch

Datamint

Model definition

SegmentationModel (50-80 lines)

Built internally by trainer

Training step

Manual training_step with loss logging

Automatic loss injection via SegmentationModule

Validation step

Manual validation_step with loss

Automatic

Test step

Manual test_step with loss

Automatic

Optimizer

Manual configure_optimizers

Automatic

Trainer setup

Trainer + logger + callbacks (15-25 lines)

Passed as kwargs to trainer

Note

Datamint’s trainer handles dataset creation, datamodule wiring, model instantiation, MLflow logger setup, and checkpoint callbacks automatically — all from a single UNetPPTrainer call.

Example 3 – Inference & deployment

After training, Datamint models are already registered in MLflow and can be loaded and used for inference with zero extra code. The built-in PredictionMode system supports image, slice, volume, frame, and interactive prediction modes.

 1import os
 2from pathlib import Path
 3import numpy as np
 4import torch
 5from PIL import Image
 6import pydicom
 7import nibabel as nib
 8import mlflow
 9
10
11# 1. Load model from MLflow ----------------------------------------
12model_uri = "runs:/abc123/artifacts/model"
13model = mlflow.pyfunc.load_model(model_uri)
14
15# 2. Manual preprocessing per format --------------------------------
16def preprocess_image(image_path: Path) -> torch.Tensor:
17    img = Image.open(image_path).convert("RGB")
18    img = np.array(img).astype(np.float32) / 255.0
19    img = (img - 0.5) / 0.5  # manual normalization
20    return torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
21
22
23def preprocess_dicom(dicom_path: Path) -> torch.Tensor:
24    ds = pydicom.dcmread(str(dicom_path))
25    img = ds.pixel_array.astype(np.float32)
26    # Handle rescale slope/intercept
27    if hasattr(ds, "RescaleSlope"):
28        img = img * ds.RescaleSlope + ds.RescaleIntercept
29    # Handle windowing
30    if hasattr(ds, "WindowCenter") and hasattr(ds, "WindowWidth"):
31        wc = ds.WindowCenter[0] if hasattr(ds.WindowCenter, "__iter__") else ds.WindowCenter
32        ww = ds.WindowWidth[0] if hasattr(ds.WindowWidth, "__iter__") else ds.WindowWidth
33        img = (img - (wc - ww / 2)) / (ww / 2)
34    img = np.clip(img, -1, 1)
35    return torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
36
37
38def preprocess_nifti(nifti_path: Path) -> torch.Tensor:
39    nifti_img = nib.load(str(nifti_path))
40    img = nifti_img.get_fdata().astype(np.float32)
41    # Must manually handle 3-D → 2-D slicing
42    img = img[:, :, img.shape[2] // 2]  # center slice
43    img = (img - img.mean()) / (img.std() + 1e-8)
44    return torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
45
46
47# 3. Manual post-processing -----------------------------------------
48def postprocess_logits(logits: torch.Tensor) -> np.ndarray:
49    probs = torch.softmax(logits, dim=1).cpu().numpy()
50    mask = np.argmax(probs, axis=1).squeeze()
51    return (mask * 255).astype(np.uint8)
52
53
54# 4. Run inference ---------------------------------------------------
55# For a PNG image:
56tensor = preprocess_image(Path("test_image.png"))
57logits = model.predict(tensor)
58result_mask = postprocess_logits(logits)
59
60# For a DICOM file (different preprocessing!):
61tensor = preprocess_dicom(Path("test_dicom.dcm"))
62logits = model.predict(tensor)
63result_mask = postprocess_logits(logits)
64
65# For a NIfTI volume (yet another preprocessing!):
66tensor = preprocess_nifti(Path("test_volume.nii.gz"))
67logits = model.predict(tensor)
68result_mask = postprocess_logits(logits)
 1import mlflow
 2from datamint.entities import LocalResource
 3from datamint.mlflow.flavors import load_model
 4
 5# 1. Load model from MLflow -------------------
 6model_uri = 'models:/UNetPP_Segmentation_Tutorial/latest'
 7model = mlflow.pyfunc.load_model(model_uri)
 8
 9# 2. Predict on each format through the same API -------------------
10# For a PNG image:
11result = model.predict([LocalResource("test-image.png")])
12
13# For a DICOM file:
14result = model.predict([LocalResource("test_dicom.dcm")])
15
16# For a NIfTI volume (and slice prediction):
17model = load_model(model_uri)  # Use Datamint's MLflow flavor to get the extended API
18result = model.predict_slice(
19    model_input=[LocalResource("test_dicom.nii.gz")],
20    slice_index=10,
21    axis="axial"
22)
23
24# 3. Results are Datamint annotations -- ready for platform use ----
25for annotation in result[0]:
26    print(annotation)

Note

The Datamint path eliminates ~70% of the boilerplate code required by the raw PyTorch approach, especially when working with medical imaging formats like DICOM and NIfTI.

Key Observations

Aspect

Raw PyTorch

Datamint

Data loading

Custom Dataset, DataModule, split logic

ImageDataset + DatamintDataModule

Model definition

Manual UNetPP implementation

Built-in UNetPPTrainer with defaults

Training loop

Manually implemented

Automatic via Lightning Trainer

Experiment tracking

Manual MLflow logging

Automatic callback-based logging

Checkpoint management

Manual export & naming

Automatic versioning + MLflow registration

Inference wrapper

Custom code per model

Built-in PredictionMode system

Deployment

Manual API integration

One-call api.deploy()

At inference time, the key differences are:

Inference-time comparison

Aspect

Raw PyTorch

Datamint

Input

Caller handles file I/O, decoding, normalization, and channel ordering.

Pass resource descriptors such as {"path": "...""}; the wrapper resolves and preprocesses automatically.

Output

Returns raw logits, probabilities, or masks requiring post-processing.

Returns list[list[Annotation]] ready for platform workflows.

Multi-mode inference

Each new mode (slice, volume, frame …) requires custom glue code.

Controlled via PredictionMode: IMAGE, SLICE, VOLUME, FRAME, FRAME_RANGE, ALL_FRAMES, TEMPORAL_SEQUENCE.

Further reading