MLflow Integration

The datamint.mlflow module provides seamless MLflow integration for experiment tracking, model logging, and model deployment.

Overview

Key features:

  • Automatic Experiment Tracking: Experiments are automatically associated with Datamint projects

  • Model Registration: Trained models are registered in MLflow and can be deployed to Datamint

  • Dataset Versioning: Datasets are logged as MLflow artifacts with metadata

  • Custom Checkpointing: MLflowModelCheckpoint integrates with Lightning callbacks

  • Flavor Support: Custom Datamint flavors for segmentation and classification models

Automatic Configuration

The MLflow module auto-configures itself on first import:

# This automatically sets up MLflow environment
import datamint.mlflow

# Or explicitly:
from datamint.mlflow import ensure_mlflow_configured
ensure_mlflow_configured()

Environment Setup

Utility functions for automatically configuring MLflow environment variables based on Datamint configuration.

datamint.mlflow.env_utils.ensure_mlflow_configured()

Ensure MLflow environment is properly configured. Raises ValueError if configuration is incomplete.

Return type:

None

datamint.mlflow.env_utils.setup_mlflow_environment(overwrite=False, set_mlflow=True)

Set up MLflow environment variables based on Datamint configuration.

Parameters:
  • overwrite (bool) – If True, overwrite existing MLflow environment variables.

  • set_mlflow (bool) – If True, set the MLflow tracking URI using mlflow.set_tracking_uri().

Returns:

True if success, False otherwise.

Return type:

bool

class datamint.mlflow.env_vars.EnvVars(*values)
DATAMINT_PROJECT_ID = 'DATAMINT_PROJECT_ID'
DATAMINT_PROJECT_NAME = 'DATAMINT_PROJECT_NAME'

MLflow Dataset

MLflow Dataset adapter for Datamint project splits.

class datamint.mlflow.data.datamint_dataset.DatamintMLflowDataset(project_id, project_name, split, resources, extra_params=None)

Bases: Dataset

MLflow Dataset wrapping a Datamint project split for lineage tracking.

Parameters:
  • project_id (str)

  • project_name (str)

  • split (str | None)

  • resources (Sequence[str] | Sequence[Resource])

  • extra_params (dict[str, Any] | None)

property profile: Any | None

Optional summary statistics for the dataset, such as the number of rows in a table, the mean / median / std of each table column, etc.

property schema

Optional dataset schema, such as an instance of mlflow.types.Schema representing the features and targets of the dataset.

to_dict()

Create config dictionary for the dataset.

Subclasses should override this method to provide additional fields in the config dict, e.g., schema, profile, etc.

Returns a string dictionary containing the following fields: name, digest, source, source type.

Return type:

dict[str, str]

The DatamintMLflowDataset wraps a DatamintBaseDataset for MLflow artifact logging.

Model Flavors

class datamint.mlflow.flavors.model.BaseDatamintModel(settings=None)

Bases: PythonModel, ABC

Core prediction gateway that any MLflow PythonModel can build on.

Owns:

Use DatamintModel when you need to load linked models at serve time.

Subclasses only need to implement predict_default() (and optionally other predict_* hooks registered with @prediction_mode).

Parameters:

settings (ModelSettings | dict[str, Any] | None)

get_supported_modes()

Return the list of prediction modes supported by this model.

Return type:

list[str]

property inference_device: str

The device that will be used for inference.

Returns _inference_device if already set, then falls back to the MLFLOW_DEFAULT_PREDICTION_DEVICE environment variable, then 'cpu'.

load_context(context)

Detect the inference device and move any attached torch modules to it.

Override in subclasses to perform additional loading (e.g. linked models) — but always call super().load_context(context) first so that inference_device is set before any model loading.

Parameters:

context (PythonModelContext)

Return type:

None

predict(model_input, params=None)

Main prediction entry point.

Routes to the appropriate handler based on params['mode']. Do not override — implement predict_default() (or other predict_* hooks) instead.

Parameters:
  • model_input (list[BaseResource])

  • params (dict[str, Any] | None)

Return type:

list[list[Annotation]]

predict_default(model_input, **kwargs)

Default prediction on entire resources.

Override this in your subclass.

Parameters:
  • model_input (list[BaseResource])

  • kwargs (Any)

Return type:

list[list[Annotation]]

predict_image(model_input, **kwargs)
Parameters:
  • model_input (list[BaseResource])

  • kwargs (Any)

Return type:

Sequence[Sequence[ImageSegmentation | ImageClassification]]

predict_slice(model_input, slice_index, axis, **kwargs)
Parameters:
  • model_input (list[BaseResource])

  • slice_index (int)

  • axis (Literal['axial', 'sagittal', 'coronal'])

  • kwargs (Any)

Return type:

Sequence[Sequence[ImageSegmentation | ImageClassification]]

predict_volume(model_input, **kwargs)
Parameters:
  • model_input (list[BaseResource])

  • kwargs (Any)

Return type:

list[list[Annotation]]

property settings: ModelSettings
task_type: ClassVar[TaskType | None] = None

Semantic task category for this model class. Subclasses should override at the class body level (e.g. task_type = TaskType.IMAGE_SEGMENTATION).

class datamint.mlflow.flavors.model.DatamintModel(settings=None, mlflow_torch_models_uri=None, mlflow_models_uri=None, torch_model=None)

Bases: BaseDatamintModel

Abstract adapter for wrapping ML models to produce Datamint annotations.

Extends BaseDatamintModel with support for loading external (“linked”) MLflow models at serve time via LinkedModelLoader. Subclasses only need to override predict_default (and optionally other predict_* hooks).

Quick Start:

class MyModel(DatamintModel):
    def __init__(self):
        super().__init__(
            mlflow_models_uri={'model': 'models:/MyModel/latest'},
            settings=ModelSettings(need_gpu=True),
        )

    def predict_default(self, model_input, **kwargs):
        device = self.inference_device
        model = self.get_mlflow_models()['model'].get_raw_model().to(device)
        return predictions

You can also pass pre-instantiated torch.nn.Module objects directly:

class MyModel(DatamintModel):
    def __init__(self):
        net = MyTorchNet()
        super().__init__(torch_model=net)

    def predict_default(self, model_input, **kwargs):
        net = self.get_mlflow_torch_models()['net']
        return net(preprocess(model_input))
Parameters:
  • settings (ModelSettings | dict[str, Any] | None)

  • mlflow_torch_models_uri (dict[str, str] | None)

  • mlflow_models_uri (dict[str, str] | None)

  • torch_model (Module | None)

LINKED_MODELS_DIR: ClassVar[str] = 'linked_models'
get_mlflow_models()

Access loaded MLflow pyfunc models.

Return type:

dict[str, PyFuncModel]

get_mlflow_torch_models()

Access loaded MLflow PyTorch models.

Return type:

dict[str, Any]

get_pytorch_model()
Return type:

Module | None

load_context(context)

Detect device and load all linked MLflow models.

Parameters:

context (PythonModelContext)

Return type:

None

property mlflow_models_uri: dict[str, str]
property mlflow_torch_models_uri: dict[str, str]
predict_slice(model_input, slice_index, axis, **kwargs)
Parameters:
  • model_input (list[BaseResource])

  • slice_index (int)

  • axis (Literal['axial', 'sagittal', 'coronal'])

  • kwargs (Any)

Return type:

Sequence[Sequence[ImageSegmentation | ImageClassification]]

Task type enumeration for Datamint MLflow models.

class datamint.mlflow.flavors.task_type.TaskType(*values)

Medical-AI task categories for Datamint models.

The str mixin ensures values are JSON-serialisable in MLflow metadata without explicit .value calls.

ANOMALY_DETECTION = 'anomaly_detection'
IMAGE_CLASSIFICATION = 'image_classification'
IMAGE_SEGMENTATION = 'image_segmentation'
INSTANCE_SEGMENTATION = 'instance_segmentation'
LANDMARK_DETECTION = 'landmark_detection'
MULTILABEL_IMAGE_CLASSIFICATION = 'multilabel_image_classification'
OBJECT_DETECTION = 'object_detection'
REPORT_GENERATION = 'report_generation'
VIDEO_FRAME_CLASSIFICATION = 'video_frame_classification'
VIDEO_SEGMENTATION = 'video_segmentation'
VOLUME_CLASSIFICATION = 'volume_classification'
VOLUME_SEGMENTATION = 'volume_segmentation'

Registry-driven prediction dispatcher and @prediction_mode decorator.

Replaces the hardcoded mode_param_keys dict, stub-method pattern, and fragile _is_mode_implemented introspection previously in DatamintModel.

class datamint.mlflow.flavors.prediction_router.ModeSpec(mode, param_keys=(), fallback_to_default=True)

Metadata for a registered prediction mode handler.

Parameters:
  • mode (PredictionMode)

  • param_keys (tuple[str, ...])

  • fallback_to_default (bool)

fallback_to_default: bool
mode: PredictionMode
param_keys: tuple[str, ...]
class datamint.mlflow.flavors.prediction_router.PredictionRouter(model_instance, base_class=None)

Registry-driven prediction dispatcher.

Discovers mode handlers in two ways (in order of priority):

  1. Methods decorated with @prediction_mode.

  2. Convention-named methods predict_<mode_value> not defined on the abstract base (backward compatibility with old-style overrides).

Parameters:
  • model_instance (Any)

  • base_class (type | None)

static discover(model, base_class)

Build the mode -> handler registry from the model instance.

Parameters:
  • model (Any)

  • base_class (type | None)

Return type:

dict[PredictionMode, tuple[Callable, ModeSpec]]

dispatch(model_input, params)
Parameters:
  • model_input (list)

  • params (dict[str, Any])

Return type:

list

supported_modes()
Return type:

list[str]

update_registry(model_instance, base_class=None, overwrite=False)

Update the registry with handlers from a new model instance (e.g. linked model).

Parameters:
  • model_instance (Any)

  • base_class (type | None)

  • overwrite (bool)

Return type:

None

datamint.mlflow.flavors.prediction_router.prediction_mode(mode, *, param_keys=(), fallback_to_default=True)

Decorator that registers a method as a prediction mode handler.

Usage:

class MyModel(DatamintModel):
    @prediction_mode(PredictionMode.SLICE, param_keys=("slice_index", "axis"))
    def predict_slice(self, model_input, *, slice_index, axis="axial", **kw):
        ...
Parameters:
  • mode (PredictionMode)

  • param_keys (tuple[str, ...])

  • fallback_to_default (bool)

Return type:

Callable

Prediction mode enumeration for DataMint models.

class datamint.mlflow.flavors.prediction_modes.PredictionMode(*values)

Enumeration of supported prediction modes.

Each mode corresponds to a specific method signature in DatamintModel.

ALL_FRAMES = 'all_frames'
DEFAULT = 'default'
FEW_SHOT = 'few_shot'
FRAME = 'frame'
FRAME_RANGE = 'frame_range'
IMAGE = 'image'
INTERACTIVE = 'interactive'
PRIMARY_SLICE = 'primary_slice'
SLICE = 'slice'
SLICE_RANGE = 'slice_range'
TEMPORAL_SEQUENCE = 'temporal_sequence'
VOLUME = 'volume'
datamint.mlflow.flavors.datamint_flavor.load_model(model_uri, device=None)
Parameters:
  • model_uri (str)

  • device (str | None)

Return type:

DatamintModel

datamint.mlflow.flavors.datamint_flavor.log_model(datamint_model, task_type=None, supported_modes=None, name='datamint_model', data_path=None, code_paths=None, infer_code_paths=False, artifacts=None, registered_model_name=None, signature=None, input_example=None, pip_requirements=None, extra_pip_requirements=None, metadata=None, model_config=None, **kwargs)
Parameters:
  • datamint_model (BaseDatamintModel)

  • task_type (TaskType | str | None)

  • supported_modes (Sequence[str] | None)

  • name (str)

  • registered_model_name (str | None)

  • signature (ModelSignature | None)

  • input_example (DataFrame | ndarray | dict | list | csr_matrix | csc_matrix | str | bytes | tuple | None)

datamint.mlflow.flavors.datamint_flavor.save_model(datamint_model, path, task_type=None, supported_modes=None, data_path=None, code_paths=None, infer_code_paths=False, conda_env=None, mlflow_model=None, artifacts=None, signature=None, input_example=None, pip_requirements=None, extra_pip_requirements=None, metadata=None, model_config=None, streamable=None, **kwargs)
Parameters:
  • datamint_model (BaseDatamintModel)

  • task_type (TaskType | str | None)

  • supported_modes (Sequence[str] | None)

  • mlflow_model (Model | None)

  • signature (ModelSignature | None)

  • input_example (DataFrame | ndarray | dict | list | csr_matrix | csc_matrix | str | bytes | tuple | None)

Checkpointing

The MLflowModelCheckpoint extends Lightning’s ModelCheckpoint to automatically log checkpoints as MLflow artifacts.

MLflow Tracking

datamint.mlflow.tracking.fluent.set_project(project)

Set the active project for the current session.

Parameters:

project (Project | str) – The Project instance or project name/ID to set as active.

class datamint.mlflow.tracking.default_experiment.DatamintExperimentProvider
get_experiment_id()

Provide the MLflow Experiment ID for the current MLflow client context.

Assumes that in_context() is True.

Returns:

The ID of the MLflow Experiment associated with the current context.

in_context()

Determine if the MLflow client is running in a context where this provider can identify an associated MLflow Experiment ID.

Returns:

True if the MLflow client is running in a context where the provider can identify an associated MLflow Experiment ID. False otherwise.

class datamint.mlflow.tracking.datamint_store.DatamintStore(store_uri, artifact_uri=None, force_valid=True)

DatamintStore is a subclass of RestStore that provides a tracking store implementation for Datamint.

Parameters:

store_uri (str)

create_experiment(name, artifact_location=None, tags=None, project_id=None)

Create a new experiment. If an experiment with the given name already exists, throws exception.

Parameters:
  • name – Desired name for an experiment.

  • artifact_location – Location to store run artifacts.

  • tags – A list of mlflow.entities.ExperimentTag instances to set for the experiment.

  • project_id (str | None)

Return type:

str

Returns:

experiment_id for the newly created experiment if successful, else None

get_experiment_by_name(experiment_name, project_id=None)

Fetch the experiment by name from the backend store.

Parameters:
  • experiment_name – Name of experiment

  • project_id (str | None)

Returns:

A single mlflow.entities.Experiment object if it exists.

Artifact Repository

class datamint.mlflow.artifact.datamint_artifacts_repo.DatamintArtifactsRepository(artifact_uri, tracking_uri=None, registry_uri=None)
Parameters:
  • artifact_uri (str)

  • tracking_uri (str | None)

  • registry_uri (str | None)

classmethod resolve_uri(artifact_uri, tracking_uri)

Usage Example

import lightning as L
from datamint.dataset import ImageDataset
from datamint.lightning import DatamintDataModule, UNetPPTrainer

# Create dataset and datamodule
dataset = ImageDataset(project="Liver Segmentation")
datamodule = DatamintDataModule(
    dataset,
    batch_size=16,
    split={'train': 0.8, 'val': 0.1, 'test': 0.1},
)

# Use trainer with automatic MLflow integration
trainer = UNetPPTrainer(
    project="Liver Segmentation",
    image_size=256,
    batch_size=16,
    max_epochs=50,
    accelerator="gpu",
    register_model=True,  # Auto-register in MLflow
)

results = trainer.fit()

# Model is automatically logged and registered
print(results["test_results"])

import mlflow
mlflow.get_experiment_by_name("Liver Segmentation")  # Verify experiment exists
mlflow.search_runs(experiment_names=["Liver Segmentation"])  # View runs