Trainer API

The trainer layer in datamint.lightning packages the usual Lightning workflow into a small number of task-focused entry points. A trainer can:

  • build the dataset and datamodule for a Datamint project,

  • choose task-specific default transforms, loss functions, and metrics,

  • create the Lightning trainer, MLflow logger, and checkpoint callbacks,

  • train and test the model, and

  • optionally register the resulting model in MLflow.

Available Trainers

UNetPPTrainer

The fastest way to train a 2-D semantic segmentation model with a sensible UNet++ configuration. For pure 3-D volume projects, it automatically slices the volumes into 2-D samples before training.

SemanticSegmentation2DTrainer

A more explicit 2-D segmentation trainer when you want to control the model, transforms, or loss. It auto-detects whether the project contains 2-D images or 3-D volumes and accepts slice_axis= to override the inferred plane for volume projects.

SemanticSegmentation3DTrainer

Slice-based semantic segmentation for projects of 3-D volumes.

ImageClassificationTrainer

Image classification using a timm backbone.

Quick Start

from datamint.lightning import UNetPPTrainer

trainer = UNetPPTrainer(
    project="BUSI_Segmentation",
    image_size=256,
    batch_size=16,
    max_epochs=20,
    accelerator="auto",
)

results = trainer.fit()
print(results["test_results"])

The built-in trainer configures the dataset, datamodule, model, MLflow logger, checkpointing, and evaluation loop for you. After fit(), the resolved objects are also available as trainer.dataset, trainer.datamodule, and trainer.model.

Inputs, Splits, and Outputs

Each trainer accepts exactly one of:

  • project=... to let Datamint build the dataset automatically, or

  • dataset=... to reuse a dataset you already configured yourself.

For SemanticSegmentation2DTrainer and UNetPPTrainer, project-backed dataset resolution is automatic: pure 2-D image projects use ImageDataset, while pure 3-D volume projects are converted to SlicedVolumeDataset. The slice plane is inferred from volume spacing and shape when possible, and falls back to 'axial'. To force a plane, pass slice_axis='coronal' or another supported axis when constructing the trainer.

When you train from a project, the trainer expects train/val/test split assignments to exist for that project. If you need strict split reproducibility across runs, pass the historical split snapshot timestamp through split_as_of_timestamp.

trainer = UNetPPTrainer(
    project="BUSI_Segmentation",
    split_as_of_timestamp="2026-04-21T12:34:56Z",
)

fit() returns a dictionary with these keys:

model

The trained model instance.

test_results

The metrics returned by Lightning test().

If you only want evaluation, use test() instead:

test_metrics = trainer.test(register_model=False)

With register_model=True, the trainer logs and registers the current model.

Passing Lightning Trainer Options

Any extra keyword arguments that are not consumed by Datamint are forwarded to lightning.Trainer.

trainer = UNetPPTrainer(
    project="BUSI_Segmentation",
    max_epochs=12,
    accelerator="gpu",
    devices=1,
    precision="16-mixed",
    log_every_n_steps=10,
    trainer_kwargs={"enable_progress_bar": True},
)

Using an External Model Inside a Datamint Trainer

There are two supported patterns, and they are not equivalent.

Preferred: Subclass a Datamint Lightning Module

If you want to swap the network architecture but keep Datamint’s loss wiring, metrics, MLflow model behaviour, and deployment-friendly prediction methods, subclass SegmentationModule or ClassificationModule and pass the class object to model=.

import segmentation_models_pytorch as smp
from datamint.lightning import SemanticSegmentation2DTrainer
from datamint.lightning.trainers.lightning_modules import SegmentationModule


class DeepLabV3PlusModule(SegmentationModule):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, class_names=["benign", "malignant"], **kwargs)
        self.model = smp.DeepLabV3Plus(
            encoder_name="resnet50",
            encoder_weights="imagenet",
            in_channels=3,
            classes=2,
        )

    def forward(self, x):
        return self.model(x)


trainer = SemanticSegmentation2DTrainer(
    project="BUSI_Segmentation",
    image_size=256,
    model=DeepLabV3PlusModule,
)

results = trainer.fit()

When you pass the class object instead of an instance, the trainer instantiates it and injects task defaults through loss_fn= and metrics_factories=. This is the easiest way to plug an external architecture into the Datamint trainer workflow while keeping the resulting model Datamint-compatible.

Fully Custom LightningModule

If you already have a plain lightning.LightningModule, pass the instance through model= and implement the full training logic yourself.

import lightning as L
import segmentation_models_pytorch as smp
import torch
from datamint.lightning import SemanticSegmentation2DTrainer


class ExternalSegmentationModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.model = smp.Unet(
            encoder_name="resnet34",
            encoder_weights="imagenet",
            in_channels=3,
            classes=2,
        )
        self.loss_fn = torch.nn.BCEWithLogitsLoss()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images = batch["image"]
        masks = batch["segmentations"][:, 1:].float()
        loss = self.loss_fn(self(images), masks)
        self.log("train/loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        images = batch["image"]
        masks = batch["segmentations"][:, 1:].float()
        loss = self.loss_fn(self(images), masks)
        self.log("val/loss", loss)
        return loss

    def test_step(self, batch, batch_idx):
        images = batch["image"]
        masks = batch["segmentations"][:, 1:].float()
        loss = self.loss_fn(self(images), masks)
        self.log("test/loss", loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)


trainer = SemanticSegmentation2DTrainer(
    project="BUSI_Segmentation",
    model=ExternalSegmentationModule(),
    max_epochs=5,
)

This gives you the Datamint dataset, split handling, MLflow logger, and checkpointing, but the model itself remains a plain Lightning module. That means Datamint-native inference and deployment behaviour is not added automatically. If you want the trained artifact to behave like a Datamint model, prefer the SegmentationModule / ClassificationModule route, or wrap the final model in a DatamintModel afterwards.

Note

Class vs. Instance

model=MyModule and model=MyModule() are different:

  • Pass the class when you want the trainer to inject loss_fn and metrics_factories.

  • Pass an instance when the module is already fully configured and owns its entire training logic.

Caution

Segmentation batches expose masks in batch["segmentations"] and include the background channel at index 0.