art.data_generators

Module defining an interface for data generators and providing concrete implementations for the supported frameworks. Their purpose is to allow for data loading and batching on the fly, as well as dynamic data augmentation. The generators can be used with the fit_generator function in the Classifier interface. Users can define their own generators following the DataGenerator interface. For large, numpy array-based datasets, the NumpyDataGenerator class can be flexibly used with fit_generator on framework-specific classifiers.

Base Class

class art.data_generators.DataGenerator(size: int | None, batch_size: int)

Base class for data generators.

property batch_size: int
Returns:

Return the batch size.

abstract get_batch() tuple

Provide the next batch for training in the form of a tuple (x, y). The generator should loop over the data indefinitely.

Returns:

A tuple containing a batch of data (x, y).

property iterator
Returns:

Return the framework’s iterable data generator.

property size: int | None
Returns:

Return the dataset size.

Framework-Specific Data Generators

class art.data_generators.KerasDataGenerator(iterator: keras.utils.Sequence | tf.keras.utils.Sequence | keras.preprocessing.image.ImageDataGenerator | tf.keras.preprocessing.image.ImageDataGenerator | Generator, size: int | None, batch_size: int)

Wrapper class on top of the Keras-native data generators. These can either be generator functions, keras.utils.Sequence or Keras-specific data generators (keras.preprocessing.image.ImageDataGenerator).

get_batch() tuple

Provide the next batch for training in the form of a tuple (x, y). The generator should loop over the data indefinitely.

Returns:

A tuple containing a batch of data (x, y).

class art.data_generators.MXDataGenerator(iterator: mxnet.gluon.data.DataLoader, size: int, batch_size: int)

Wrapper class on top of the MXNet/Gluon native data loader mxnet.gluon.data.DataLoader.

get_batch() tuple

Provide the next batch for training in the form of a tuple (x, y). The generator should loop over the data indefinitely.

Returns:

A tuple containing a batch of data (x, y).

class art.data_generators.NumpyDataGenerator(x: ndarray, y: ndarray, batch_size: int = 1, drop_remainder: bool = True, shuffle: bool = False)

Simple numpy data generator backed by numpy arrays.

Can be useful for applying numpy data to estimators in other frameworks

e.g., when translating the entire numpy data to GPU tensors would cause OOM

get_batch() tuple
Provide the next batch for training in the form of a tuple (x, y).

The generator will loop over the data indefinitely. If drop_remainder is True, then the last minibatch in each epoch may be a different size

Returns:

A tuple containing a batch of data (x, y).

class art.data_generators.PyTorchDataGenerator(iterator: torch.utils.data.DataLoader, size: int, batch_size: int)

Wrapper class on top of the PyTorch native data loader torch.utils.data.DataLoader.

get_batch() tuple

Provide the next batch for training in the form of a tuple (x, y). The generator should loop over the data indefinitely.

Returns:

A tuple containing a batch of data (x, y).

Return type:

tuple

class art.data_generators.TensorFlowDataGenerator(sess: tf.Session, iterator: tf.data.Iterator, iterator_type: str, iterator_arg: Dict | Tuple | tf.Operation, size: int, batch_size: int)

Wrapper class on top of the TensorFlow native iterators tf.data.Iterator.

get_batch() tuple

Provide the next batch for training in the form of a tuple (x, y). The generator should loop over the data indefinitely.

Returns:

A tuple containing a batch of data (x, y).

Raises:

ValueError – If the iterator has reached the end.

class art.data_generators.TensorFlowV2DataGenerator(iterator: tf.data.Dataset, size: int, batch_size: int)

Wrapper class on top of the TensorFlow v2 native iterators tf.data.Iterator.

get_batch() tuple

Provide the next batch for training in the form of a tuple (x, y). The generator should loop over the data indefinitely.

Returns:

A tuple containing a batch of data (x, y).

Raises:

ValueError – If the iterator has reached the end.