MLflow Integration
The datamint.mlflow module provides seamless MLflow integration for experiment
tracking, model logging, and model deployment.
MLflow Components
Overview
Key features:
Automatic Experiment Tracking: Experiments are automatically associated with Datamint projects
Model Registration: Trained models are registered in MLflow and can be deployed to Datamint
Dataset Versioning: Datasets are logged as MLflow artifacts with metadata
Custom Checkpointing: MLflowModelCheckpoint integrates with Lightning callbacks
Flavor Support: Custom Datamint flavors for segmentation and classification models
Automatic Configuration
The MLflow module auto-configures itself on first import:
# This automatically sets up MLflow environment
import datamint.mlflow
# Or explicitly:
from datamint.mlflow import ensure_mlflow_configured
ensure_mlflow_configured()
Environment Setup
Utility functions for automatically configuring MLflow environment variables based on Datamint configuration.
- datamint.mlflow.env_utils.ensure_mlflow_configured()
Ensure MLflow environment is properly configured. Raises ValueError if configuration is incomplete.
- Return type:
None
- datamint.mlflow.env_utils.setup_mlflow_environment(overwrite=False, set_mlflow=True)
Set up MLflow environment variables based on Datamint configuration.
- Parameters:
overwrite (
bool) – If True, overwrite existing MLflow environment variables.set_mlflow (
bool) – If True, set the MLflow tracking URI using mlflow.set_tracking_uri().
- Returns:
True if success, False otherwise.
- Return type:
bool
MLflow Dataset
MLflow Dataset adapter for Datamint project splits.
- class datamint.mlflow.data.datamint_dataset.DatamintMLflowDataset(project_id, project_name, split, resources, extra_params=None)
Bases:
DatasetMLflow Dataset wrapping a Datamint project split for lineage tracking.
- Parameters:
project_id (
str)project_name (
str)split (
str|None)resources (
Sequence[str] |Sequence[Resource])extra_params (
dict[str,Any] |None)
- property profile: Any | None
Optional summary statistics for the dataset, such as the number of rows in a table, the mean / median / std of each table column, etc.
- property schema
Optional dataset schema, such as an instance of
mlflow.types.Schemarepresenting the features and targets of the dataset.
- to_dict()
Create config dictionary for the dataset.
Subclasses should override this method to provide additional fields in the config dict, e.g., schema, profile, etc.
Returns a string dictionary containing the following fields: name, digest, source, source type.
- Return type:
dict[str,str]
The DatamintMLflowDataset wraps
a DatamintBaseDataset for MLflow artifact logging.
Model Flavors
- class datamint.mlflow.flavors.model.BaseDatamintModel(settings=None)
Bases:
PythonModel,ABCCore prediction gateway that any MLflow
PythonModelcan build on.Owns:
settings— hardware / deployment configuration.Device detection (
_detect_device/inference_device).Prediction dispatch via
PredictionRouter.Pickle-safe serialization.
Use
DatamintModelwhen you need to load linked models at serve time.Subclasses only need to implement
predict_default()(and optionally otherpredict_*hooks registered with@prediction_mode).- Parameters:
settings (
ModelSettings|dict[str,Any] |None)
- get_supported_modes()
Return the list of prediction modes supported by this model.
- Return type:
list[str]
- property inference_device: str
The device that will be used for inference.
Returns
_inference_deviceif already set, then falls back to theMLFLOW_DEFAULT_PREDICTION_DEVICEenvironment variable, then'cpu'.
- load_context(context)
Detect the inference device and move any attached torch modules to it.
Override in subclasses to perform additional loading (e.g. linked models) — but always call
super().load_context(context)first so thatinference_deviceis set before any model loading.- Parameters:
context (
PythonModelContext)- Return type:
None
- predict(model_input, params=None)
Main prediction entry point.
Routes to the appropriate handler based on
params['mode']. Do not override — implementpredict_default()(or otherpredict_*hooks) instead.- Parameters:
model_input (
list[BaseResource])params (
dict[str,Any] |None)
- Return type:
list[list[Annotation]]
- predict_default(model_input, **kwargs)
Default prediction on entire resources.
Override this in your subclass.
- Parameters:
model_input (
list[BaseResource])kwargs (
Any)
- Return type:
list[list[Annotation]]
- predict_image(model_input, **kwargs)
- Parameters:
model_input (
list[BaseResource])kwargs (
Any)
- Return type:
Sequence[Sequence[ImageSegmentation|ImageClassification]]
- predict_slice(model_input, slice_index, axis, **kwargs)
- Parameters:
model_input (
list[BaseResource])slice_index (
int)axis (
Literal['axial','sagittal','coronal'])kwargs (
Any)
- Return type:
Sequence[Sequence[ImageSegmentation|ImageClassification]]
- predict_volume(model_input, **kwargs)
- Parameters:
model_input (
list[BaseResource])kwargs (
Any)
- Return type:
list[list[Annotation]]
- property settings: ModelSettings
- class datamint.mlflow.flavors.model.DatamintModel(settings=None, mlflow_torch_models_uri=None, mlflow_models_uri=None, torch_model=None)
Bases:
BaseDatamintModelAbstract adapter for wrapping ML models to produce Datamint annotations.
Extends
BaseDatamintModelwith support for loading external (“linked”) MLflow models at serve time viaLinkedModelLoader. Subclasses only need to overridepredict_default(and optionally otherpredict_*hooks).Quick Start:
class MyModel(DatamintModel): def __init__(self): super().__init__( mlflow_models_uri={'model': 'models:/MyModel/latest'}, settings=ModelSettings(need_gpu=True), ) def predict_default(self, model_input, **kwargs): device = self.inference_device model = self.get_mlflow_models()['model'].get_raw_model().to(device) return predictions
You can also pass pre-instantiated
torch.nn.Moduleobjects directly:class MyModel(DatamintModel): def __init__(self): net = MyTorchNet() super().__init__(torch_model=net) def predict_default(self, model_input, **kwargs): net = self.get_mlflow_torch_models()['net'] return net(preprocess(model_input))
- Parameters:
settings (
ModelSettings|dict[str,Any] |None)mlflow_torch_models_uri (
dict[str,str] |None)mlflow_models_uri (
dict[str,str] |None)torch_model (
Module|None)
- LINKED_MODELS_DIR: ClassVar[str] = 'linked_models'
- get_mlflow_models()
Access loaded MLflow pyfunc models.
- Return type:
dict[str,PyFuncModel]
- get_mlflow_torch_models()
Access loaded MLflow PyTorch models.
- Return type:
dict[str,Any]
- get_pytorch_model()
- Return type:
Module|None
- load_context(context)
Detect device and load all linked MLflow models.
- Parameters:
context (
PythonModelContext)- Return type:
None
- property mlflow_models_uri: dict[str, str]
- property mlflow_torch_models_uri: dict[str, str]
- predict_slice(model_input, slice_index, axis, **kwargs)
- Parameters:
model_input (
list[BaseResource])slice_index (
int)axis (
Literal['axial','sagittal','coronal'])kwargs (
Any)
- Return type:
Sequence[Sequence[ImageSegmentation|ImageClassification]]
Task type enumeration for Datamint MLflow models.
- class datamint.mlflow.flavors.task_type.TaskType(*values)
Medical-AI task categories for Datamint models.
The
strmixin ensures values are JSON-serialisable in MLflow metadata without explicit.valuecalls.- ANOMALY_DETECTION = 'anomaly_detection'
- IMAGE_CLASSIFICATION = 'image_classification'
- IMAGE_SEGMENTATION = 'image_segmentation'
- INSTANCE_SEGMENTATION = 'instance_segmentation'
- LANDMARK_DETECTION = 'landmark_detection'
- MULTILABEL_IMAGE_CLASSIFICATION = 'multilabel_image_classification'
- OBJECT_DETECTION = 'object_detection'
- REPORT_GENERATION = 'report_generation'
- VIDEO_FRAME_CLASSIFICATION = 'video_frame_classification'
- VIDEO_SEGMENTATION = 'video_segmentation'
- VOLUME_CLASSIFICATION = 'volume_classification'
- VOLUME_SEGMENTATION = 'volume_segmentation'
Registry-driven prediction dispatcher and @prediction_mode decorator.
Replaces the hardcoded mode_param_keys dict, stub-method pattern,
and fragile _is_mode_implemented introspection previously in DatamintModel.
- class datamint.mlflow.flavors.prediction_router.ModeSpec(mode, param_keys=(), fallback_to_default=True)
Metadata for a registered prediction mode handler.
- Parameters:
mode (
PredictionMode)param_keys (
tuple[str,...])fallback_to_default (
bool)
- fallback_to_default: bool
- mode: PredictionMode
- param_keys: tuple[str, ...]
- class datamint.mlflow.flavors.prediction_router.PredictionRouter(model_instance, base_class=None)
Registry-driven prediction dispatcher.
Discovers mode handlers in two ways (in order of priority):
Methods decorated with
@prediction_mode.Convention-named methods
predict_<mode_value>not defined on the abstract base (backward compatibility with old-style overrides).
- Parameters:
model_instance (
Any)base_class (
type|None)
- static discover(model, base_class)
Build the mode -> handler registry from the model instance.
- Parameters:
model (
Any)base_class (
type|None)
- Return type:
dict[PredictionMode,tuple[Callable,ModeSpec]]
- dispatch(model_input, params)
- Parameters:
model_input (
list)params (
dict[str,Any])
- Return type:
list
- supported_modes()
- Return type:
list[str]
- update_registry(model_instance, base_class=None, overwrite=False)
Update the registry with handlers from a new model instance (e.g. linked model).
- Parameters:
model_instance (
Any)base_class (
type|None)overwrite (
bool)
- Return type:
None
- datamint.mlflow.flavors.prediction_router.prediction_mode(mode, *, param_keys=(), fallback_to_default=True)
Decorator that registers a method as a prediction mode handler.
Usage:
class MyModel(DatamintModel): @prediction_mode(PredictionMode.SLICE, param_keys=("slice_index", "axis")) def predict_slice(self, model_input, *, slice_index, axis="axial", **kw): ...
- Parameters:
mode (
PredictionMode)param_keys (
tuple[str,...])fallback_to_default (
bool)
- Return type:
Callable
Prediction mode enumeration for DataMint models.
- class datamint.mlflow.flavors.prediction_modes.PredictionMode(*values)
Enumeration of supported prediction modes.
Each mode corresponds to a specific method signature in DatamintModel.
- ALL_FRAMES = 'all_frames'
- DEFAULT = 'default'
- FEW_SHOT = 'few_shot'
- FRAME = 'frame'
- FRAME_RANGE = 'frame_range'
- IMAGE = 'image'
- INTERACTIVE = 'interactive'
- PRIMARY_SLICE = 'primary_slice'
- SLICE = 'slice'
- SLICE_RANGE = 'slice_range'
- TEMPORAL_SEQUENCE = 'temporal_sequence'
- VOLUME = 'volume'
- datamint.mlflow.flavors.datamint_flavor.load_model(model_uri, device=None)
- Parameters:
model_uri (
str)device (
str|None)
- Return type:
- datamint.mlflow.flavors.datamint_flavor.log_model(datamint_model, task_type=None, supported_modes=None, name='datamint_model', data_path=None, code_paths=None, infer_code_paths=False, artifacts=None, registered_model_name=None, signature=None, input_example=None, pip_requirements=None, extra_pip_requirements=None, metadata=None, model_config=None, **kwargs)
- Parameters:
datamint_model (
BaseDatamintModel)task_type (
TaskType|str|None)supported_modes (
Sequence[str] |None)name (
str)registered_model_name (
str|None)signature (
ModelSignature|None)input_example (
DataFrame|ndarray|dict|list| csr_matrix | csc_matrix |str|bytes|tuple|None)
- datamint.mlflow.flavors.datamint_flavor.save_model(datamint_model, path, task_type=None, supported_modes=None, data_path=None, code_paths=None, infer_code_paths=False, conda_env=None, mlflow_model=None, artifacts=None, signature=None, input_example=None, pip_requirements=None, extra_pip_requirements=None, metadata=None, model_config=None, streamable=None, **kwargs)
- Parameters:
datamint_model (
BaseDatamintModel)task_type (
TaskType|str|None)supported_modes (
Sequence[str] |None)mlflow_model (
Model|None)signature (
ModelSignature|None)input_example (
DataFrame|ndarray|dict|list| csr_matrix | csc_matrix |str|bytes|tuple|None)
Checkpointing
The MLflowModelCheckpoint
extends Lightning’s ModelCheckpoint to automatically log checkpoints as MLflow artifacts.
MLflow Tracking
- datamint.mlflow.tracking.fluent.set_project(project)
Set the active project for the current session.
- Parameters:
project (
Project|str) – The Project instance or project name/ID to set as active.
- class datamint.mlflow.tracking.default_experiment.DatamintExperimentProvider
- get_experiment_id()
Provide the MLflow Experiment ID for the current MLflow client context.
Assumes that
in_context()isTrue.- Returns:
The ID of the MLflow Experiment associated with the current context.
- in_context()
Determine if the MLflow client is running in a context where this provider can identify an associated MLflow Experiment ID.
- Returns:
True if the MLflow client is running in a context where the provider can identify an associated MLflow Experiment ID. False otherwise.
- class datamint.mlflow.tracking.datamint_store.DatamintStore(store_uri, artifact_uri=None, force_valid=True)
DatamintStore is a subclass of RestStore that provides a tracking store implementation for Datamint.
- Parameters:
store_uri (
str)
- create_experiment(name, artifact_location=None, tags=None, project_id=None)
Create a new experiment. If an experiment with the given name already exists, throws exception.
- Parameters:
name – Desired name for an experiment.
artifact_location – Location to store run artifacts.
tags – A list of
mlflow.entities.ExperimentTaginstances to set for the experiment.project_id (
str|None)
- Return type:
str- Returns:
experiment_id for the newly created experiment if successful, else None
- get_experiment_by_name(experiment_name, project_id=None)
Fetch the experiment by name from the backend store.
- Parameters:
experiment_name – Name of experiment
project_id (
str|None)
- Returns:
A single
mlflow.entities.Experimentobject if it exists.
Artifact Repository
Usage Example
import lightning as L
from datamint.dataset import ImageDataset
from datamint.lightning import DatamintDataModule, UNetPPTrainer
# Create dataset and datamodule
dataset = ImageDataset(project="Liver Segmentation")
datamodule = DatamintDataModule(
dataset,
batch_size=16,
split={'train': 0.8, 'val': 0.1, 'test': 0.1},
)
# Use trainer with automatic MLflow integration
trainer = UNetPPTrainer(
project="Liver Segmentation",
image_size=256,
batch_size=16,
max_epochs=50,
accelerator="gpu",
register_model=True, # Auto-register in MLflow
)
results = trainer.fit()
# Model is automatically logged and registered
print(results["test_results"])
import mlflow
mlflow.get_experiment_by_name("Liver Segmentation") # Verify experiment exists
mlflow.search_runs(experiment_names=["Liver Segmentation"]) # View runs