Source code for mlx_graphs.loaders.dataloaders
from typing import Sequence, Union
import numpy as np
from mlx_graphs.data import GraphData, GraphDataBatch, batch
from mlx_graphs.datasets.dataset import Dataset
[docs]
class Dataloader:
"""
Default data loader to batch and iterate over multiple graphs.
Args:
dataset: `Dataset` or list of `GraphData` to batch and iterate over
batch_size: Number of graphs to load per batch. Defaults to 1
shuffle: Whether to reshuffle the order of graphs within each batch.
Defaults to False
"""
[docs]
def __init__(
self,
dataset: Union[Dataset, Sequence[GraphData]],
batch_size: int = 1,
shuffle: bool = False,
):
self.dataset = dataset
self.batch_size = batch_size
self.shuffle = shuffle
self._indices = list(range(len(dataset)))
self._current_index = 0
def _shuffle_indices(self):
np.random.shuffle(self._indices)
def __iter__(self):
return self
def __next__(self) -> GraphDataBatch:
"""
Get the next batch of data.
Returns:
A batch of graphs.
"""
if self._current_index >= len(self.dataset):
self._current_index = 0
if self.shuffle:
self._shuffle_indices()
raise StopIteration
batch_indices = self._indices[
self._current_index : self._current_index + self.batch_size
]
batched_data = batch(
[
self.dataset[i] # type: ignore - this is a list[GraphData]
for i in batch_indices
]
)
self._current_index += self.batch_size
return batched_data