Skip to main content

Load data on AI Runtime

Public Preview

AI Runtime for single-node tasks is in Public Preview. The distributed training API for multi-GPU workloads remain in Beta.

This section covers information about loading data on AI Runtime specifically for ML and DL applications. Check the tutorial to learn more about how to load and transform data using the Spark Python API.

note

Unity Catalog is required. All data access on AI Runtime goes through Unity Catalog. Your tables and volumes must be registered in Unity Catalog and accessible to your user or service principal.

Load tabular data

Use Spark Connect to load tabular machine learning data from Delta tables.

For single-node training, you can convert Apache Spark DataFrames into pandas DataFrames using the PySpark method toPandas(), and then optionally convert to NumPy format using the PySpark method to_numpy().

note

Spark Connect defers analysis and name resolution to execution time, which may change the behavior of your code. See Compare Spark Connect to Spark Classic.

Spark Connect supports most PySpark APIs, including Spark SQL, Pandas API on Spark, Structured Streaming, and MLlib (DataFrame-based). See the PySpark API reference documentation for the latest supported APIs.

For other limitations, see Serverless compute limitations.

Load large Delta tables using volumes

For large Delta tables that are too big to convert with toPandas(), export the data to a Unity Catalog volume and load it directly using PyTorch or Hugging Face:

Python
# Step 1: Export the Delta table to Parquet files in a UC volume
output_path = "/Volumes/catalog/schema/my_volume/training_data"
spark.table("catalog.schema.my_table").write.mode("overwrite").parquet(output_path)
Python
# Step 2: Load the exported data directly using Hugging Face datasets
from datasets import load_dataset

dataset = load_dataset("parquet", data_files="/Volumes/catalog/schema/my_volume/training_data/*.parquet")

This approach avoids Spark overhead during training and works well for both single-GPU and distributed training workflows.

Load unstructured data from volumes

For unstructured data such as images, audio, and text files, use Unity Catalog volumes. The following example shows how to read files from a volume and use them with a PyTorch Dataset:

Python
# Read files from a UC volume
volume_path = "/Volumes/catalog/schema/my_volume/images/"

from torch.utils.data import Dataset
import os
from PIL import Image

class ImageDataset(Dataset):
def __init__(self, root_dir):
self.file_list = [os.path.join(root_dir, f) for f in os.listdir(root_dir)]

def __len__(self):
return len(self.file_list)

def __getitem__(self, idx):
img = Image.open(self.file_list[idx])
return img

Load data inside the @distributed decorator

When using the Serverless GPU API for distributed training, move data loading code inside the @distributed decorator. The dataset size can exceed the maximum size allowed by pickle, so it is recommended to generate the dataset inside the decorator, as shown below:

Python
from serverless_gpu import distributed

# This may cause a pickle error if the dataset is too large
dataset = get_dataset(file_path)

@distributed(gpus=8, gpu_type='H100')
def run_train():
# Load data inside the decorator to avoid pickle serialization issues
dataset = get_dataset(file_path)
...

Data loading performance

/Workspace and /Volumes directories are hosted on remote Unity Catalog storage. If your dataset is stored in Unity Catalog, the data loading speed is limited by the available network bandwidth. If you are training multiple epochs, the recommended approach is to first copy the data locally, specifically to the /tmp directory, which is hosted on fast NVMe SSD storage.

If your dataset is large, the following techniques can improve performance:

  • Cache data locally for multi-epoch training. Copy datasets to /tmp for faster access across epochs:

    Python
    import shutil
    shutil.copytree("/Volumes/catalog/schema/volume/dataset", "/tmp/dataset")
  • Parallelize data fetching. Use the torch DataLoader with multiple workers to overlap data loading with GPU computation. Set num_workers to at least 2. To improve performance, increase num_workers (which increases parallel reads) or prefetch_factor (which increases the number of items each worker prefetches):

    Python
    from torch.utils.data import DataLoader

    loader = DataLoader(
    dataset,
    batch_size=32,
    num_workers=2,
    prefetch_factor=2,
    pin_memory=True
    )
  • Use Spark Connect for large tabular datasets. Spark Connect supports most PySpark APIs and handles distributed reads efficiently.

Streaming datasets

For very large datasets that do not fit in memory, use streaming approaches: