Model Flavors

class datamint.mlflow.flavors.model.BaseDatamintModel(settings=None)

Bases: PythonModel, ABC

Core prediction gateway that any MLflow PythonModel can build on.

Owns:

Use DatamintModel when you need to load linked models at serve time.

Subclasses only need to implement predict_default() (and optionally other predict_* hooks registered with @prediction_mode).

Parameters:

settings (ModelSettings | dict[str, Any] | None)

get_supported_modes()

Return the list of prediction modes supported by this model.

Return type:

list[str]

property inference_device: str

The device that will be used for inference.

Returns _inference_device if already set, then falls back to the MLFLOW_DEFAULT_PREDICTION_DEVICE environment variable, then 'cpu'.

load_context(context)

Detect the inference device and move any attached torch modules to it.

Override in subclasses to perform additional loading (e.g. linked models) — but always call super().load_context(context) first so that inference_device is set before any model loading.

Parameters:

context (PythonModelContext)

Return type:

None

predict(model_input, params=None)

Main prediction entry point.

Routes to the appropriate handler based on params['mode']. Do not override — implement predict_default() (or other predict_* hooks) instead.

Parameters:
  • model_input (list[BaseResource])

  • params (dict[str, Any] | None)

Return type:

list[list[Annotation]]

predict_default(model_input, **kwargs)

Default prediction on entire resources.

Override this in your subclass.

Parameters:
  • model_input (list[BaseResource])

  • kwargs (Any)

Return type:

list[list[Annotation]]

predict_image(model_input, **kwargs)
Parameters:
  • model_input (list[BaseResource])

  • kwargs (Any)

Return type:

Sequence[Sequence[ImageSegmentation | ImageClassification]]

predict_slice(model_input, slice_index, axis, **kwargs)
Parameters:
  • model_input (list[BaseResource])

  • slice_index (int)

  • axis (Literal['axial', 'sagittal', 'coronal'])

  • kwargs (Any)

Return type:

Sequence[Sequence[ImageSegmentation | ImageClassification]]

predict_volume(model_input, **kwargs)
Parameters:
  • model_input (list[BaseResource])

  • kwargs (Any)

Return type:

list[list[Annotation]]

property settings: ModelSettings
task_type: ClassVar[TaskType | None] = None

Semantic task category for this model class. Subclasses should override at the class body level (e.g. task_type = TaskType.IMAGE_SEGMENTATION).

class datamint.mlflow.flavors.model.DatamintModel(settings=None, mlflow_torch_models_uri=None, mlflow_models_uri=None, torch_model=None)

Bases: BaseDatamintModel

Abstract adapter for wrapping ML models to produce Datamint annotations.

Extends BaseDatamintModel with support for loading external (“linked”) MLflow models at serve time via LinkedModelLoader. Subclasses only need to override predict_default (and optionally other predict_* hooks).

Quick Start:

class MyModel(DatamintModel):
    def __init__(self):
        super().__init__(
            mlflow_models_uri={'model': 'models:/MyModel/latest'},
            settings=ModelSettings(need_gpu=True),
        )

    def predict_default(self, model_input, **kwargs):
        device = self.inference_device
        model = self.get_mlflow_models()['model'].get_raw_model().to(device)
        return predictions

You can also pass pre-instantiated torch.nn.Module objects directly:

class MyModel(DatamintModel):
    def __init__(self):
        net = MyTorchNet()
        super().__init__(torch_model=net)

    def predict_default(self, model_input, **kwargs):
        net = self.get_mlflow_torch_models()['net']
        return net(preprocess(model_input))
Parameters:
  • settings (ModelSettings | dict[str, Any] | None)

  • mlflow_torch_models_uri (dict[str, str] | None)

  • mlflow_models_uri (dict[str, str] | None)

  • torch_model (Module | None)

LINKED_MODELS_DIR: ClassVar[str] = 'linked_models'
get_mlflow_models()

Access loaded MLflow pyfunc models.

Return type:

dict[str, PyFuncModel]

get_mlflow_torch_models()

Access loaded MLflow PyTorch models.

Return type:

dict[str, Any]

get_pytorch_model()
Return type:

Module | None

load_context(context)

Detect device and load all linked MLflow models.

Parameters:

context (PythonModelContext)

Return type:

None

property mlflow_models_uri: dict[str, str]
property mlflow_torch_models_uri: dict[str, str]
predict_slice(model_input, slice_index, axis, **kwargs)
Parameters:
  • model_input (list[BaseResource])

  • slice_index (int)

  • axis (Literal['axial', 'sagittal', 'coronal'])

  • kwargs (Any)

Return type:

Sequence[Sequence[ImageSegmentation | ImageClassification]]

class datamint.mlflow.flavors.model.ModelSettings(need_gpu=False)

Bases: object

Deployment and inference configuration for DatamintModel.

These settings are serialized with the model and used by remote MLflow servers to properly configure the runtime environment.

Parameters:

need_gpu (bool)

classmethod from_dict(data)

Create config from dictionary, raising error on unknown keys.

Parameters:

data (dict[str, Any])

Return type:

ModelSettings

need_gpu: bool = False

Whether GPU is required for inference