holisticai.datasets.DataLoader#

class holisticai.datasets.DataLoader(dataset: Dataset, batch_size: int, dtype: Literal['jax', 'pandas', 'numpy'])[source]#

A class that represents a data loader for a dataset. This class is used to load the dataset in batches in a specific data type (jax, pandas, or numpy).

Parameters

dataset: Dataset

The dataset to load.

batch_size: int

The size of the batch.

dtype: Literal[“jax”, “pandas”, “numpy”]

The data type to load the dataset in.

Example

>>> from holisticai.datasets import load_dataset
>>> dataset = load_dataset("adult")
>>> dataloader = DataLoader(dataset, batch_size=32, dtype="jax")
>>> for batch in dataloader:
...     print(batch)