DatamintDataModule
DatamintDataModule — LightningDataModule wrapper for Datamint datasets.
Wraps any DatamintBaseDataset subclass and
provides train_dataloader, val_dataloader, test_dataloader, and
predict_dataloader for use with a Lightning Trainer.
- class datamint.lightning.datamodule.DatamintDataModule(dataset, batch_size=32, train_batch_size=None, val_batch_size=None, test_batch_size=None, num_workers=0, pin_memory=True, shuffle_train=True, drop_last_train=False, split=True, split_seed=None, split_as_of_timestamp=None, use_server_splits=None, train_transform=None, eval_transform=None)
Bases:
LightningDataModuleA
LightningDataModulethat wraps aDatamintBaseDataset.The dataset must already be fully constructed (project loaded, filters applied). Splitting is delegated to
split(). Stage-specific transforms are applied to each split after splitting.- Parameters:
dataset (
DatamintBaseDataset) – A fully initialised Datamint dataset (without transforms; those are applied per-split via train_transform / eval_transform).batch_size (
int) – Default batch size for every stage.train_batch_size (
int|None) – Override batch size for training.val_batch_size (
int|None) – Override batch size for validation.test_batch_size (
int|None) – Override batch size for testing.num_workers (
int) – Number of DataLoader workers.pin_memory (
bool) – Whether to pin memory in DataLoaders.shuffle_train (
bool) – Shuffle the training dataloader.drop_last_train (
bool) – Drop last incomplete training batch.split (
dict[str,float] |bool|None) – Split ratios forwarded toDatamintBaseDataset.split()(e.g.{'train': 0.7, 'val': 0.15, 'test': 0.15}). When None the full dataset is used for every stage.split_seed (
int|None) – Random seed for reproducible local splits.split_as_of_timestamp (
str|None) – Historical timestamp forwarded toDatamintBaseDataset.split()when reusing project-scoped split assignments.use_server_splits (
bool|None) – If True, use server-sidesplit:*tags instead of local random splitting.train_transform (
Callable|None) – Albumentations transform applied only to the training split (e.g. augmentations). Callsset_transform()on the train split aftersetup()resolves the splits.eval_transform (
Callable|None) – Albumentations transform applied to the validation and test splits (typically resize/normalise only, no augmentation).
Example:
import albumentations as A train_tfm = A.Compose([A.RandomHorizontalFlip(), A.Normalize()]) eval_tfm = A.Compose([A.Normalize()]) dataset = ImageDataset(project='my_project', ...) dm = DatamintDataModule( dataset, batch_size=8, split={'train': 0.8, 'val': 0.1, 'test': 0.1}, split_seed=42, train_transform=train_tfm, eval_transform=eval_tfm, ) # prepare_data() / setup() fetch data and make attributes available: trainer = L.Trainer(...) trainer.fit(model, datamodule=dm) trainer.test(datamodule=dm)
- get_dataset_split(split)
Return the Datamint dataset for the given split. Falls back to the full dataset when the requested split is not available.
- Parameters:
split (
str)- Return type:
DatamintBaseDataset|None
- get_mlflow_dataset()
Return an MLflow dataset for the full dataset (without split context).
- get_mlflow_dataset_split(split)
Return a
DatamintMLflowDatasetfor the given split.Delegates to the corresponding split dataset’s
build_mlflow_dataset(). Falls back to the full dataset when the requested split is not available.- Parameters:
split (
str) – One of'train','val', or'test'.- Return type:
DatamintMLflowDataset|None
- property has_val_split: bool
Return
Trueif a validation split is expected to be available.After
setup()has run, reflects the actual resolved state. Before that, infers from the split configuration.
- prepare_data()
Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.
Warning
DO NOT set state to the model (use
setupinstead) since this is NOT called on every deviceExample:
def prepare_data(self): # good download_data() tokenize() etc() # bad self.split = data_split self.some_state = some_other_state()
In a distributed environment,
prepare_datacan be called in two ways (using prepare_data_per_node)Once per node. This is the default and is only called on LOCAL_RANK=0.
Once in total. Only called on GLOBAL_RANK=0.
Example:
# DEFAULT # called once per node on LOCAL_RANK=0 of that node class LitDataModule(LightningDataModule): def __init__(self): super().__init__() self.prepare_data_per_node = True # call on GLOBAL_RANK=0 (great for shared file systems) class LitDataModule(LightningDataModule): def __init__(self): super().__init__() self.prepare_data_per_node = False
This is called before requesting the dataloaders:
model.prepare_data() initialize_distributed() model.setup(stage) model.train_dataloader() model.val_dataloader() model.test_dataloader() model.predict_dataloader()
- Return type:
None
- setup(stage=None)
Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
stage (
str|None) – either'fit','validate','test', or'predict'- Return type:
None
Example:
class LitModel(...): def __init__(self): self.l1 = None def prepare_data(self): download_data() tokenize() # don't do this self.something = else def setup(self, stage): data = load_data(...) self.l1 = nn.Linear(28, data.num_classes)
- test_dataloader()
An iterable or collection of iterables specifying test samples.
For more information about multiple dataloaders, see this section.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
test()prepare_data()setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
Note
If you don’t need a test dataset and a
test_step(), you don’t need to implement this method.- Return type:
DataLoader
- train_dataloader()
An iterable or collection of iterables specifying training samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
For data processing use the following pattern:
download in
prepare_data()process and split in
setup()
However, the above are only necessary for distributed processing.
Warning
do not assign state in prepare_data
fit()prepare_data()setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.
- Return type:
DataLoader
- val_dataloader()
An iterable or collection of iterables specifying validation samples.
For more information about multiple dataloaders, see this section.
The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.
It’s recommended that all data downloads and preparation happen in
prepare_data().fit()validate()prepare_data()setup()
Note
Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.
Note
If you don’t need a validation dataset and a
validation_step(), you don’t need to implement this method.