Skip to main content

Get started: Serverless GPU compute with A10 GPUs

This notebook demonstrates how to use serverless GPU compute to run GPU workloads on A10 GPUs directly from Databricks notebooks. You'll learn how to use the serverless_gpu Python library to execute functions on single or multiple GPUs for distributed training.

Serverless GPU compute provides on-demand access to GPU resources without managing clusters. The serverless_gpu library enables seamless execution of GPU workloads with automatic resource provisioning. To learn more, see the Serverless GPU API documentation.

Requirements

Before running this notebook, connect it to Serverless GPU compute:

  1. From the compute selector, select Serverless GPU.
  2. In the Environment tab on the right side, select A10 as the Accelerator.

Verify GPU connection

Run the nvidia-smi command to confirm that your notebook is connected to an A10 GPU and view GPU specifications.

Python
%sh nvidia-smi
Output
Wed Jan 14 19:31:33 2026
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.144.03 Driver Version: 550.144.03 CUDA Version: 12.4 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA A10G On | 00000000:00:1E.0 Off | 0 |
| 0% 22C P8 23W / 300W | 1MiB / 23028MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+

Import the serverless GPU library

Import the serverless_gpu library to access the API for running functions on GPU resources.

Python
import serverless_gpu
Output
Warning: serverless_gpu is in Beta. The API is subject to change.

Create a distributed function

Use the @distributed decorator to create a DistributedFunction that runs on GPU resources. This decorator accepts the following parameters:

  • gpus (int): Number of GPUs to use
  • gpu_type (Optional[Union[GPUType, str]]): The GPU type to use (required if remote=True). Available types: 'a10', 'h100'
  • remote (bool): Whether to run the function on remote GPUs (defaults to False)
  • run_async (bool): Whether to run the function asynchronously (defaults to False)
Python
from serverless_gpu import distributed

@distributed(gpus=1, gpu_type='a10', remote=True)
def foo(x):
print('hello_world', x)
return x


foo
Output
DistributedFunction(gpus=1, gpu_type=GPUType.A10, remote=True, run_async=False, func=foo)

Run the distributed function

Launch the DistributedFunction using the .distributed() method. Pass any required arguments as keyword parameters.

Python
foo.distributed(x=5)
Output
[5]

Distributed training with multiple GPUs

You can launch multiple A10 GPUs in parallel for distributed training workloads. The serverless_gpu.runtime module provides helper functions to manage distributed execution:

  • get_local_rank(): Get the local rank of the current GPU
  • get_global_rank(): Get the global rank across all GPUs
  • get_world_size(): Get the total number of GPUs

Note: Multi-node runs of up to 70 nodes may take as long as 20 minutes to start, with each subsequent node taking longer. For larger runs, you might experience longer wait times or occasional failures.

Python
# The runtime module includes helpers to be used during the GPU runtime (i.e. in the function body).
# These helpers include get_local_rank, get_global_rank, get_world_size
from serverless_gpu import runtime as rt

@distributed(gpus=3, gpu_type='a10', remote=True)
def multi_a10():
return rt.get_global_rank(), rt.get_world_size()


multi_a10.distributed() # returns a list, one element per GPU

Output
[(0, 3), (1, 3), (2, 3)]

Next steps

Learn more about serverless GPU compute:

Example notebook

Get started: Serverless GPU compute with A10 GPUs

Open notebook in new tab