art.data_generators
¶
Module defining an interface for data generators and providing concrete implementations for the supported frameworks.
Their purpose is to allow for data loading and batching on the fly, as well as dynamic data augmentation.
The generators can be used with the fit_generator function in the Classifier
interface. Users can define
their own generators following the DataGenerator
interface. For large, numpy array-based datasets, the
NumpyDataGenerator
class can be flexibly used with fit_generator on framework-specific classifiers.
Base Class¶
- class art.data_generators.DataGenerator(size: Optional[int], batch_size: int)¶
Base class for data generators.
- property batch_size: int¶
- Returns:
Return the batch size.
- abstract get_batch() tuple ¶
Provide the next batch for training in the form of a tuple (x, y). The generator should loop over the data indefinitely.
- Returns:
A tuple containing a batch of data (x, y).
- property iterator¶
- Returns:
Return the framework’s iterable data generator.
- property size: Optional[int]¶
- Returns:
Return the dataset size.
Framework-Specific Data Generators¶
- class art.data_generators.KerasDataGenerator(iterator: Union[keras.utils.Sequence, tf.keras.utils.Sequence, keras.preprocessing.image.ImageDataGenerator, tf.keras.preprocessing.image.ImageDataGenerator, Generator], size: Optional[int], batch_size: int)¶
Wrapper class on top of the Keras-native data generators. These can either be generator functions, keras.utils.Sequence or Keras-specific data generators (keras.preprocessing.image.ImageDataGenerator).
- get_batch() tuple ¶
Provide the next batch for training in the form of a tuple (x, y). The generator should loop over the data indefinitely.
- Returns:
A tuple containing a batch of data (x, y).
- class art.data_generators.MXDataGenerator(iterator: mxnet.gluon.data.DataLoader, size: int, batch_size: int)¶
Wrapper class on top of the MXNet/Gluon native data loader
mxnet.gluon.data.DataLoader
.- get_batch() tuple ¶
Provide the next batch for training in the form of a tuple (x, y). The generator should loop over the data indefinitely.
- Returns:
A tuple containing a batch of data (x, y).
- class art.data_generators.NumpyDataGenerator(x: ndarray, y: ndarray, batch_size: int = 1, drop_remainder: bool = True, shuffle: bool = False)¶
Simple numpy data generator backed by numpy arrays.
- Can be useful for applying numpy data to estimators in other frameworks
e.g., when translating the entire numpy data to GPU tensors would cause OOM
- get_batch() tuple ¶
- Provide the next batch for training in the form of a tuple (x, y).
The generator will loop over the data indefinitely. If drop_remainder is True, then the last minibatch in each epoch may be a different size
- Returns:
A tuple containing a batch of data (x, y).
- class art.data_generators.PyTorchDataGenerator(iterator: torch.utils.data.DataLoader, size: int, batch_size: int)¶
Wrapper class on top of the PyTorch native data loader
torch.utils.data.DataLoader
.- get_batch() tuple ¶
Provide the next batch for training in the form of a tuple (x, y). The generator should loop over the data indefinitely.
- Returns:
A tuple containing a batch of data (x, y).
- Return type:
tuple
- class art.data_generators.TensorFlowDataGenerator(sess: tf.Session, iterator: tf.data.Iterator, iterator_type: str, iterator_arg: Union[Dict, Tuple, tf.Operation], size: int, batch_size: int)¶
Wrapper class on top of the TensorFlow native iterators
tf.data.Iterator
.- get_batch() tuple ¶
Provide the next batch for training in the form of a tuple (x, y). The generator should loop over the data indefinitely.
- Returns:
A tuple containing a batch of data (x, y).
- Raises:
ValueError – If the iterator has reached the end.
- class art.data_generators.TensorFlowV2DataGenerator(iterator: tf.data.Dataset, size: int, batch_size: int)¶
Wrapper class on top of the TensorFlow v2 native iterators
tf.data.Iterator
.- get_batch() tuple ¶
Provide the next batch for training in the form of a tuple (x, y). The generator should loop over the data indefinitely.
- Returns:
A tuple containing a batch of data (x, y).
- Raises:
ValueError – If the iterator has reached the end.