combined_loader¶
Classes
Combines different iterables under specific sampling modes. |
- class lightning.pytorch.utilities.combined_loader.CombinedLoader(iterables, mode='min_size')[source]¶
Bases:
IterableCombines different iterables under specific sampling modes.
- Parameters:
iterables¶ (
Any) – the iterable or collection of iterables to sample from.mode¶ (
Literal['min_size','max_size_cycle','max_size','sequential']) –the mode to use. The following modes are supported:
min_size: stops after the shortest iterable (the one with the lowest number of items) is done.max_size_cycle: stops after the longest iterable (the one with most items) is done, while cycling through the rest of the iterables.max_size: stops after the longest iterable (the one with most items) is done, while returning None for the exhausted iterables.sequential: completely consumes each iterable sequentially, and returns a triplet(data, idx, iterable_idx)
Examples
>>> from torch.utils.data import DataLoader >>> iterables = {'a': DataLoader(range(6), batch_size=4), ... 'b': DataLoader(range(15), batch_size=5)} >>> combined_loader = CombinedLoader(iterables, 'max_size_cycle') >>> _ = iter(combined_loader) >>> len(combined_loader) 3 >>> for batch, batch_idx, dataloader_idx in combined_loader: ... print(f"{batch}, {batch_idx=}, {dataloader_idx=}") {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}, batch_idx=0, dataloader_idx=0 {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}, batch_idx=1, dataloader_idx=0 {'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])}, batch_idx=2, dataloader_idx=0
>>> combined_loader = CombinedLoader(iterables, 'max_size') >>> _ = iter(combined_loader) >>> len(combined_loader) 3 >>> for batch, batch_idx, dataloader_idx in combined_loader: ... print(f"{batch}, {batch_idx=}, {dataloader_idx=}") {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}, batch_idx=0, dataloader_idx=0 {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}, batch_idx=1, dataloader_idx=0 {'a': None, 'b': tensor([10, 11, 12, 13, 14])}, batch_idx=2, dataloader_idx=0
>>> combined_loader = CombinedLoader(iterables, 'min_size') >>> _ = iter(combined_loader) >>> len(combined_loader) 2 >>> for batch, batch_idx, dataloader_idx in combined_loader: ... print(f"{batch}, {batch_idx=}, {dataloader_idx=}") {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}, batch_idx=0, dataloader_idx=0 {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}, batch_idx=1, dataloader_idx=0
>>> combined_loader = CombinedLoader(iterables, 'sequential') >>> _ = iter(combined_loader) >>> len(combined_loader) 5 >>> for batch, batch_idx, dataloader_idx in combined_loader: ... print(f"{batch}, {batch_idx=}, {dataloader_idx=}") tensor([0, 1, 2, 3]), batch_idx=0, dataloader_idx=0 tensor([4, 5]), batch_idx=1, dataloader_idx=0 tensor([0, 1, 2, 3, 4]), batch_idx=0, dataloader_idx=1 tensor([5, 6, 7, 8, 9]), batch_idx=1, dataloader_idx=1 tensor([10, 11, 12, 13, 14]), batch_idx=2, dataloader_idx=1