Checkpointing
- class datamint.mlflow.lightning.callbacks.modelcheckpoint.MLFlowDatamintModelCheckpoint(*args, register_model_name=None, register_model_on='test', code_paths=None, log_model_at_end_only=True, additional_metadata=None, extra_pip_requirements=None, log_model_metrics=True, **kwargs)
Bases:
_BaseMLFlowModelCheckpointMLflow model checkpoint for
BaseDatamintModel-based Lightning modules.Logs models using the datamint custom flavor (which wraps
mlflow.pyfunc). Signature inference is delegated to the datamint flavor viapredict_type_hints, so no forward-wrapping is performed.- Parameters:
register_model_name (
str|None)register_model_on (
Literal['train','val','test','predict'])code_paths (
list[str] |None)log_model_at_end_only (
bool)additional_metadata (
dict[str,Any] |None)extra_pip_requirements (
list[str] |None)log_model_metrics (
bool)
- log_model_to_mlflow(model, run_id)
Log the model to MLflow using the datamint flavor.
- Parameters:
model (
Module|LightningModule|BaseDatamintModel)run_id (
str|MLFlowLogger)
- Return type:
None
- datamint.mlflow.lightning.callbacks.modelcheckpoint.MLFlowModelCheckpoint
alias of
MLFlowPyTorchModelCheckpoint
- class datamint.mlflow.lightning.callbacks.modelcheckpoint.MLFlowPyTorchModelCheckpoint(*args, register_model_name=None, register_model_on='test', code_paths=None, log_model_at_end_only=True, additional_metadata=None, extra_pip_requirements=None, log_model_metrics=True, **kwargs)
Bases:
_BaseMLFlowModelCheckpointMLflow model checkpoint for standard PyTorch Lightning modules.
Logs models using
mlflow.pytorch.log_model()and infers the MLflow model signature by intercepting the first call topl_module.forward.- Parameters:
register_model_name (
str|None)register_model_on (
Literal['train','val','test','predict'])code_paths (
list[str] |None)log_model_at_end_only (
bool)additional_metadata (
dict[str,Any] |None)extra_pip_requirements (
list[str] |None)log_model_metrics (
bool)
- log_model_to_mlflow(model, run_id)
Log the model to MLflow using the pytorch flavor.
- Parameters:
model (
Module|LightningModule|BaseDatamintModel)run_id (
str|MLFlowLogger)