Skip to main content

Distributed XGBoost with GPUs using Ray Tune on Serverless GPU Compute

This notebook demonstrates end-to-end distributed XGBoost training with hyperparameter optimization using Ray Tune on Databricks Serverless GPU Compute. It covers:

  1. Synthetic Dataset Generation: Creating a large-scale synthetic dataset using dbldatagen
  2. Distributed Training: Using Ray Train with XGBoost for distributed data parallelism
  3. Hyperparameter Optimization: Leveraging Ray Tune with Optuna for automated hyperparameter search
  4. MLflow Integration: Logging experiments, metrics, and registering models to Unity Catalog

This demo uses 30M rows x 100 feature columns x 1 target column (2 classes) for binary classification. This dataset is ~12GB compressed and provides an excellent foundation for distributed training experiments.

Serverless GPU Compute Benefits

  • On-Demand Scaling: Automatically provision and scale GPU resources based on workload demands
  • Cost Optimization: Pay only for compute time used, with automatic resource cleanup
  • No Infrastructure Management: Focus on ML training without managing underlying hardware
  • High Availability: Built-in fault tolerance and automatic failover capabilities

FAQs

When do I switch to a distributed version of XGBoost?

  • Large XGBoost datasets should use distributed data parallelism (DDP). This example uses 30M rows for demonstration purposes.
  • Consider single-node and multi-threading across all CPUs, then DDP across multiple nodes with CPUs, then DDP leveraging multiple GPUs.

If I'm using GPUs, how much memory (VRAM) do I need for my dataset?

  • 30M rows x 100 columns x 4 bytes (float32) = ~12GB
  • A total of 2-4x the data footprint in VRAM across GPUs is needed (2x so ~24GB) to train the model
  • This extra memory accounts for boosting rounds, model size, gradients, and intermediate computations
  • A10G GPUs (24GB VRAM each) are perfect for this workload - 1-2 GPUs per model

Serverless GPU Compute Specifications

Databricks Serverless GPU Compute:

  • GPU Types: NVIDIA A10G (24GB VRAM), H100 (80GB VRAM)
  • Auto-scaling: Automatically scales based on workload demands
  • Billing: Pay-per-second billing with automatic resource cleanup
  • Availability: Multi-region support with high availability
  • Integration: Seamless integration with Unity Catalog and MLflow

Recommended Configuration:

  • Workers: 2-4 Ray workers for optimal performance
  • GPU Allocation: 1 A10G GPU per worker (24GB VRAM)
  • Memory: 32GB RAM per worker for data preprocessing
  • Storage: Unity Catalog integration for data access

Connect to serverless GPU compute and install dependencies

Connect your notebook to serverless A10 GPU:

  1. Click the Connect dropdown at the top.
  2. Select Serverless GPU.
  3. Open the Environment side panel on the right side of the notebook.
  4. Set Accelerator to A10 for this demo.
  5. Select Apply and click Confirm to apply this environment to your notebook.
Python
%pip install -qU dbldatagen ray[all]==2.48.0 xgboost optuna mlflow>=3.0
# Note: serverless_gpu package should be pre-installed in the Serverless GPU environment
# If not available, install it using: %pip install databricks-serverless-gpu
dbutils.library.restartPython()

Use widgets to set your Unity Catalog paths and training parameters

Create widgets for the Unity Catalog path and training configuration. The code below uses placeholder values that you can customize.

Python
# Define job inputs
dbutils.widgets.text("catalog", "main", "Unity Catalog Name")
dbutils.widgets.text("schema", "dev", "Unity Catalog Schema Name")
dbutils.widgets.text("num_training_rows", "30000000", "Number of training rows to generate")
dbutils.widgets.text("num_training_columns", "100", "Number of feature columns")
dbutils.widgets.text("num_labels", "2", "Number of labels in the target column")
dbutils.widgets.text("warehouse_id", "93a682dcf60dae13", "SQL Warehouse ID (optional, for reading from UC)")
dbutils.widgets.text("num_workers", "2", "Number of Ray workers per trial")
dbutils.widgets.text("num_hpo_trials", "8", "Number of hyperparameter optimization trials")
dbutils.widgets.text("max_concurrent_trials", "4", "Maximum concurrent HPO trials")

# Get parameter values (will override widget defaults if run by job)
UC_CATALOG = dbutils.widgets.get("catalog")
UC_SCHEMA = dbutils.widgets.get("schema")
NUM_TRAINING_ROWS = int(dbutils.widgets.get("num_training_rows"))
NUM_TRAINING_COLUMNS = int(dbutils.widgets.get("num_training_columns"))
NUM_LABELS = int(dbutils.widgets.get("num_labels"))
WAREHOUSE_ID = dbutils.widgets.get("warehouse_id")
NUM_WORKERS = int(dbutils.widgets.get("num_workers"))
NUM_HPO_TRIALS = int(dbutils.widgets.get("num_hpo_trials"))
MAX_CONCURRENT_TRIALS = int(dbutils.widgets.get("max_concurrent_trials"))

print(f"UC_CATALOG: {UC_CATALOG}")
print(f"UC_SCHEMA: {UC_SCHEMA}")
print(f"NUM_TRAINING_ROWS: {NUM_TRAINING_ROWS}")
print(f"NUM_TRAINING_COLUMNS: {NUM_TRAINING_COLUMNS}")
print(f"NUM_LABELS: {NUM_LABELS}")
print(f"NUM_WORKERS: {NUM_WORKERS}")
print(f"NUM_HPO_TRIALS: {NUM_HPO_TRIALS}")

Step 1: Generate Synthetic Dataset

We'll create a synthetic dataset using dbldatagen for demonstration purposes. In production, you would use your actual training data.

Python
import dbldatagen as dg
from pyspark.sql.types import FloatType, IntegerType
import os
from databricks.sdk import WorkspaceClient
from databricks.sdk.service import catalog

# Create schema if it doesn't exist
w = WorkspaceClient()
try:
created_schema = w.schemas.create(name=UC_SCHEMA, catalog_name=UC_CATALOG)
print(f"Schema '{created_schema.name}' created successfully")
except Exception as e:
# Handle the case where the schema already exists
print(f"Schema '{UC_SCHEMA}' already exists in catalog '{UC_CATALOG}'. Skipping schema creation.")

# Create volumes for data storage
parquet_write_path = f'/Volumes/{UC_CATALOG}/{UC_SCHEMA}/synthetic_data'
if not os.path.exists(parquet_write_path):
created_volume = w.volumes.create(
catalog_name=UC_CATALOG,
schema_name=UC_SCHEMA,
name='synthetic_data',
volume_type=catalog.VolumeType.MANAGED
)
print(f"Volume 'synthetic_data' at {parquet_write_path} created successfully")
else:
print(f"Volume {parquet_write_path} already exists. Skipping volumes creation.")

# Create volume for Ray storage
ray_storage_path = f'/Volumes/{UC_CATALOG}/{UC_SCHEMA}/ray_data_tmp_dir'
if not os.path.exists(ray_storage_path):
created_volume = w.volumes.create(
catalog_name=UC_CATALOG,
schema_name=UC_SCHEMA,
name='ray_data_tmp_dir',
volume_type=catalog.VolumeType.MANAGED
)
print(f"Volume 'ray_data_tmp_dir' at {ray_storage_path} created successfully")
else:
print(f"Volume {ray_storage_path} already exists. Skipping volumes creation.")
Python
# Generate synthetic dataset
table_name = f"synthetic_data_{NUM_TRAINING_ROWS}_rows_{NUM_TRAINING_COLUMNS}_columns_{NUM_LABELS}_labels"
label = "target"

print(f"Generating {NUM_TRAINING_ROWS} synthetic rows")
print(f"Generating {NUM_TRAINING_COLUMNS} synthetic columns")
print(f"Generating {NUM_LABELS} synthetic labels")

testDataSpec = (
dg.DataGenerator(spark, name="synthetic_data", rows=NUM_TRAINING_ROWS)
.withIdOutput()
.withColumn(
"r",
FloatType(),
expr="rand()",
numColumns=NUM_TRAINING_COLUMNS,
)
.withColumn(
"target",
IntegerType(),
expr=f"floor(rand()*{NUM_LABELS})",
numColumns=1
)
)

df = testDataSpec.build()
df = df.repartition(50)

# Write to Delta table
df.write.format("delta").mode("overwrite").option("delta.enableDeletionVectors", "true").saveAsTable(
f"{UC_CATALOG}.{UC_SCHEMA}.{table_name}"
)

# Write to Parquet for Ray dataset reading (backup option)
df.write.mode("overwrite").format("parquet").save(
f"{parquet_write_path}/{table_name}"
)

print(f"Dataset created successfully: {UC_CATALOG}.{UC_SCHEMA}.{table_name}")

Step 2: Set Up Ray Dataset and Training Functions

Configure Ray to read from Unity Catalog and set up the distributed training functions.

Python
import ray
import os

# Set up environment variables for MLflow
os.environ['DATABRICKS_HOST'] = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().get()

def read_ray_dataset(catalog, schema, table, warehouse_id=None):
"""
Read data from Unity Catalog into a Ray Dataset.

Args:
catalog: Unity Catalog name
schema: Schema name
table: Table name
warehouse_id: Optional SQL Warehouse ID for reading from UC

Returns:
train_dataset, val_dataset: Ray Datasets for training and validation
"""
try:
## Option 1 (PREFERRED): Build a Ray Dataset using a Databricks SQL Warehouse
if warehouse_id and warehouse_id.strip():
ds = ray.data.read_databricks_tables(
warehouse_id=warehouse_id,
catalog=catalog,
schema=schema,
query=f'SELECT * FROM {catalog}.{schema}.{table}',
)
print('Read directly from Unity Catalog using SQL Warehouse')
else:
raise ValueError("Warehouse ID not provided - falling back to Parquet")
except Exception as e:
print(f"Note: {e}")
## Option 2: Build a Ray Dataset using Parquet files
# If you have too many Ray nodes, you may not be able to create a Ray dataset
# using the warehouse method above because of rate limits. One backup solution
# is to create parquet files from the delta table and build a ray dataset from that.
parquet_path = f'/Volumes/{catalog}/{schema}/synthetic_data/{table}'
ds = ray.data.read_parquet(parquet_path)
print('Read directly from Parquet files')

train_dataset, val_dataset = ds.train_test_split(test_size=0.25)
return train_dataset, val_dataset

Step 3: Define XGBoost Training Functions

Define the per-worker training function and the driver function that orchestrates distributed training.

Python
import xgboost
import ray.train
from ray.train.xgboost import XGBoostTrainer, RayTrainReportCallback

def train_fn_per_worker(params: dict):
"""
Trains an XGBoost model on a shard of the distributed dataset assigned to this worker.

This function is designed to be executed by individual Ray Train workers.
It retrieves the training and validation data shards, converts them to DMatrix format,
and performs a portion of the distributed XGBoost training. Ray Train handles
the inter-worker communication.

Args:
params (dict): A dictionary of XGBoost training parameters, including
'num_estimators', 'eval_metric', and potentially other
XGBoost-specific parameters.
"""
# Get dataset shards for this worker
train_shard = ray.train.get_dataset_shard("train")
val_shard = ray.train.get_dataset_shard("val")

# Convert shards to pandas DataFrames
train_df = train_shard.materialize().to_pandas()
val_df = val_shard.materialize().to_pandas()

train_X = train_df.drop(label, axis=1)
train_y = train_df[label]
val_X = val_df.drop(label, axis=1)
val_y = val_df[label]

dtrain = xgboost.DMatrix(train_X, label=train_y)
deval = xgboost.DMatrix(val_X, label=val_y)

# Do distributed data-parallel training.
# Ray Train sets up the necessary coordinator processes and
# environment variables for workers to communicate with each other.
evals_results = {}
bst = xgboost.train(
params,
dtrain=dtrain,
evals=[(deval, "validation")],
num_boost_round=params['num_estimators'],
evals_result=evals_results,
callbacks=[RayTrainReportCallback(
metrics={params['eval_metric']: f"validation-{params['eval_metric']}"},
frequency=1
)],
)

Python
def train_driver_fn(config: dict, train_dataset, val_dataset, ray_storage_path: str):
"""
Drives the distributed XGBoost training process using Ray Train.

This function sets up the XGBoostTrainer, configures scaling (number of workers, GPU usage,
and resources per worker), and initiates the distributed training by calling `trainer.fit()`.
It also propagates metrics back to Ray Tune if integrated.

Args:
config (dict): A dictionary containing run-level hyperparameters such as
'num_workers', 'use_gpu', and a nested 'params' dictionary
for XGBoost training parameters.
train_dataset: The Ray Dataset for training.
val_dataset: The Ray Dataset for validation.
ray_storage_path: Path for Ray storage.

Returns:
None: The function reports metrics to Ray Tune but does not explicitly return a value.
The trained model artifact is typically handled by Ray Train's checkpointing.
"""
# Unpack run-level hyperparameters.
num_workers = config["num_workers"]
use_gpu = config["use_gpu"]
params = config['params']

# Initialize the XGBoostTrainer, which orchestrates the distributed training using Ray.
trainer = XGBoostTrainer(
train_loop_per_worker=train_fn_per_worker, # The function to be executed on each worker
train_loop_config=params,
# By default Ray uses 1 GPU and 1 CPU per worker if resources_per_worker is not specified.
# XGBoost is multi-threaded, so multiple CPUs can be assigned per worker, but not GPUs.
scaling_config=ray.train.ScalingConfig(
num_workers=num_workers,
use_gpu=use_gpu,
resources_per_worker={"CPU": 12, "GPU": 1}
),
datasets={"train": train_dataset, "val": val_dataset}, # Ray Datasets to be used by the trainer + workers
run_config=ray.train.RunConfig(storage_path=ray_storage_path)
)

result = trainer.fit()

# Propagate metrics back up for Ray Tune.
# Ensure the metric key matches your eval_metric.
ray.tune.report(
{params['eval_metric']: result.metrics.get('mlogloss', result.metrics.get('validation-mlogloss', 0.0))},
checkpoint=result.checkpoint
)

Step 4: Hyperparameter Optimization with Ray Tune

Use Ray Tune with Optuna for automated hyperparameter search, integrated with MLflow for experiment tracking.

Python
import mlflow
from ray import tune
from ray.tune.tuner import Tuner
from ray.tune.search.optuna import OptunaSearch
from ray.air.integrations.mlflow import MLflowLoggerCallback

# Import serverless GPU launcher
try:
from serverless_gpu.ray import ray_launch
except ImportError:
raise ImportError(
"serverless_gpu package not found. Please ensure you're running on Serverless GPU compute "
"or install it using: %pip install databricks-serverless-gpu"
)

# Set up MLflow experiment
username = spark.sql("SELECT current_user()").collect()[0][0]
notebook_name = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get().split("/")[-1]
experiment_name = f"/Users/{username}/{notebook_name}"

# Configure MLflow to use Unity Catalog
mlflow.set_registry_uri("databricks-uc")

@ray.remote(num_cpus=1) # Ensure main_task is not scheduled on head
class TaskRunner:
def run(self):
# Load dataset as distributed Ray Dataset
train_dataset, val_dataset = read_ray_dataset(
UC_CATALOG, UC_SCHEMA, table_name, WAREHOUSE_ID if WAREHOUSE_ID and WAREHOUSE_ID.strip() else None
)

# Define the hyperparameter search space.
param_space = {
"num_workers": NUM_WORKERS,
"use_gpu": True,
"params": {
"objective": "multi:softmax",
'eval_metric': 'mlogloss',
"tree_method": "hist",
"device": "cuda",
"num_class": NUM_LABELS,
"learning_rate": tune.uniform(0.01, 0.3),
"num_estimators": tune.randint(20, 30)
}
}

# Set up search algorithm. Here we use Optuna with the default Bayesian sampler (TPES)
optuna = OptunaSearch(
metric=param_space['params']['eval_metric'],
mode="min"
)

with mlflow.start_run() as run:
# Set up Tuner job and run.
tuner = tune.Tuner(
tune.with_parameters(
train_driver_fn,
train_dataset=train_dataset,
val_dataset=val_dataset,
ray_storage_path=ray_storage_path
),
run_config=tune.RunConfig(
name='xgboost_raytune_run',
storage_path=ray_storage_path,
callbacks=[MLflowLoggerCallback(
save_artifact=True,
tags={"mlflow.parentRunId": run.info.run_id},
log_params_on_trial_end=True
)]
),
tune_config=tune.TuneConfig(
num_samples=NUM_HPO_TRIALS,
max_concurrent_trials=MAX_CONCURRENT_TRIALS,
search_alg=optuna,
),
param_space=param_space,
)

results = tuner.fit()
return results

Step 5: Launch Distributed Training

Use the distributed serverless GPU API to run distributed model training on remote A10 GPUs.

Python
@ray_launch(gpus=8, gpu_type='A10', remote=True)
def my_ray_function():
runner = TaskRunner.remote()
return ray.get(runner.run.remote())

results = my_ray_function.distributed()

Step 6: Retrieve Best Model and Log to MLflow

Extract the best model from hyperparameter optimization and register it to Unity Catalog.

Python
# Get the best result
results = results[0] if type(results) == list else results
best_result = results.get_best_result(metric="mlogloss", mode="min")
best_params = best_result.config

print(f"Best hyperparameters: {best_params}")
print(f"Best validation mlogloss: {best_result.metrics.get('mlogloss', 'N/A')}")

# Load the best model
booster = RayTrainReportCallback.get_model(best_result.checkpoint)

Python

# Sample data for input example
sample_data = spark.read.table(f"{UC_CATALOG}.{UC_SCHEMA}.{table_name}").limit(5).toPandas()

with mlflow.start_run() as run:
logged_model = mlflow.xgboost.log_model(
booster,
"model",
input_example=sample_data[[col for col in sample_data.columns if col != label]]
)
print(f"Model logged to MLflow: {logged_model.model_uri}")


Step 7: Test Model Inference

Load the logged model and test inference on sample data.

Python

# Load the model
loaded_model = mlflow.pyfunc.load_model(logged_model.model_uri)

# Test inference
test_data = spark.read.table(f"{UC_CATALOG}.{UC_SCHEMA}.{table_name}").limit(10).toPandas()
predictions = loaded_model.predict(test_data[[col for col in test_data.columns if col != label]])

print("Sample predictions:")
print(predictions)
print("\nSample actual labels:")
print(test_data[label].values)

Next steps

You've successfully completed distributed XGBoost training with hyperparameter optimization using Ray Tune on Databricks Serverless GPU.

Customization options

  • Scale your dataset: Adjust num_training_rows and num_training_columns widgets for larger datasets
  • Tune hyperparameters: Modify the param_space in Step 4 to explore different XGBoost parameters
  • Increase GPU resources: Update gpus parameter in @ray_launch decorator for more compute power
  • Adjust workers: Change num_workers widget to scale training parallelism
  • Try different models: Replace XGBoost with other Ray Train-compatible frameworks

Resources

Cleanup

GPU resources are automatically cleaned up when the notebook disconnects. To manually disconnect:

  1. Click Connected in the compute dropdown
  2. Hover over Serverless
  3. Select Terminate from the dropdown menu

Example notebook

Distributed XGBoost with GPUs using Ray Tune on Serverless GPU Compute

Open notebook in new tab