Skip to content

ArrayLoader

Custom DataLoader class for JAX arrays based on PyTorch's DataLoader.

Source code in aimz/utils/data/array_loader.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
class ArrayLoader(DataLoader):
    """Custom DataLoader class for JAX arrays based on PyTorch's DataLoader."""

    def __init__(
        self,
        dataset: ArrayDataset,
        *,
        batch_size: int = 1,
        shuffle: bool = False,
        sampler: "Sampler | None" = None,
        num_workers: int = 0,
        collate_fn: "Callable | None" = None,
        pin_memory: bool = False,
        drop_last: bool = False,
    ) -> None:
        """Initializes an ArrayLoader instance."""
        super().__init__(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle,
            sampler=sampler,
            num_workers=num_workers,
            collate_fn=collate_fn,
            pin_memory=pin_memory,
            drop_last=drop_last,
        )

    @staticmethod
    def calculate_padding(batch_size: int, num_devices: int) -> int:
        """Calculate the number of padding needed.

        Args:
            batch_size (int): The size of the batch.
            num_devices (int): The number of devices.

        Returns:
            int: The number of padding rows (or elements) needed to make the batch size
                divisible by the number of devices.
        """
        remainder = batch_size % num_devices
        return 0 if remainder == 0 else num_devices - remainder

    @staticmethod
    def pad_array(x: ArrayLike, n_pad: int, axis: int) -> Array:
        """Pad an array to ensure compatibility with sharding.

        Args:
            x (ArrayLike): The input array to be padded.
            n_pad (int): The number of padding elements to add.
            axis (int): The axis along which to apply the padding.

        Returns:
            Array: The padded array with padding applied along the specified axis.

        Raises:
            ValueError: If padding is requested along an unsupported axis for a 1D
                array.
        """
        if x.ndim == 1:
            if axis == 0:
                return jnp.pad(x, pad_width=(0, n_pad), mode="edge")
            msg = "Padding 1D arrays is only supported along axis 0."
            raise ValueError(msg)

        # Initialize all axes with no padding
        pad_width: list[tuple[int, int]] = [(0, 0)] * x.ndim
        # Apply padding to the specified axis
        pad_width[axis] = (0, n_pad)

        return jnp.pad(x, pad_width=pad_width, mode="edge")

    @staticmethod
    def collate_without_output(
        batch: list[tuple],
        device: "NamedSharding | None" = None,
    ) -> tuple:
        """Collate function to process batches with sharding and padding.

        This function unpacks the batch of data, converts it into JAX arrays, and
        applies padding to ensure the batch size is compatible with the number of
        devices, if sharding is necessary. When a device is provided, the data is
        automatically distributed across the available devices.

        Args:
            batch (list[tuple]): A list of tuples, where each tuple contains the input
                data, optional target data, and array-like keyword arguments.
            device (NamedSharding | None, optional): Sharding using named axes for
                parallel data distribution across devices. Defaults to `None`, meaning
                no sharding is applied.

        Returns:
            tuple: A tuple containing:
                - n_pad (int): The number of padding rows/elements added (0 if no
                    padding was required).
                - x_batch (Array): The input batch with padding applied if necessary.
                - kwargs_batch (list[Array]): A list of keyword arguments with
                    padding applied if necessary.
        """
        x_batch, *kwargs_batch = map(jnp.asarray, zip(*batch, strict=True))

        n_pad = (
            ArrayLoader.calculate_padding(
                len(x_batch),
                num_devices=device.num_devices,
            )
            if device
            else 0
        )
        if n_pad:
            x_batch = ArrayLoader.pad_array(x_batch, n_pad=n_pad, axis=0)
            kwargs_batch = [
                ArrayLoader.pad_array(x, n_pad=n_pad, axis=0) for x in kwargs_batch
            ]

        if device:
            x_batch = device_put(x_batch, device=device)
            kwargs_batch = [device_put(x, device=device) for x in kwargs_batch]

        return n_pad, x_batch, *kwargs_batch

    @staticmethod
    def collate_with_sharding(
        batch: list[tuple],
        device: "NamedSharding | None" = None,
    ) -> tuple:
        """Collate function to process batches with sharding and padding.

        This function unpacks the batch of data, converts it into JAX arrays, and
        applies padding to ensure the batch size is compatible with the number of
        devices, if sharding is necessary. When a device is provided, the data is
        automatically distributed across the available devices.

        Args:
            batch (list[tuple]): A list of tuples, where each tuple contains the input
                data, optional target data, and array-like keyword arguments.
            device (NamedSharding | None, optional): Sharding using named axes for
                parallel data distribution across devices. Defaults to `None`, meaning
                no sharding is applied.

        Returns:
            tuple: A tuple containing:
                - n_pad (int): The number of padding rows/elements added (0 if no
                    padding was required).
                - x_batch (Array): The input batch with padding applied if necessary.
                - y_batch (Array): The target batch with padding applied.
                - kwargs_batch (list[Array]): A list of keyword arguments with padding
                    applied if necessary.
        """
        x_batch, y_batch, *kwargs_batch = map(jnp.asarray, zip(*batch, strict=True))

        n_pad = (
            ArrayLoader.calculate_padding(
                len(x_batch),
                num_devices=device.num_devices,
            )
            if device
            else 0
        )
        if n_pad:
            x_batch = ArrayLoader.pad_array(x_batch, n_pad=n_pad, axis=0)
            y_batch = ArrayLoader.pad_array(y_batch, n_pad=n_pad, axis=0)
            kwargs_batch = [
                ArrayLoader.pad_array(x, n_pad=n_pad, axis=0) for x in kwargs_batch
            ]

        if device:
            x_batch = device_put(x_batch, device=device)
            y_batch = device_put(y_batch, device=device)
            kwargs_batch = [device_put(x, device=device) for x in kwargs_batch]

        return n_pad, x_batch, y_batch, *kwargs_batch

__init__

__init__(dataset: ArrayDataset, *, batch_size: int = 1, shuffle: bool = False, sampler: Sampler | None = None, num_workers: int = 0, collate_fn: Callable | None = None, pin_memory: bool = False, drop_last: bool = False) -> None
Source code in aimz/utils/data/array_loader.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
def __init__(
    self,
    dataset: ArrayDataset,
    *,
    batch_size: int = 1,
    shuffle: bool = False,
    sampler: "Sampler | None" = None,
    num_workers: int = 0,
    collate_fn: "Callable | None" = None,
    pin_memory: bool = False,
    drop_last: bool = False,
) -> None:
    """Initializes an ArrayLoader instance."""
    super().__init__(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=pin_memory,
        drop_last=drop_last,
    )

calculate_padding staticmethod

calculate_padding(batch_size: int, num_devices: int) -> int

Calculate the number of padding needed.

Parameters:

Name Type Description Default
batch_size int

The size of the batch.

required
num_devices int

The number of devices.

required

Returns:

Name Type Description
int int

The number of padding rows (or elements) needed to make the batch size divisible by the number of devices.

Source code in aimz/utils/data/array_loader.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
@staticmethod
def calculate_padding(batch_size: int, num_devices: int) -> int:
    """Calculate the number of padding needed.

    Args:
        batch_size (int): The size of the batch.
        num_devices (int): The number of devices.

    Returns:
        int: The number of padding rows (or elements) needed to make the batch size
            divisible by the number of devices.
    """
    remainder = batch_size % num_devices
    return 0 if remainder == 0 else num_devices - remainder

pad_array staticmethod

pad_array(x: ArrayLike, n_pad: int, axis: int) -> Array

Pad an array to ensure compatibility with sharding.

Parameters:

Name Type Description Default
x ArrayLike

The input array to be padded.

required
n_pad int

The number of padding elements to add.

required
axis int

The axis along which to apply the padding.

required

Returns:

Name Type Description
Array Array

The padded array with padding applied along the specified axis.

Raises:

Type Description
ValueError

If padding is requested along an unsupported axis for a 1D array.

Source code in aimz/utils/data/array_loader.py
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
@staticmethod
def pad_array(x: ArrayLike, n_pad: int, axis: int) -> Array:
    """Pad an array to ensure compatibility with sharding.

    Args:
        x (ArrayLike): The input array to be padded.
        n_pad (int): The number of padding elements to add.
        axis (int): The axis along which to apply the padding.

    Returns:
        Array: The padded array with padding applied along the specified axis.

    Raises:
        ValueError: If padding is requested along an unsupported axis for a 1D
            array.
    """
    if x.ndim == 1:
        if axis == 0:
            return jnp.pad(x, pad_width=(0, n_pad), mode="edge")
        msg = "Padding 1D arrays is only supported along axis 0."
        raise ValueError(msg)

    # Initialize all axes with no padding
    pad_width: list[tuple[int, int]] = [(0, 0)] * x.ndim
    # Apply padding to the specified axis
    pad_width[axis] = (0, n_pad)

    return jnp.pad(x, pad_width=pad_width, mode="edge")

collate_without_output staticmethod

collate_without_output(batch: list[tuple], device: NamedSharding | None = None) -> tuple

Collate function to process batches with sharding and padding.

This function unpacks the batch of data, converts it into JAX arrays, and applies padding to ensure the batch size is compatible with the number of devices, if sharding is necessary. When a device is provided, the data is automatically distributed across the available devices.

Parameters:

Name Type Description Default
batch list[tuple]

A list of tuples, where each tuple contains the input data, optional target data, and array-like keyword arguments.

required
device NamedSharding | None

Sharding using named axes for parallel data distribution across devices. Defaults to None, meaning no sharding is applied.

None

Returns:

Name Type Description
tuple tuple

A tuple containing: - n_pad (int): The number of padding rows/elements added (0 if no padding was required). - x_batch (Array): The input batch with padding applied if necessary. - kwargs_batch (list[Array]): A list of keyword arguments with padding applied if necessary.

Source code in aimz/utils/data/array_loader.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
@staticmethod
def collate_without_output(
    batch: list[tuple],
    device: "NamedSharding | None" = None,
) -> tuple:
    """Collate function to process batches with sharding and padding.

    This function unpacks the batch of data, converts it into JAX arrays, and
    applies padding to ensure the batch size is compatible with the number of
    devices, if sharding is necessary. When a device is provided, the data is
    automatically distributed across the available devices.

    Args:
        batch (list[tuple]): A list of tuples, where each tuple contains the input
            data, optional target data, and array-like keyword arguments.
        device (NamedSharding | None, optional): Sharding using named axes for
            parallel data distribution across devices. Defaults to `None`, meaning
            no sharding is applied.

    Returns:
        tuple: A tuple containing:
            - n_pad (int): The number of padding rows/elements added (0 if no
                padding was required).
            - x_batch (Array): The input batch with padding applied if necessary.
            - kwargs_batch (list[Array]): A list of keyword arguments with
                padding applied if necessary.
    """
    x_batch, *kwargs_batch = map(jnp.asarray, zip(*batch, strict=True))

    n_pad = (
        ArrayLoader.calculate_padding(
            len(x_batch),
            num_devices=device.num_devices,
        )
        if device
        else 0
    )
    if n_pad:
        x_batch = ArrayLoader.pad_array(x_batch, n_pad=n_pad, axis=0)
        kwargs_batch = [
            ArrayLoader.pad_array(x, n_pad=n_pad, axis=0) for x in kwargs_batch
        ]

    if device:
        x_batch = device_put(x_batch, device=device)
        kwargs_batch = [device_put(x, device=device) for x in kwargs_batch]

    return n_pad, x_batch, *kwargs_batch

collate_with_sharding staticmethod

collate_with_sharding(batch: list[tuple], device: NamedSharding | None = None) -> tuple

Collate function to process batches with sharding and padding.

This function unpacks the batch of data, converts it into JAX arrays, and applies padding to ensure the batch size is compatible with the number of devices, if sharding is necessary. When a device is provided, the data is automatically distributed across the available devices.

Parameters:

Name Type Description Default
batch list[tuple]

A list of tuples, where each tuple contains the input data, optional target data, and array-like keyword arguments.

required
device NamedSharding | None

Sharding using named axes for parallel data distribution across devices. Defaults to None, meaning no sharding is applied.

None

Returns:

Name Type Description
tuple tuple

A tuple containing: - n_pad (int): The number of padding rows/elements added (0 if no padding was required). - x_batch (Array): The input batch with padding applied if necessary. - y_batch (Array): The target batch with padding applied. - kwargs_batch (list[Array]): A list of keyword arguments with padding applied if necessary.

Source code in aimz/utils/data/array_loader.py
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
@staticmethod
def collate_with_sharding(
    batch: list[tuple],
    device: "NamedSharding | None" = None,
) -> tuple:
    """Collate function to process batches with sharding and padding.

    This function unpacks the batch of data, converts it into JAX arrays, and
    applies padding to ensure the batch size is compatible with the number of
    devices, if sharding is necessary. When a device is provided, the data is
    automatically distributed across the available devices.

    Args:
        batch (list[tuple]): A list of tuples, where each tuple contains the input
            data, optional target data, and array-like keyword arguments.
        device (NamedSharding | None, optional): Sharding using named axes for
            parallel data distribution across devices. Defaults to `None`, meaning
            no sharding is applied.

    Returns:
        tuple: A tuple containing:
            - n_pad (int): The number of padding rows/elements added (0 if no
                padding was required).
            - x_batch (Array): The input batch with padding applied if necessary.
            - y_batch (Array): The target batch with padding applied.
            - kwargs_batch (list[Array]): A list of keyword arguments with padding
                applied if necessary.
    """
    x_batch, y_batch, *kwargs_batch = map(jnp.asarray, zip(*batch, strict=True))

    n_pad = (
        ArrayLoader.calculate_padding(
            len(x_batch),
            num_devices=device.num_devices,
        )
        if device
        else 0
    )
    if n_pad:
        x_batch = ArrayLoader.pad_array(x_batch, n_pad=n_pad, axis=0)
        y_batch = ArrayLoader.pad_array(y_batch, n_pad=n_pad, axis=0)
        kwargs_batch = [
            ArrayLoader.pad_array(x, n_pad=n_pad, axis=0) for x in kwargs_batch
        ]

    if device:
        x_batch = device_put(x_batch, device=device)
        y_batch = device_put(y_batch, device=device)
        kwargs_batch = [device_put(x, device=device) for x in kwargs_batch]

    return n_pad, x_batch, y_batch, *kwargs_batch