Datamint vs Raw PyTorch
This page shows how Datamint removes boilerplate from common medical-imaging training workflows by comparing side-by-side a raw PyTorch / Lightning setup with the equivalent Datamint code. All examples target 2-D semantic segmentation (e.g. BUSI, skin-dataset, or liver-segmentation projects).
Workflow comparison at a glance
Responsibility |
Raw PyTorch / Lightning |
Datamint |
|---|---|---|
📂 Data loading
|
Write a custom
torch.utils.data.Dataset subclass thatreads files, decodes DICOM / NIfTI / PNG,
applies augmentations, and returns aligned tensors.
|
One line:
ImageDatasethandles DICOM series, NIfTI, PNG, JPEG, and annotation parsing
automatically.
|
📊 Train / val / test splits
|
Manually partition resource IDs, tag them,
or write a split function;
must track split manually to
ensure reproducibility with seeds.
|
split() resolvesproject-scoped splits (or falls back to legacy
split:* tags) andreturns a snapshot timestamp you can replay later.
|
🔌 DataModule wiring
|
Implement
lightning.pytorch.core.LightningDataModule withprepare_data, setup, train_dataloader, val_dataloader,test_dataloader. |
|
🧱 Model scaffolding
|
Subclass
lightning.pytorch.LightningModule; writeforward, training_step, validation_step, test_step,configure_optimizers. |
SegmentationModuleprovides
forward, predict, and predict_batch withDatamint-native inference. Loss and metrics are injected automatically
when you pass the class (not an instance) to
model=. |
⚙️ Trainer configuration
|
Instantiate
lightning.pytorch.trainer.Trainer withmax_epochs, accelerator, devices, precision, logger,and checkpoint callbacks.
|
builds the dataset, datamodule, model, MLflow logger, and checkpoint
callbacks for you. Extra kwargs are forwarded to
lightning.pytorch.trainer.Trainer. |
📈 Experiment tracking
|
Configure
lightning.pytorch.loggers.MLFlowLogger orlightning.pytorch.loggers.TensorBoardLogger; loghyper-parameters, metrics, and artifacts manually.
|
MLflow is auto-configured on first import of
datamint.mlflow. The trainer logs metrics, hyper-parameters,and checkpoints automatically;
register_model=True registers thefinal artifact in MLflow.
|
🚀 Model deployment
|
Export the checkpoint, write inference code,
wrap it in an MLflow
mlflow.pyfunc.PythonModel,and deploy to your serving platform.
|
Trained model already a
DatamintModelthat can be deployed via the
DeployModelApi API. |
Example 1 – Dataset and split setup
1import os
2from pathlib import Path
3from typing import List, Tuple
4
5import albumentations as A
6import numpy as np
7import torch
8from PIL import Image
9from torch.utils.data import Dataset, DataLoader
10
11
12class SegmentationDataset(Dataset):
13 """Manually wired dataset for 2-D segmentation."""
14
15 def __init__(
16 self,
17 image_paths: List[Path],
18 mask_paths: List[Path],
19 transform: A.Compose,
20 ):
21 self.image_paths = image_paths
22 self.mask_paths = mask_paths
23 self.transform = transform
24
25 def __len__(self) -> int:
26 return len(self.image_paths)
27
28 def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
29 # -- File-type detection and decoding ---------------------------------
30 ext = self.image_paths[idx].suffix.lower()
31
32 if ext in (".dcm", ".dicom"):
33 # DICOM: requires pydicom to read, then handle pixel spacing,
34 # rescale slope/intercept, windowing, and multi-frame series.
35 import pydicom
36 ds = pydicom.dcmread(str(self.image_paths[idx]))
37 img = ds.pixel_array.astype(float)
38 # Interpret which shape dimension is the channel, which varies across dicoms and may require manual handling.
39 # (...) This is a common source of bugs
40
41 elif ext == ".nii" or ext == ".nii.gz":
42 # NIfTI: requires nibabel or nitransforms; handles 3-D volumes,
43 # custom affine transforms, and voxel-to-RAS mapping.
44 import nibabel as nib
45 nifti_img = nib.load(str(self.image_paths[idx]))
46 img = nifti_img.get_fdata().astype(np.float32)
47 # For 3-D volumes you must slice or process the full volume.
48 # This is non-trivial for segmentation tasks since you must ensure alignment with the mask and
49 # know which axis is the correct one to slice on.
50 # (...)
51
52 elif ext in (".png", ".jpg", ".jpeg"):
53 img = Image.open(self.image_paths[idx]).convert("RGB")
54 img = np.array(img).astype(np.float32)
55 else:
56 raise ValueError(f"Unsupported image format: {ext}")
57
58 # -- Mask loading (similar complexity for medical formats) -----------
59 mask_ext = self.mask_paths[idx].suffix.lower()
60 if mask_ext in (".nii", ".nii.gz"):
61 nifti_mask = nib.load(str(self.mask_paths[idx]))
62 mask = nifti_mask.get_fdata().astype(np.int64)
63 elif mask_ext in (".png",):
64 mask = np.array(Image.open(self.mask_paths[idx]).convert("L")).astype(np.int64)
65 else:
66 raise ValueError(f"Unsupported mask format: {mask_ext}")
67
68 # -- Augmentation -----------------------------------------------------
69 augmented = self.transform(image=img, mask=mask)
70 image = torch.from_numpy(augmented["image"]).float().permute(2, 0, 1)
71 mask = torch.from_numpy(augmented["mask"]).long()
72
73 return image, mask
74
75
76# -- Split logic (manual) -------------------------------------------
77all_paths = list(Path("data/images").iterdir())
78np.random.seed(42)
79np.random.shuffle(all_paths)
80train_paths = all_paths[:140]
81val_paths = all_paths[140:160]
82test_paths = all_paths[160:]
83
84# -- Training transform (with augmentations) ------------------------
85train_transform = A.Compose([
86 A.HorizontalFlip(p=0.5),
87 A.RandomBrightnessContrast(p=0.3),
88 A.Normalize(),
89])
90
91train_ds = SegmentationDataset(
92 train_paths,
93 [Path("data/masks") / p.name for p in train_paths],
94 train_transform,
95)
96train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=4)
97
98# -- Eval transform (no augmentations) ------------------------------
99eval_transform = A.Compose([A.Normalize()])
100
101val_ds = SegmentationDataset(
102 val_paths,
103 [Path("data/masks") / p.name for p in val_paths],
104 eval_transform,
105)
106val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=4)
107
108test_ds = SegmentationDataset(
109 test_paths,
110 [Path("data/masks") / p.name for p in test_paths],
111 eval_transform,
112)
113test_loader = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=4)
1import albumentations as A
2from datamint.dataset import ImageDataset
3from datamint.lightning import DatamintDataModule
4
5# 1. Load the project -- Datamint resolves all images and masks.
6dataset = ImageDataset(project="BUSI_Segmentation")
7
8# 2. Define transforms per stage.
9train_tfm = A.Compose([
10 A.HorizontalFlip(p=0.5),
11 A.RandomBrightnessContrast(p=0.3),
12 A.Normalize(),
13])
14eval_tfm = A.Compose([A.Normalize()])
15
16# 3. Wrap in a DataModule -- splitting and dataloaders are handled.
17dm = DatamintDataModule(
18 dataset,
19 batch_size=16,
20 split={'train': 0.7, 'val': 0.1, 'test': 0.1},
21 split_seed=42,
22 train_transform=train_tfm,
23 eval_transform=eval_tfm,
24)
25
26# Call prepare_data() / setup() before accessing dataloaders.
27dm.prepare_data()
28
29train_loader = dm.train_dataloader()
30val_loader = dm.val_dataloader()
31test_loader = dm.test_dataloader()
Aspect |
Raw PyTorch |
Datamint |
Dataset class |
|
|
Train/val/test splits |
Manual split logic + 3 instances (20-30 lines) |
|
DataModule |
Custom |
|
Total boilerplate |
~120-180 lines (PNG/JPEG) ~200-300 lines (DICOM/NIfTI) |
~5-10 lines |
Example 2 – Full training loop
1import lightning as L
2import segmentation_models_pytorch as smp
3import torch
4
5
6class SegmentationModel(L.LightningModule):
7 def __init__(self, num_classes: int = 2):
8 super().__init__()
9 self.model = smp.Unet(
10 encoder_name="resnet34",
11 encoder_weights="imagenet",
12 in_channels=3,
13 classes=num_classes,
14 )
15 self.loss_fn = torch.nn.CrossEntropyLoss()
16
17 def forward(self, x):
18 return self.model(x)
19
20 def training_step(self, batch, batch_idx):
21 images, masks = batch
22 logits = self(images)
23 loss = self.loss_fn(logits, masks)
24 self.log("train/loss", loss)
25 return loss
26
27 def validation_step(self, batch, batch_idx):
28 images, masks = batch
29 logits = self(images)
30 loss = self.loss_fn(logits, masks)
31 self.log("val/loss", loss)
32 return loss
33
34 def test_step(self, batch, batch_idx):
35 images, masks = batch
36 logits = self(images)
37 loss = self.loss_fn(logits, masks)
38 self.log("test/loss", loss)
39 return loss
40
41 def configure_optimizers(self):
42 return torch.optim.Adam(self.parameters(), lr=1e-4)
43
44
45# -- Trainer setup --------------------------------------------------
46model = SegmentationModel(num_classes=2)
47trainer = L.Trainer(
48 max_epochs=20,
49 accelerator="auto",
50 devices=1,
51 precision="16-mixed",
52 callbacks=[
53 L.pytorch.callbacks.ModelCheckpoint(
54 monitor="val/loss",
55 save_top_k=1,
56 mode="min",
57 )
58 ],
59 logger=L.pytorch.loggers.MLFlowLogger(experiment_name="BUSI_Segmentation"),
60)
61
62trainer.fit(model, datamodule=dm)
63trainer.test(model, datamodule=dm)
1from datamint.lightning import UNetPPTrainer
2
3trainer = UNetPPTrainer(
4 project="BUSI_Segmentation",
5 image_size=256,
6 batch_size=16,
7 max_epochs=20,
8 accelerator="auto",
9 # All extra kwargs are forwarded to lightning.Trainer.
10 precision="16-mixed",
11 devices=1,
12)
13
14results = trainer.fit()
15trainer.test() # evaluates on the test split
16
17# The model is already registered in MLflow.
18print(results["test_results"])
Aspect |
Raw PyTorch |
Datamint |
Model definition |
|
Built internally by trainer |
Training step |
Manual |
Automatic loss injection
via |
Validation step |
Manual |
Automatic |
Test step |
Manual |
Automatic |
Optimizer |
Manual |
Automatic |
Trainer setup |
|
Passed as kwargs to trainer |
Note
Datamint’s trainer handles dataset creation, datamodule wiring, model
instantiation, MLflow logger setup, and checkpoint callbacks
automatically — all from a single UNetPPTrainer
call.
Example 3 – Inference & deployment
After training, Datamint models are already registered in MLflow and can be
loaded and used for inference with zero extra code. The built-in
PredictionMode system
supports image, slice, volume, frame, and interactive prediction modes.
1import os
2from pathlib import Path
3import numpy as np
4import torch
5from PIL import Image
6import pydicom
7import nibabel as nib
8import mlflow
9
10
11# 1. Load model from MLflow ----------------------------------------
12model_uri = "runs:/abc123/artifacts/model"
13model = mlflow.pyfunc.load_model(model_uri)
14
15# 2. Manual preprocessing per format --------------------------------
16def preprocess_image(image_path: Path) -> torch.Tensor:
17 img = Image.open(image_path).convert("RGB")
18 img = np.array(img).astype(np.float32) / 255.0
19 img = (img - 0.5) / 0.5 # manual normalization
20 return torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
21
22
23def preprocess_dicom(dicom_path: Path) -> torch.Tensor:
24 ds = pydicom.dcmread(str(dicom_path))
25 img = ds.pixel_array.astype(np.float32)
26 # Handle rescale slope/intercept
27 if hasattr(ds, "RescaleSlope"):
28 img = img * ds.RescaleSlope + ds.RescaleIntercept
29 # Handle windowing
30 if hasattr(ds, "WindowCenter") and hasattr(ds, "WindowWidth"):
31 wc = ds.WindowCenter[0] if hasattr(ds.WindowCenter, "__iter__") else ds.WindowCenter
32 ww = ds.WindowWidth[0] if hasattr(ds.WindowWidth, "__iter__") else ds.WindowWidth
33 img = (img - (wc - ww / 2)) / (ww / 2)
34 img = np.clip(img, -1, 1)
35 return torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0)
36
37
38def preprocess_nifti(nifti_path: Path) -> torch.Tensor:
39 nifti_img = nib.load(str(nifti_path))
40 img = nifti_img.get_fdata().astype(np.float32)
41 # Must manually handle 3-D → 2-D slicing
42 img = img[:, :, img.shape[2] // 2] # center slice
43 img = (img - img.mean()) / (img.std() + 1e-8)
44 return torch.from_numpy(img).unsqueeze(0).unsqueeze(0)
45
46
47# 3. Manual post-processing -----------------------------------------
48def postprocess_logits(logits: torch.Tensor) -> np.ndarray:
49 probs = torch.softmax(logits, dim=1).cpu().numpy()
50 mask = np.argmax(probs, axis=1).squeeze()
51 return (mask * 255).astype(np.uint8)
52
53
54# 4. Run inference ---------------------------------------------------
55# For a PNG image:
56tensor = preprocess_image(Path("test_image.png"))
57logits = model.predict(tensor)
58result_mask = postprocess_logits(logits)
59
60# For a DICOM file (different preprocessing!):
61tensor = preprocess_dicom(Path("test_dicom.dcm"))
62logits = model.predict(tensor)
63result_mask = postprocess_logits(logits)
64
65# For a NIfTI volume (yet another preprocessing!):
66tensor = preprocess_nifti(Path("test_volume.nii.gz"))
67logits = model.predict(tensor)
68result_mask = postprocess_logits(logits)
1import mlflow
2from datamint.entities import LocalResource
3from datamint.mlflow.flavors import load_model
4
5# 1. Load model from MLflow -------------------
6model_uri = 'models:/UNetPP_Segmentation_Tutorial/latest'
7model = mlflow.pyfunc.load_model(model_uri)
8
9# 2. Predict on each format through the same API -------------------
10# For a PNG image:
11result = model.predict([LocalResource("test-image.png")])
12
13# For a DICOM file:
14result = model.predict([LocalResource("test_dicom.dcm")])
15
16# For a NIfTI volume (and slice prediction):
17model = load_model(model_uri) # Use Datamint's MLflow flavor to get the extended API
18result = model.predict_slice(
19 model_input=[LocalResource("test_dicom.nii.gz")],
20 slice_index=10,
21 axis="axial"
22)
23
24# 3. Results are Datamint annotations -- ready for platform use ----
25for annotation in result[0]:
26 print(annotation)
Note
The Datamint path eliminates ~70% of the boilerplate code required by the raw PyTorch approach, especially when working with medical imaging formats like DICOM and NIfTI.
Key Observations
Aspect |
Raw PyTorch |
Datamint |
|---|---|---|
Data loading |
Custom Dataset, DataModule, split logic |
|
Model definition |
Manual UNetPP implementation |
Built-in |
Training loop |
Manually implemented |
Automatic via Lightning |
Experiment tracking |
Manual MLflow logging |
Automatic callback-based logging |
Checkpoint management |
Manual export & naming |
Automatic versioning + MLflow registration |
Inference wrapper |
Custom code per model |
Built-in |
Deployment |
Manual API integration |
One-call |
At inference time, the key differences are:
Aspect |
Raw PyTorch |
Datamint |
|---|---|---|
Input |
Caller handles file I/O, decoding, normalization, and channel ordering. |
Pass resource descriptors such as |
Output |
Returns raw logits, probabilities, or masks requiring post-processing. |
Returns |
Multi-mode inference |
Each new mode (slice, volume, frame …) requires custom glue code. |
Controlled via |
Further reading
PyTorch & Lightning Integration – Dataset and datamodule details.
Trainer API – Trainer API reference and external-model patterns.
BUSI trainer notebook – Runnable end-to-end segmentation tutorial.
External model deployment tutorial – Deploy a custom model trained with Datamint.