Trainers
Specialized trainers for end-to-end Datamint workflows.
- class datamint.lightning.trainers.BaseTrainer(dataset=None, project=None, *, dataset_kwargs=None, model=None, loss_fn=None, batch_size=16, num_workers=4, train_transform=None, eval_transform=None, split_as_of_timestamp=None, max_epochs=1, early_stopping_patience=10, mlflow_experiment_name=None, register_model_name=None, auto_deploy_adapter=True, trainer_kwargs=None, **kwargs)
Bases:
ABCAbstract base trainer encapsulating an end-to-end training workflow.
Subclasses provide task-specific defaults for model architecture, transforms, loss, and metrics by overriding the
_build_*/_default_*hooks. Users typically only need to specify aproject(ordataset) and optionally override a few settings.- Parameters:
dataset (
DatamintBaseDataset|None) – A pre-builtDatamintBaseDataset. Mutually exclusive with project.project (
str|Project|None) – Project name orProjectobject used to auto-build a dataset when dataset isNone.model (
LightningModule|type[LightningModule] |None) – A user-providedLightningModule. WhenNonethe trainer builds a default one via_build_model().loss_fn (
Module|None) – Custom loss function forwarded to the default model. Ignored when model is provided (the user’s module owns its own loss).batch_size (
int) – Training batch size.num_workers (
int) – DataLoader workers.train_transform (
BaseCompose|None) – Albumentations transform for training. WhenNonethe trainer uses_train_transform().eval_transform (
BaseCompose|None) – Albumentations transform for val/test. WhenNonethe trainer uses_eval_transform().split_as_of_timestamp (
str|None) – Historical timestamp used to resolve project-scoped dataset splits during training. When omitted, the resolved project split datasets capture the current UTC timestamp and training lineage logs it via MLflow.max_epochs (
int) – Maximum number of training epochs.early_stopping_patience (
int|None) – Epochs without improvement before stopping. Set toNoneto disable early stopping.mlflow_experiment_name (
str|None) – MLflow experiment name. Auto-generated from the project name whenNone.register_model_name (
str|None) – Name for MLflow Model Registry. Auto-generated whenNone.auto_deploy_adapter (
bool) – WhenTrue, auto-generate aDatamintModeladapter after training.trainer_kwargs (
dict[str,Any] |None) – Extra keyword arguments forwarded tolightning.Trainer.dataset_kwargs (
dict[str,Any] |None)kwargs (
Any)
- property datamodule: DatamintDataModule
- property dataset: DatamintBaseDataset
- property experiment_name: str
- fit()
Run the full training pipeline.
- Return type:
dict[str,Any]- Returns:
Dictionary with keys
'trainer','model','test_results', and'adapter'(when auto_deploy_adapter is enabled).
- property model: LightningModule
- test(register_model=True)
Run evaluation on the test split in a fresh run.
- Parameters:
register_model (
bool) – WhenTrue, run a zero-epoch fit first so the checkpoint callback saves the current model to MLflow and registers it after test metrics are logged.- Return type:
list[Mapping[str,float]]
- class datamint.lightning.trainers.ClassificationTrainer(dataset=None, project=None, *, dataset_kwargs=None, model=None, loss_fn=None, batch_size=16, num_workers=4, train_transform=None, eval_transform=None, split_as_of_timestamp=None, max_epochs=1, early_stopping_patience=10, mlflow_experiment_name=None, register_model_name=None, auto_deploy_adapter=True, trainer_kwargs=None, **kwargs)
Bases:
BaseTrainerAbstract trainer for classification tasks.
Provides shared defaults:
Loss –
CrossEntropyLoss.Metrics – Multiclass Accuracy and macro F1 (torchmetrics).
Monitor –
val/accuracy(maximise).
- Parameters:
dataset (
DatamintBaseDataset|None)project (
str|Project|None)dataset_kwargs (
dict[str,Any] |None)model (
LightningModule|type[LightningModule] |None)loss_fn (
Module|None)batch_size (
int)num_workers (
int)train_transform (
BaseCompose|None)eval_transform (
BaseCompose|None)split_as_of_timestamp (
str|None)max_epochs (
int)early_stopping_patience (
int|None)mlflow_experiment_name (
str|None)register_model_name (
str|None)auto_deploy_adapter (
bool)trainer_kwargs (
dict[str,Any] |None)kwargs (
Any)
- class datamint.lightning.trainers.DeepLabV3PlusTrainer(dataset=None, project=None, *, image_size=None, slice_axis=None, model=None, in_channels=3, loss_fn=None, batch_size=16, num_workers=4, train_transform=None, eval_transform=None, split_as_of_timestamp=None, max_epochs=1, early_stopping_patience=10, mlflow_experiment_name=None, register_model_name=None, auto_deploy_adapter=True, trainer_kwargs=None, dataset_kwargs=None, encoder_name='resnet34', decoder_atrous_rates=(12, 24, 36), **kwargs)
Bases:
SemanticSegmentation2DTrainerConvenience trainer pre-configured for DeepLab v3+.
Uses the ASPP-based DeepLab v3+ architecture from
segmentation_models_pytorch. Thedecoder_atrous_ratesparameter controls the dilation rates of the Atrous Spatial Pyramid Pooling module, which is DeepLab v3+’s core multi-scale context mechanism.Example:
trainer = DeepLabV3PlusTrainer( project='BUS_Segmentation', encoder_name='resnet50', ) results = trainer.fit()
- Parameters:
dataset (
DatamintBaseDataset|None)project (
str|Project|None)image_size (
int|tuple[int,int] |None)slice_axis (
Literal['axial','sagittal','coronal'] |int|None)model (
LightningModule|type[LightningModule] |None)in_channels (
int)loss_fn (
Module|None)batch_size (
int)num_workers (
int)train_transform (
BaseCompose|None)eval_transform (
BaseCompose|None)split_as_of_timestamp (
str|None)max_epochs (
int)early_stopping_patience (
int|None)mlflow_experiment_name (
str|None)register_model_name (
str|None)auto_deploy_adapter (
bool)trainer_kwargs (
dict[str,Any] |None)dataset_kwargs (
dict[str,Any] |None)encoder_name (
str)decoder_atrous_rates (
tuple[int,int,int])kwargs (
Any)
- class datamint.lightning.trainers.ImageClassificationTrainer(*, model_name='resnet34', pretrained=True, image_size=None, **kwargs)
Bases:
ClassificationTrainerTrainer for image classification tasks.
Default model: ResNet-34 (via
timm) pretrained on ImageNet.- Parameters:
model_name (
str) –timmmodel name. Defaults to'resnet34'.pretrained (
bool) – Use pretrained weights. Defaults toTrue.image_size (
int|tuple[int,int] |None) – Optional target image size(H, W)or a single int for square images. When omitted, the trainer keeps the original image size instead of forcing a resize.
Example:
trainer = ImageClassificationTrainer(project='ChestXray') results = trainer.fit()
- Parameters:
kwargs (
Any)
- class datamint.lightning.trainers.SegmentationTrainer(dataset=None, project=None, *, dataset_kwargs=None, model=None, loss_fn=None, batch_size=16, num_workers=4, train_transform=None, eval_transform=None, split_as_of_timestamp=None, max_epochs=1, early_stopping_patience=10, mlflow_experiment_name=None, register_model_name=None, auto_deploy_adapter=True, trainer_kwargs=None, **kwargs)
Bases:
BaseTrainerAbstract trainer for segmentation tasks.
Provides shared defaults:
Loss – combined BCE + Dice (
_BCEDiceLoss).Metrics – Mean IoU and Generalised Dice Score (torchmetrics).
Monitor –
val/iou(maximise).
- Parameters:
dataset (
DatamintBaseDataset|None)project (
str|Project|None)dataset_kwargs (
dict[str,Any] |None)model (
LightningModule|type[LightningModule] |None)loss_fn (
Module|None)batch_size (
int)num_workers (
int)train_transform (
BaseCompose|None)eval_transform (
BaseCompose|None)split_as_of_timestamp (
str|None)max_epochs (
int)early_stopping_patience (
int|None)mlflow_experiment_name (
str|None)register_model_name (
str|None)auto_deploy_adapter (
bool)trainer_kwargs (
dict[str,Any] |None)kwargs (
Any)
- class datamint.lightning.trainers.SemanticSegmentation2DTrainer(*, image_size=None, slice_axis=None, model=None, in_channels=3, trainer_kwargs=None, **kwargs)
Bases:
SegmentationTrainerTrainer for 2-D semantic segmentation.
Default model: UNet++ (
segmentation_models_pytorch) with aresnet34encoder pretrained on ImageNet.When pointed at a project made of 3-D volumes, the trainer automatically converts it to a
SlicedVolumeDatasetand trains on 2-D slices instead.- Parameters:
slice_axis (
Literal['axial','sagittal','coronal'] |int|None) – Slice axis override for 3-D volume projects. When omitted, the trainer tries to infer the most sensible anatomical plane and falls back to'axial'.image_size (
int|tuple[int,int] |None) – Target image size(H, W)or a single int for square images. Forwarded to default transforms. WhenNonea sensible default is chosen.in_channels (
int) – Number of input image channels. Defaults to3.to (All remaining keyword arguments are forwarded)
:param
BaseTrainer.:Example:
trainer = SemanticSegmentation2DTrainer(project='BUS_Segmentation') results = trainer.fit()
- Parameters:
model (
LightningModule|type[LightningModule] |None)trainer_kwargs (
dict[str,Any] |None)kwargs (
Any)
- class datamint.lightning.trainers.SemanticSegmentation3DTrainer(*, slice_axis='axial', encoder_name='resnet34', in_channels=3, image_size=None, **kwargs)
Bases:
SegmentationTrainerTrainer for 3-D semantic segmentation via per-slice 2-D training.
Builds a
VolumeDataset, slices it along the chosen axis, and trains a 2-D segmentation model on individual slices.- Parameters:
slice_axis (
str|int) – Slicing axis —'axial','sagittal','coronal', or an integer axis index.encoder_name (
str) – SMP encoder backbone.in_channels (
int) – Number of input channels.image_size (
int|tuple[int,int] |None) – Optional target image size(H, W)or a single int for square images. When omitted, training keeps the original slice size.
Example:
trainer = SemanticSegmentation3DTrainer( project='CT_Liver', slice_axis='axial', ) results = trainer.fit()
- Parameters:
kwargs (
Any)
- class datamint.lightning.trainers.TransUNetTrainer(dataset=None, project=None, *, image_size=None, slice_axis=None, model=None, in_channels=3, loss_fn=None, batch_size=16, num_workers=4, train_transform=None, eval_transform=None, split_as_of_timestamp=None, max_epochs=1, early_stopping_patience=10, mlflow_experiment_name=None, register_model_name=None, auto_deploy_adapter=True, trainer_kwargs=None, dataset_kwargs=None, variant='R50-ViT-B_16', pretrained=True, **kwargs)
Bases:
SemanticSegmentation2DTrainerConvenience trainer pre-configured for TransUNet.
Uses the R50-ViT-B/16 hybrid encoder with a Cascaded UPsampler (CUP) decoder from Chen et al. (2021). The backbone is
timm’svit_base_r50_s16_224, which is a drop-in match for the architecture described in the paper.Example:
trainer = TransUNetTrainer( project='BUS_Segmentation', ) results = trainer.fit()
- Parameters:
dataset (
DatamintBaseDataset|None)project (
str|Project|None)image_size (
int|tuple[int,int] |None)slice_axis (
Literal['axial','sagittal','coronal'] |int|None)model (
LightningModule|type[LightningModule] |None)in_channels (
int)loss_fn (
Module|None)batch_size (
int)num_workers (
int)train_transform (
BaseCompose|None)eval_transform (
BaseCompose|None)split_as_of_timestamp (
str|None)max_epochs (
int)early_stopping_patience (
int|None)mlflow_experiment_name (
str|None)register_model_name (
str|None)auto_deploy_adapter (
bool)trainer_kwargs (
dict[str,Any] |None)dataset_kwargs (
dict[str,Any] |None)variant (
str)pretrained (
bool)kwargs (
Any)
- REQUIRED_IMAGE_SIZE: tuple[int, int] = (224, 224)
- class datamint.lightning.trainers.UNetPPTrainer(dataset=None, project=None, *, image_size=None, slice_axis=None, model=None, in_channels=3, loss_fn=None, batch_size=16, num_workers=4, train_transform=None, eval_transform=None, split_as_of_timestamp=None, max_epochs=1, early_stopping_patience=10, mlflow_experiment_name=None, register_model_name=None, auto_deploy_adapter=True, trainer_kwargs=None, dataset_kwargs=None, encoder_name='resnet34', **kwargs)
Bases:
SemanticSegmentation2DTrainerConvenience trainer pre-configured for UNet++ with stronger augmentations.
Adds elastic transform and grid distortion to the default training pipeline — augmentations that are particularly effective for medical image segmentation.
Example:
trainer = UNetPPTrainer( project='BUS_Segmentation', encoder_name='resnet34',) results = trainer.fit()
- Parameters:
dataset (
DatamintBaseDataset|None)project (
str|Project|None)image_size (
int|tuple[int,int] |None)slice_axis (
Literal['axial','sagittal','coronal'] |int|None)model (
LightningModule|type[LightningModule] |None)in_channels (
int)loss_fn (
Module|None)batch_size (
int)num_workers (
int)train_transform (
BaseCompose|None)eval_transform (
BaseCompose|None)split_as_of_timestamp (
str|None)max_epochs (
int)early_stopping_patience (
int|None)mlflow_experiment_name (
str|None)register_model_name (
str|None)auto_deploy_adapter (
bool)trainer_kwargs (
dict[str,Any] |None)dataset_kwargs (
dict[str,Any] |None)encoder_name (
str)kwargs (
Any)