Model Flavors
- class datamint.mlflow.flavors.model.BaseDatamintModel(settings=None)
Bases:
PythonModel,ABCCore prediction gateway that any MLflow
PythonModelcan build on.Owns:
settings— hardware / deployment configuration.Device detection (
_detect_device/inference_device).Prediction dispatch via
PredictionRouter.Pickle-safe serialization.
Use
DatamintModelwhen you need to load linked models at serve time.Subclasses only need to implement
predict_default()(and optionally otherpredict_*hooks registered with@prediction_mode).- Parameters:
settings (
ModelSettings|dict[str,Any] |None)
- get_supported_modes()
Return the list of prediction modes supported by this model.
- Return type:
list[str]
- property inference_device: str
The device that will be used for inference.
Returns
_inference_deviceif already set, then falls back to theMLFLOW_DEFAULT_PREDICTION_DEVICEenvironment variable, then'cpu'.
- load_context(context)
Detect the inference device and move any attached torch modules to it.
Override in subclasses to perform additional loading (e.g. linked models) — but always call
super().load_context(context)first so thatinference_deviceis set before any model loading.- Parameters:
context (
PythonModelContext)- Return type:
None
- predict(model_input, params=None)
Main prediction entry point.
Routes to the appropriate handler based on
params['mode']. Do not override — implementpredict_default()(or otherpredict_*hooks) instead.- Parameters:
model_input (
list[BaseResource])params (
dict[str,Any] |None)
- Return type:
list[list[Annotation]]
- predict_default(model_input, **kwargs)
Default prediction on entire resources.
Override this in your subclass.
- Parameters:
model_input (
list[BaseResource])kwargs (
Any)
- Return type:
list[list[Annotation]]
- predict_image(model_input, **kwargs)
- Parameters:
model_input (
list[BaseResource])kwargs (
Any)
- Return type:
Sequence[Sequence[ImageSegmentation|ImageClassification]]
- predict_slice(model_input, slice_index, axis, **kwargs)
- Parameters:
model_input (
list[BaseResource])slice_index (
int)axis (
Literal['axial','sagittal','coronal'])kwargs (
Any)
- Return type:
Sequence[Sequence[ImageSegmentation|ImageClassification]]
- predict_volume(model_input, **kwargs)
- Parameters:
model_input (
list[BaseResource])kwargs (
Any)
- Return type:
list[list[Annotation]]
- property settings: ModelSettings
- task_type: ClassVar[TaskType | None] = None
Semantic task category for this model class. Subclasses should override at the class body level (e.g.
task_type = TaskType.IMAGE_SEGMENTATION).
- class datamint.mlflow.flavors.model.DatamintModel(settings=None, mlflow_torch_models_uri=None, mlflow_models_uri=None, torch_model=None)
Bases:
BaseDatamintModelAbstract adapter for wrapping ML models to produce Datamint annotations.
Extends
BaseDatamintModelwith support for loading external (“linked”) MLflow models at serve time viaLinkedModelLoader. Subclasses only need to overridepredict_default(and optionally otherpredict_*hooks).Quick Start:
class MyModel(DatamintModel): def __init__(self): super().__init__( mlflow_models_uri={'model': 'models:/MyModel/latest'}, settings=ModelSettings(need_gpu=True), ) def predict_default(self, model_input, **kwargs): device = self.inference_device model = self.get_mlflow_models()['model'].get_raw_model().to(device) return predictions
You can also pass pre-instantiated
torch.nn.Moduleobjects directly:class MyModel(DatamintModel): def __init__(self): net = MyTorchNet() super().__init__(torch_model=net) def predict_default(self, model_input, **kwargs): net = self.get_mlflow_torch_models()['net'] return net(preprocess(model_input))
- Parameters:
settings (
ModelSettings|dict[str,Any] |None)mlflow_torch_models_uri (
dict[str,str] |None)mlflow_models_uri (
dict[str,str] |None)torch_model (
Module|None)
- LINKED_MODELS_DIR: ClassVar[str] = 'linked_models'
- get_mlflow_models()
Access loaded MLflow pyfunc models.
- Return type:
dict[str,PyFuncModel]
- get_mlflow_torch_models()
Access loaded MLflow PyTorch models.
- Return type:
dict[str,Any]
- get_pytorch_model()
- Return type:
Module|None
- load_context(context)
Detect device and load all linked MLflow models.
- Parameters:
context (
PythonModelContext)- Return type:
None
- property mlflow_models_uri: dict[str, str]
- property mlflow_torch_models_uri: dict[str, str]
- predict_slice(model_input, slice_index, axis, **kwargs)
- Parameters:
model_input (
list[BaseResource])slice_index (
int)axis (
Literal['axial','sagittal','coronal'])kwargs (
Any)
- Return type:
Sequence[Sequence[ImageSegmentation|ImageClassification]]
- class datamint.mlflow.flavors.model.ModelSettings(need_gpu=False)
Bases:
objectDeployment and inference configuration for DatamintModel.
These settings are serialized with the model and used by remote MLflow servers to properly configure the runtime environment.
- Parameters:
need_gpu (
bool)
- classmethod from_dict(data)
Create config from dictionary, raising error on unknown keys.
- Parameters:
data (
dict[str,Any])- Return type:
ModelSettings
- need_gpu: bool = False
Whether GPU is required for inference