ArrayLoader
Custom DataLoader class for JAX arrays based on PyTorch's DataLoader.
Source code in aimz/utils/data/array_loader.py
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
|
__init__
__init__(dataset: ArrayDataset, *, batch_size: int = 1, shuffle: bool = False, sampler: Sampler | None = None, num_workers: int = 0, collate_fn: Callable | None = None, pin_memory: bool = False, drop_last: bool = False) -> None
Source code in aimz/utils/data/array_loader.py
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
|
calculate_padding
staticmethod
calculate_padding(batch_size: int, num_devices: int) -> int
Calculate the number of padding needed.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch_size
|
int
|
The size of the batch. |
required |
num_devices
|
int
|
The number of devices. |
required |
Returns:
Name | Type | Description |
---|---|---|
int |
int
|
The number of padding rows (or elements) needed to make the batch size divisible by the number of devices. |
Source code in aimz/utils/data/array_loader.py
65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
|
pad_array
staticmethod
pad_array(x: ArrayLike, n_pad: int, axis: int) -> Array
Pad an array to ensure compatibility with sharding.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
ArrayLike
|
The input array to be padded. |
required |
n_pad
|
int
|
The number of padding elements to add. |
required |
axis
|
int
|
The axis along which to apply the padding. |
required |
Returns:
Name | Type | Description |
---|---|---|
Array |
Array
|
The padded array with padding applied along the specified axis. |
Raises:
Type | Description |
---|---|
ValueError
|
If padding is requested along an unsupported axis for a 1D array. |
Source code in aimz/utils/data/array_loader.py
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
|
collate_without_output
staticmethod
collate_without_output(batch: list[tuple], device: NamedSharding | None = None) -> tuple
Collate function to process batches with sharding and padding.
This function unpacks the batch of data, converts it into JAX arrays, and applies padding to ensure the batch size is compatible with the number of devices, if sharding is necessary. When a device is provided, the data is automatically distributed across the available devices.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
list[tuple]
|
A list of tuples, where each tuple contains the input data, optional target data, and array-like keyword arguments. |
required |
device
|
NamedSharding | None
|
Sharding using named axes for
parallel data distribution across devices. Defaults to |
None
|
Returns:
Name | Type | Description |
---|---|---|
tuple |
tuple
|
A tuple containing: - n_pad (int): The number of padding rows/elements added (0 if no padding was required). - x_batch (Array): The input batch with padding applied if necessary. - kwargs_batch (list[Array]): A list of keyword arguments with padding applied if necessary. |
Source code in aimz/utils/data/array_loader.py
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
|
collate_with_sharding
staticmethod
collate_with_sharding(batch: list[tuple], device: NamedSharding | None = None) -> tuple
Collate function to process batches with sharding and padding.
This function unpacks the batch of data, converts it into JAX arrays, and applies padding to ensure the batch size is compatible with the number of devices, if sharding is necessary. When a device is provided, the data is automatically distributed across the available devices.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
list[tuple]
|
A list of tuples, where each tuple contains the input data, optional target data, and array-like keyword arguments. |
required |
device
|
NamedSharding | None
|
Sharding using named axes for
parallel data distribution across devices. Defaults to |
None
|
Returns:
Name | Type | Description |
---|---|---|
tuple |
tuple
|
A tuple containing: - n_pad (int): The number of padding rows/elements added (0 if no padding was required). - x_batch (Array): The input batch with padding applied if necessary. - y_batch (Array): The target batch with padding applied. - kwargs_batch (list[Array]): A list of keyword arguments with padding applied if necessary. |
Source code in aimz/utils/data/array_loader.py
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
|