Trainers

Specialized trainers for end-to-end Datamint workflows.

class datamint.lightning.trainers.BaseTrainer(dataset=None, project=None, *, dataset_kwargs=None, model=None, loss_fn=None, batch_size=16, num_workers=4, train_transform=None, eval_transform=None, split_as_of_timestamp=None, max_epochs=1, early_stopping_patience=10, mlflow_experiment_name=None, register_model_name=None, auto_deploy_adapter=True, trainer_kwargs=None, **kwargs)

Bases: ABC

Abstract base trainer encapsulating an end-to-end training workflow.

Subclasses provide task-specific defaults for model architecture, transforms, loss, and metrics by overriding the _build_* / _default_* hooks. Users typically only need to specify a project (or dataset) and optionally override a few settings.

Parameters:
  • dataset (DatamintBaseDataset | None) – A pre-built DatamintBaseDataset. Mutually exclusive with project.

  • project (str | Project | None) – Project name or Project object used to auto-build a dataset when dataset is None.

  • model (LightningModule | type[LightningModule] | None) – A user-provided LightningModule. When None the trainer builds a default one via _build_model().

  • loss_fn (Module | None) – Custom loss function forwarded to the default model. Ignored when model is provided (the user’s module owns its own loss).

  • batch_size (int) – Training batch size.

  • num_workers (int) – DataLoader workers.

  • train_transform (BaseCompose | None) – Albumentations transform for training. When None the trainer uses _train_transform().

  • eval_transform (BaseCompose | None) – Albumentations transform for val/test. When None the trainer uses _eval_transform().

  • split_as_of_timestamp (str | None) – Historical timestamp used to resolve project-scoped dataset splits during training. When omitted, the resolved project split datasets capture the current UTC timestamp and training lineage logs it via MLflow.

  • max_epochs (int) – Maximum number of training epochs.

  • early_stopping_patience (int | None) – Epochs without improvement before stopping. Set to None to disable early stopping.

  • mlflow_experiment_name (str | None) – MLflow experiment name. Auto-generated from the project name when None.

  • register_model_name (str | None) – Name for MLflow Model Registry. Auto-generated when None.

  • auto_deploy_adapter (bool) – When True, auto-generate a DatamintModel adapter after training.

  • trainer_kwargs (dict[str, Any] | None) – Extra keyword arguments forwarded to lightning.Trainer.

  • dataset_kwargs (dict[str, Any] | None)

  • kwargs (Any)

property datamodule: DatamintDataModule
property dataset: DatamintBaseDataset
property experiment_name: str
fit()

Run the full training pipeline.

Return type:

dict[str, Any]

Returns:

Dictionary with keys 'trainer', 'model', 'test_results', and 'adapter' (when auto_deploy_adapter is enabled).

property model: LightningModule
test(register_model=True)

Run evaluation on the test split in a fresh run.

Parameters:

register_model (bool) – When True, run a zero-epoch fit first so the checkpoint callback saves the current model to MLflow and registers it after test metrics are logged.

Return type:

list[Mapping[str, float]]

class datamint.lightning.trainers.ClassificationTrainer(dataset=None, project=None, *, dataset_kwargs=None, model=None, loss_fn=None, batch_size=16, num_workers=4, train_transform=None, eval_transform=None, split_as_of_timestamp=None, max_epochs=1, early_stopping_patience=10, mlflow_experiment_name=None, register_model_name=None, auto_deploy_adapter=True, trainer_kwargs=None, **kwargs)

Bases: BaseTrainer

Abstract trainer for classification tasks.

Provides shared defaults:

  • LossCrossEntropyLoss.

  • Metrics – Multiclass Accuracy and macro F1 (torchmetrics).

  • Monitorval/accuracy (maximise).

Parameters:
  • dataset (DatamintBaseDataset | None)

  • project (str | Project | None)

  • dataset_kwargs (dict[str, Any] | None)

  • model (LightningModule | type[LightningModule] | None)

  • loss_fn (Module | None)

  • batch_size (int)

  • num_workers (int)

  • train_transform (BaseCompose | None)

  • eval_transform (BaseCompose | None)

  • split_as_of_timestamp (str | None)

  • max_epochs (int)

  • early_stopping_patience (int | None)

  • mlflow_experiment_name (str | None)

  • register_model_name (str | None)

  • auto_deploy_adapter (bool)

  • trainer_kwargs (dict[str, Any] | None)

  • kwargs (Any)

class datamint.lightning.trainers.DeepLabV3PlusTrainer(dataset=None, project=None, *, image_size=None, slice_axis=None, model=None, in_channels=3, loss_fn=None, batch_size=16, num_workers=4, train_transform=None, eval_transform=None, split_as_of_timestamp=None, max_epochs=1, early_stopping_patience=10, mlflow_experiment_name=None, register_model_name=None, auto_deploy_adapter=True, trainer_kwargs=None, dataset_kwargs=None, encoder_name='resnet34', decoder_atrous_rates=(12, 24, 36), **kwargs)

Bases: SemanticSegmentation2DTrainer

Convenience trainer pre-configured for DeepLab v3+.

Uses the ASPP-based DeepLab v3+ architecture from segmentation_models_pytorch. The decoder_atrous_rates parameter controls the dilation rates of the Atrous Spatial Pyramid Pooling module, which is DeepLab v3+’s core multi-scale context mechanism.

Example:

trainer = DeepLabV3PlusTrainer(
    project='BUS_Segmentation',
    encoder_name='resnet50',
)
results = trainer.fit()
Parameters:
  • dataset (DatamintBaseDataset | None)

  • project (str | Project | None)

  • image_size (int | tuple[int, int] | None)

  • slice_axis (Literal['axial', 'sagittal', 'coronal'] | int | None)

  • model (LightningModule | type[LightningModule] | None)

  • in_channels (int)

  • loss_fn (Module | None)

  • batch_size (int)

  • num_workers (int)

  • train_transform (BaseCompose | None)

  • eval_transform (BaseCompose | None)

  • split_as_of_timestamp (str | None)

  • max_epochs (int)

  • early_stopping_patience (int | None)

  • mlflow_experiment_name (str | None)

  • register_model_name (str | None)

  • auto_deploy_adapter (bool)

  • trainer_kwargs (dict[str, Any] | None)

  • dataset_kwargs (dict[str, Any] | None)

  • encoder_name (str)

  • decoder_atrous_rates (tuple[int, int, int])

  • kwargs (Any)

class datamint.lightning.trainers.ImageClassificationTrainer(*, model_name='resnet34', pretrained=True, image_size=None, **kwargs)

Bases: ClassificationTrainer

Trainer for image classification tasks.

Default model: ResNet-34 (via timm) pretrained on ImageNet.

Parameters:
  • model_name (str) – timm model name. Defaults to 'resnet34'.

  • pretrained (bool) – Use pretrained weights. Defaults to True.

  • image_size (int | tuple[int, int] | None) – Optional target image size (H, W) or a single int for square images. When omitted, the trainer keeps the original image size instead of forcing a resize.

Example:

trainer = ImageClassificationTrainer(project='ChestXray')
results = trainer.fit()
Parameters:

kwargs (Any)

class datamint.lightning.trainers.SegmentationTrainer(dataset=None, project=None, *, dataset_kwargs=None, model=None, loss_fn=None, batch_size=16, num_workers=4, train_transform=None, eval_transform=None, split_as_of_timestamp=None, max_epochs=1, early_stopping_patience=10, mlflow_experiment_name=None, register_model_name=None, auto_deploy_adapter=True, trainer_kwargs=None, **kwargs)

Bases: BaseTrainer

Abstract trainer for segmentation tasks.

Provides shared defaults:

  • Loss – combined BCE + Dice (_BCEDiceLoss).

  • Metrics – Mean IoU and Generalised Dice Score (torchmetrics).

  • Monitorval/iou (maximise).

Parameters:
  • dataset (DatamintBaseDataset | None)

  • project (str | Project | None)

  • dataset_kwargs (dict[str, Any] | None)

  • model (LightningModule | type[LightningModule] | None)

  • loss_fn (Module | None)

  • batch_size (int)

  • num_workers (int)

  • train_transform (BaseCompose | None)

  • eval_transform (BaseCompose | None)

  • split_as_of_timestamp (str | None)

  • max_epochs (int)

  • early_stopping_patience (int | None)

  • mlflow_experiment_name (str | None)

  • register_model_name (str | None)

  • auto_deploy_adapter (bool)

  • trainer_kwargs (dict[str, Any] | None)

  • kwargs (Any)

class datamint.lightning.trainers.SemanticSegmentation2DTrainer(*, image_size=None, slice_axis=None, model=None, in_channels=3, trainer_kwargs=None, **kwargs)

Bases: SegmentationTrainer

Trainer for 2-D semantic segmentation.

Default model: UNet++ (segmentation_models_pytorch) with a resnet34 encoder pretrained on ImageNet.

When pointed at a project made of 3-D volumes, the trainer automatically converts it to a SlicedVolumeDataset and trains on 2-D slices instead.

Parameters:
  • slice_axis (Literal['axial', 'sagittal', 'coronal'] | int | None) – Slice axis override for 3-D volume projects. When omitted, the trainer tries to infer the most sensible anatomical plane and falls back to 'axial'.

  • image_size (int | tuple[int, int] | None) – Target image size (H, W) or a single int for square images. Forwarded to default transforms. When None a sensible default is chosen.

  • in_channels (int) – Number of input image channels. Defaults to 3.

  • to (All remaining keyword arguments are forwarded)

:param BaseTrainer.:

Example:

trainer = SemanticSegmentation2DTrainer(project='BUS_Segmentation')
results = trainer.fit()
Parameters:
  • model (LightningModule | type[LightningModule] | None)

  • trainer_kwargs (dict[str, Any] | None)

  • kwargs (Any)

class datamint.lightning.trainers.SemanticSegmentation3DTrainer(*, slice_axis='axial', encoder_name='resnet34', in_channels=3, image_size=None, **kwargs)

Bases: SegmentationTrainer

Trainer for 3-D semantic segmentation via per-slice 2-D training.

Builds a VolumeDataset, slices it along the chosen axis, and trains a 2-D segmentation model on individual slices.

Parameters:
  • slice_axis (str | int) – Slicing axis — 'axial', 'sagittal', 'coronal', or an integer axis index.

  • encoder_name (str) – SMP encoder backbone.

  • in_channels (int) – Number of input channels.

  • image_size (int | tuple[int, int] | None) – Optional target image size (H, W) or a single int for square images. When omitted, training keeps the original slice size.

Example:

trainer = SemanticSegmentation3DTrainer(
    project='CT_Liver',
    slice_axis='axial',
)
results = trainer.fit()
Parameters:

kwargs (Any)

class datamint.lightning.trainers.TransUNetTrainer(dataset=None, project=None, *, image_size=None, slice_axis=None, model=None, in_channels=3, loss_fn=None, batch_size=16, num_workers=4, train_transform=None, eval_transform=None, split_as_of_timestamp=None, max_epochs=1, early_stopping_patience=10, mlflow_experiment_name=None, register_model_name=None, auto_deploy_adapter=True, trainer_kwargs=None, dataset_kwargs=None, variant='R50-ViT-B_16', pretrained=True, **kwargs)

Bases: SemanticSegmentation2DTrainer

Convenience trainer pre-configured for TransUNet.

Uses the R50-ViT-B/16 hybrid encoder with a Cascaded UPsampler (CUP) decoder from Chen et al. (2021). The backbone is timm’s vit_base_r50_s16_224, which is a drop-in match for the architecture described in the paper.

Example:

trainer = TransUNetTrainer(
    project='BUS_Segmentation',
)
results = trainer.fit()
Parameters:
  • dataset (DatamintBaseDataset | None)

  • project (str | Project | None)

  • image_size (int | tuple[int, int] | None)

  • slice_axis (Literal['axial', 'sagittal', 'coronal'] | int | None)

  • model (LightningModule | type[LightningModule] | None)

  • in_channels (int)

  • loss_fn (Module | None)

  • batch_size (int)

  • num_workers (int)

  • train_transform (BaseCompose | None)

  • eval_transform (BaseCompose | None)

  • split_as_of_timestamp (str | None)

  • max_epochs (int)

  • early_stopping_patience (int | None)

  • mlflow_experiment_name (str | None)

  • register_model_name (str | None)

  • auto_deploy_adapter (bool)

  • trainer_kwargs (dict[str, Any] | None)

  • dataset_kwargs (dict[str, Any] | None)

  • variant (str)

  • pretrained (bool)

  • kwargs (Any)

REQUIRED_IMAGE_SIZE: tuple[int, int] = (224, 224)
class datamint.lightning.trainers.UNetPPTrainer(dataset=None, project=None, *, image_size=None, slice_axis=None, model=None, in_channels=3, loss_fn=None, batch_size=16, num_workers=4, train_transform=None, eval_transform=None, split_as_of_timestamp=None, max_epochs=1, early_stopping_patience=10, mlflow_experiment_name=None, register_model_name=None, auto_deploy_adapter=True, trainer_kwargs=None, dataset_kwargs=None, encoder_name='resnet34', **kwargs)

Bases: SemanticSegmentation2DTrainer

Convenience trainer pre-configured for UNet++ with stronger augmentations.

Adds elastic transform and grid distortion to the default training pipeline — augmentations that are particularly effective for medical image segmentation.

Example:

trainer = UNetPPTrainer(
    project='BUS_Segmentation',
    encoder_name='resnet34',)
results = trainer.fit()
Parameters:
  • dataset (DatamintBaseDataset | None)

  • project (str | Project | None)

  • image_size (int | tuple[int, int] | None)

  • slice_axis (Literal['axial', 'sagittal', 'coronal'] | int | None)

  • model (LightningModule | type[LightningModule] | None)

  • in_channels (int)

  • loss_fn (Module | None)

  • batch_size (int)

  • num_workers (int)

  • train_transform (BaseCompose | None)

  • eval_transform (BaseCompose | None)

  • split_as_of_timestamp (str | None)

  • max_epochs (int)

  • early_stopping_patience (int | None)

  • mlflow_experiment_name (str | None)

  • register_model_name (str | None)

  • auto_deploy_adapter (bool)

  • trainer_kwargs (dict[str, Any] | None)

  • dataset_kwargs (dict[str, Any] | None)

  • encoder_name (str)

  • kwargs (Any)