DataLoaders
A DataLoader is a crucial utility that wraps a Dataset and provides an iterable over it. Its primary responsibilities are to handle batching, shuffling, and multi-process data loading, ensuring that the GPU is fed with data efficiently without becoming a bottleneck.
While LibTorch provides a basic data loading API (torch::data::DataLoader), it can be complex to use and lacks some of the convenient features found in Python's torch.utils.data.DataLoader.
xTorch simplifies and enhances this process with its own high-performance implementations.
xt::dataloaders::ExtendedDataLoader
The ExtendedDataLoader is the primary, high-level data loader in xTorch. It is designed to be both easy to use and highly performant, mirroring the functionality and simplicity of its Python counterpart.
It abstracts away the complexity of parallel data fetching and provides a simple for loop interface for iterating over batches of data.
Key Features
- Simple API: Requires minimal setup and configuration.
- Multi-Process Data Loading: Uses multiple worker processes to load data in parallel, preventing CPU bottlenecks.
- Automatic Batching: Combines individual data samples into batches.
- Optional Shuffling: Can automatically shuffle the data at the beginning of each epoch.
- Prefetching: Pre-fetches batches in the background to keep the GPU saturated.
Usage
The ExtendedDataLoader is typically initialized with a dataset object and configuration options. It can then be used in a range-based for loop to retrieve data batches.
#include <xtorch/xtorch.h>
int main() {
// 1. Assume 'dataset' is an initialized xt::datasets::Dataset object
auto dataset = xt::datasets::MNIST("./data", xt::datasets::DataMode::TRAIN);
// 2. Instantiate the ExtendedDataLoader
xt::dataloaders::ExtendedDataLoader data_loader(
dataset,
/*batch_size=*/64,
/*shuffle=*/true,
/*num_workers=*/4,
/*prefetch_factor=*/2
);
// 3. Iterate over the data loader to get batches
torch::Device device(torch::kCUDA);
int batch_count = 0;
for (auto& batch : data_loader) {
// Each 'batch' is a pair of (data, target) tensors
torch::Tensor data = batch.first.to(device);
torch::Tensor target = batch.second.to(device);
if (batch_count == 0) {
std::cout << "Batch Data Shape: " << data.sizes() << std::endl;
std::cout << "Batch Target Shape: " << target.sizes() << std::endl;
}
batch_count++;
}
std::cout << "Total batches: " << batch_count << std::endl;
}Constructor Parameters
The ExtendedDataLoader is configured through its constructor:
ExtendedDataLoader(Dataset& dataset, size_t batch_size, bool shuffle = false, int num_workers = 0, int prefetch_factor = 2)
| Parameter | Type | Description |
|---|---|---|
dataset |
xt::datasets::Dataset& |
The dataset from which to load the data. |
batch_size |
size_t |
The number of samples per batch. |
shuffle |
bool |
If true, the data is reshuffled at every epoch. Defaults to false. |
num_workers |
int |
The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. Defaults to 0. |
prefetch_factor |
int |
The number of batches to prefetch in advance for each worker. This helps hide data loading latency. Defaults to 2. |
Integration with xt::Trainer
The ExtendedDataLoader is designed to work seamlessly with the xt::Trainer. You simply pass your initialized data loader instances to the trainer.fit() method.
// Assume model, optimizer, train_loader, and val_loader are initialized
xt::Trainer trainer;
trainer.set_max_epochs(10)
.set_optimizer(optimizer)
.set_loss_fn(torch::nll_loss);
// The trainer will automatically iterate over the data loaders
trainer.fit(model, train_loader, &val_loader, device);For most use cases, the xt::dataloaders::ExtendedDataLoader is the recommended and only data loader you will need.
