DatamintDataModule

DatamintDataModule — LightningDataModule wrapper for Datamint datasets.

Wraps any DatamintBaseDataset subclass and provides train_dataloader, val_dataloader, test_dataloader, and predict_dataloader for use with a Lightning Trainer.

class datamint.lightning.datamodule.DatamintDataModule(dataset, batch_size=32, train_batch_size=None, val_batch_size=None, test_batch_size=None, num_workers=0, pin_memory=True, shuffle_train=True, drop_last_train=False, split=True, split_seed=None, split_as_of_timestamp=None, use_server_splits=None, train_transform=None, eval_transform=None)

Bases: LightningDataModule

A LightningDataModule that wraps a DatamintBaseDataset.

The dataset must already be fully constructed (project loaded, filters applied). Splitting is delegated to split(). Stage-specific transforms are applied to each split after splitting.

Parameters:
  • dataset (DatamintBaseDataset) – A fully initialised Datamint dataset (without transforms; those are applied per-split via train_transform / eval_transform).

  • batch_size (int) – Default batch size for every stage.

  • train_batch_size (int | None) – Override batch size for training.

  • val_batch_size (int | None) – Override batch size for validation.

  • test_batch_size (int | None) – Override batch size for testing.

  • num_workers (int) – Number of DataLoader workers.

  • pin_memory (bool) – Whether to pin memory in DataLoaders.

  • shuffle_train (bool) – Shuffle the training dataloader.

  • drop_last_train (bool) – Drop last incomplete training batch.

  • split (dict[str, float] | bool | None) – Split ratios forwarded to DatamintBaseDataset.split() (e.g. {'train': 0.7, 'val': 0.15, 'test': 0.15}). When None the full dataset is used for every stage.

  • split_seed (int | None) – Random seed for reproducible local splits.

  • split_as_of_timestamp (str | None) – Historical timestamp forwarded to DatamintBaseDataset.split() when reusing project-scoped split assignments.

  • use_server_splits (bool | None) – If True, use server-side split:* tags instead of local random splitting.

  • train_transform (Callable | None) – Albumentations transform applied only to the training split (e.g. augmentations). Calls set_transform() on the train split after setup() resolves the splits.

  • eval_transform (Callable | None) – Albumentations transform applied to the validation and test splits (typically resize/normalise only, no augmentation).

Example:

import albumentations as A

train_tfm = A.Compose([A.RandomHorizontalFlip(), A.Normalize()])
eval_tfm  = A.Compose([A.Normalize()])

dataset = ImageDataset(project='my_project', ...)
dm = DatamintDataModule(
    dataset,
    batch_size=8,
    split={'train': 0.8, 'val': 0.1, 'test': 0.1},
    split_seed=42,
    train_transform=train_tfm,
    eval_transform=eval_tfm,
)

# prepare_data() / setup() fetch data and make attributes available:
trainer = L.Trainer(...)
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
get_dataset_split(split)

Return the Datamint dataset for the given split. Falls back to the full dataset when the requested split is not available.

Parameters:

split (str)

Return type:

DatamintBaseDataset | None

get_mlflow_dataset()

Return an MLflow dataset for the full dataset (without split context).

get_mlflow_dataset_split(split)

Return a DatamintMLflowDataset for the given split.

Delegates to the corresponding split dataset’s build_mlflow_dataset(). Falls back to the full dataset when the requested split is not available.

Parameters:

split (str) – One of 'train', 'val', or 'test'.

Return type:

DatamintMLflowDataset | None

property has_val_split: bool

Return True if a validation split is expected to be available.

After setup() has run, reflects the actual resolved state. Before that, infers from the split configuration.

prepare_data()

Use this to download and prepare data. Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures this method is called only within a single process, so you can safely add your downloading logic within.

Warning

DO NOT set state to the model (use setup instead) since this is NOT called on every device

Example:

def prepare_data(self):
    # good
    download_data()
    tokenize()
    etc()

    # bad
    self.split = data_split
    self.some_state = some_other_state()

In a distributed environment, prepare_data can be called in two ways (using prepare_data_per_node)

  1. Once per node. This is the default and is only called on LOCAL_RANK=0.

  2. Once in total. Only called on GLOBAL_RANK=0.

Example:

# DEFAULT
# called once per node on LOCAL_RANK=0 of that node
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = True


# call on GLOBAL_RANK=0 (great for shared file systems)
class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = False

This is called before requesting the dataloaders:

model.prepare_data()
initialize_distributed()
model.setup(stage)
model.train_dataloader()
model.val_dataloader()
model.test_dataloader()
model.predict_dataloader()
Return type:

None

setup(stage=None)

Called at the beginning of fit (train + validate), validate, test, or predict. This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:

stage (str | None) – either 'fit', 'validate', 'test', or 'predict'

Return type:

None

Example:

class LitModel(...):
    def __init__(self):
        self.l1 = None

    def prepare_data(self):
        download_data()
        tokenize()

        # don't do this
        self.something = else

    def setup(self, stage):
        data = load_data(...)
        self.l1 = nn.Linear(28, data.num_classes)
test_dataloader()

An iterable or collection of iterables specifying test samples.

For more information about multiple dataloaders, see this section.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

  • test()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Note

If you don’t need a test dataset and a test_step(), you don’t need to implement this method.

Return type:

DataLoader

train_dataloader()

An iterable or collection of iterables specifying training samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

For data processing use the following pattern:

  • download in prepare_data()

  • process and split in setup()

However, the above are only necessary for distributed processing.

Warning

do not assign state in prepare_data

  • fit()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware. There is no need to set it yourself.

Return type:

DataLoader

val_dataloader()

An iterable or collection of iterables specifying validation samples.

For more information about multiple dataloaders, see this section.

The dataloader you return will not be reloaded unless you set :paramref:`~lightning.pytorch.trainer.trainer.Trainer.reload_dataloaders_every_n_epochs` to a positive integer.

It’s recommended that all data downloads and preparation happen in prepare_data().

  • fit()

  • validate()

  • prepare_data()

  • setup()

Note

Lightning tries to add the correct sampler for distributed and arbitrary hardware There is no need to set it yourself.

Note

If you don’t need a validation dataset and a validation_step(), you don’t need to implement this method.