Skip to content

ArrayDataset

Custom Dataset class for JAX arrays based on PyTorch's Dataset.

Source code in aimz/utils/data/array_dataset.py
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
class ArrayDataset(Dataset):
    """Custom Dataset class for JAX arrays based on PyTorch's Dataset."""

    def __init__(self, *arrays: ArrayLike) -> None:
        """Initialize an ArrayDataset instance.

        Args:
            *arrays: One or more JAX arrays or compatible array-like objects.

        Raises:
            ValueError: If no arrays are provided or if the arrays do not have the same
                length.
        """
        if not arrays:
            msg = "At least one array must be provided."
            raise ValueError(msg)
        length = len(arrays[0])
        if any(len(arr) != length for arr in arrays):
            msg = "All arrays must have the same length."
            raise ValueError(msg)
        # Convert inputs to NumPy arrays for efficient CPU-based batching and slicing.
        # Keeping data as NumPy arrays until batch collation speeds up data loading
        # and reduces overhead from host-to-device data transfers.
        self.arrays = tuple(np.asarray(arr) for arr in arrays)

    def __len__(self) -> int:
        """Get the number of samples in the dataset.

        Returns:
            The number of samples, equal to the length of the first array.
        """
        # Since all arrays have the same length, return the length of the first one
        return len(self.arrays[0])

    def __getitem__(self, index: int) -> object:
        """Retrieve the elements at the specified index.

        Args:
            index: Index of the item to retrieve.

        Returns:
            A pytree containing the elements at the given index from each array.
        """
        return tree.map(lambda x: x[index], self.arrays)

__init__

__init__(*arrays: ArrayLike) -> None

Parameters:

Name Type Description Default
*arrays ArrayLike

One or more JAX arrays or compatible array-like objects.

()

Raises:

Type Description
ValueError

If no arrays are provided or if the arrays do not have the same length.

Source code in aimz/utils/data/array_dataset.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(self, *arrays: ArrayLike) -> None:
    """Initialize an ArrayDataset instance.

    Args:
        *arrays: One or more JAX arrays or compatible array-like objects.

    Raises:
        ValueError: If no arrays are provided or if the arrays do not have the same
            length.
    """
    if not arrays:
        msg = "At least one array must be provided."
        raise ValueError(msg)
    length = len(arrays[0])
    if any(len(arr) != length for arr in arrays):
        msg = "All arrays must have the same length."
        raise ValueError(msg)
    # Convert inputs to NumPy arrays for efficient CPU-based batching and slicing.
    # Keeping data as NumPy arrays until batch collation speeds up data loading
    # and reduces overhead from host-to-device data transfers.
    self.arrays = tuple(np.asarray(arr) for arr in arrays)

__len__

__len__() -> int

Get the number of samples in the dataset.

Returns:

Type Description
int

The number of samples, equal to the length of the first array.

Source code in aimz/utils/data/array_dataset.py
48
49
50
51
52
53
54
55
def __len__(self) -> int:
    """Get the number of samples in the dataset.

    Returns:
        The number of samples, equal to the length of the first array.
    """
    # Since all arrays have the same length, return the length of the first one
    return len(self.arrays[0])

__getitem__

__getitem__(index: int) -> object

Retrieve the elements at the specified index.

Parameters:

Name Type Description Default
index int

Index of the item to retrieve.

required

Returns:

Type Description
object

A pytree containing the elements at the given index from each array.

Source code in aimz/utils/data/array_dataset.py
57
58
59
60
61
62
63
64
65
66
def __getitem__(self, index: int) -> object:
    """Retrieve the elements at the specified index.

    Args:
        index: Index of the item to retrieve.

    Returns:
        A pytree containing the elements at the given index from each array.
    """
    return tree.map(lambda x: x[index], self.arrays)