ArrayDataset
Custom Dataset class for JAX arrays based on PyTorch's Dataset.
Source code in aimz/utils/data/array_dataset.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
|
__init__
__init__(*arrays: ArrayLike) -> None
Parameters:
Name | Type | Description | Default |
---|---|---|---|
*arrays
|
ArrayLike
|
One or more JAX arrays or compatible array-like objects. |
()
|
Raises:
Type | Description |
---|---|
ValueError
|
If no arrays are provided or if the arrays do not have the same length. |
Source code in aimz/utils/data/array_dataset.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
|
__len__
__len__() -> int
Get the number of samples in the dataset.
Returns:
Type | Description |
---|---|
int
|
The number of samples, equal to the length of the first array. |
Source code in aimz/utils/data/array_dataset.py
48 49 50 51 52 53 54 55 |
|
__getitem__
__getitem__(index: int) -> object
Retrieve the elements at the specified index.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
index
|
int
|
Index of the item to retrieve. |
required |
Returns:
Type | Description |
---|---|
object
|
A pytree containing the elements at the given index from each array. |
Source code in aimz/utils/data/array_dataset.py
57 58 59 60 61 62 63 64 65 66 |
|