YOLO11n object detection on Databricks AI Runtime
This notebook trains a YOLO11n object detection model end to end on Databricks Serverless GPU. You train the model on the COCO128 dataset, track the run with MLflow, register it to Unity Catalog, and optionally deploy it to a Model Serving endpoint.
The notebook prepares the dataset in Unity Catalog volumes and wraps the trained model in a custom MLflow PyFunc model that takes base64-encoded images and returns bounding boxes, so the same model works over a serving endpoint. It validates the model locally before deploying with the Databricks SDK and logging requests to AI Gateway inference tables.
⚠️ COCO128 is for demo only (128 images). The model will overfit. For production, use larger datasets (1K+ images). This workflow scales directly by updating the data paths. See NuInsSeg for a real-world example.
Reference: Train and deploy a YOLO vision model on Databricks AI Runtime (serverless GPU) (Databricks Technical Blog)
Connect to serverless GPU
- Click Serverless GPU (top-right connect button).
- Open the
Environment panel in the right sidebar, then:
- Set Accelerator: A10 (cost-efficient) or H100 (faster).
- Set Base environment:
AI v5(check the release notes for updates). - Click Apply.
AI v5 pre-bundles mlflow>=3 (skinny), nvidia-ml-py, threadpoolctl, and torch, so only ultralytics needs installing. If Model Serving fails on skinny MLflow, re-add %pip install mlflow; the Package Verification cell flags this.
Environment setup
Install required packages and configure Python environment for YOLO training on serverless GPU.
# ============================================================
# PACKAGE INSTALLATION — AI v5 (Serverless GPU, single GPU)
# ============================================================
# AI v5 pre-bundles mlflow>=3, nvidia-ml-py, threadpoolctl (see top cell).
# Only ultralytics needs install.
%pip install ultralytics==8.3.204 -q
# Note: %pip automatically restarts the Python environment after install.
# YOLO_CONFIG_DIR is set in the 'Import Libraries' cell (cell 7).
# ============================================================
# PACKAGE VERIFICATION
# ============================================================
import sys
import importlib.metadata
print("Checking required packages...\n")
missing_packages = []
installed_packages = {}
# Check each required package using importlib.metadata
# Check mlflow (full install) or mlflow-skinny (AI v5 pre-bundled) — either satisfies mlflow>=3.0
_mlflow_found = False
for _pkg in ['mlflow', 'mlflow-skinny']:
try:
_ver = importlib.metadata.version(_pkg)
installed_packages['mlflow'] = _ver
print(f"✓ mlflow ({_pkg}): {_ver}")
_mlflow_found = True
break
except importlib.metadata.PackageNotFoundError:
pass
if not _mlflow_found:
missing_packages.append('mlflow>=3.0')
print("✗ mlflow: NOT INSTALLED")
packages_to_check = [
('ultralytics', 'ultralytics==8.3.204'),
('opencv-python', 'opencv-python (provides cv2)'),
('nvidia-ml-py', 'nvidia-ml-py==13.580.82'),
('threadpoolctl', 'threadpoolctl==3.1.0')
]
for package_name, package_spec in packages_to_check:
try:
version = importlib.metadata.version(package_name)
installed_packages[package_name] = version
print(f"✓ {package_name}: {version}")
except importlib.metadata.PackageNotFoundError:
missing_packages.append(package_spec)
print(f"✗ {package_name}: NOT INSTALLED")
print("\n" + "="*60)
if missing_packages:
print("[ACTION REQUIRED] Missing packages detected!")
print("\nInstall missing packages and restart kernel.")
for pkg in missing_packages:
print(f" - {pkg}")
else:
print("[OK] All required packages are installed!")
print(f" Python version: {sys.version.split()[0]}")
print("="*60)
# --- AI v5 environment probe (added for v5 validation) ---
import torch
print("\n--- AI v5 environment probe ---")
print("torch:", torch.__version__, "| CUDA available:", torch.cuda.is_available())
try:
print("GPU:", torch.cuda.get_device_name(0))
except Exception as e:
print("GPU: (could not query)", e)
_dists = {d.metadata['Name'].lower() for d in importlib.metadata.distributions()}
print("mlflow-skinny present:", "mlflow-skinny" in _dists, "| full mlflow present:", "mlflow" in _dists)
try:
import mlflow.deployments, mlflow.models # serving/registration path (the skinny risk)
print("[OK] mlflow.deployments + mlflow.models importable -> Model Serving step should work")
except Exception as e:
print("[WARN] mlflow serving API import FAILED on skinny -> re-add `%pip install mlflow`:", e)
# Note: You may see a FutureWarning about pynvml being deprecated.
# This is expected - PyTorch internally imports pynvml, but we have
# the correct nvidia-ml-py package installed. The warning is harmless.
import os
import pandas as pd
import torch
import mlflow
import uuid
# Set Ultralytics config directory to a unique writable location
config_dir = f'/tmp/yolo_config_{uuid.uuid4().hex[:8]}'
os.environ['YOLO_CONFIG_DIR'] = config_dir
os.makedirs(config_dir, exist_ok=True)
from ultralytics import YOLO, settings
from mlflow.types.schema import Schema, ColSpec
from mlflow.models.signature import ModelSignature
print("[OK] All libraries imported successfully")
Helper functions
Utility functions for the complete YOLO training and deployment workflow:
Data Management:
download_file()- Download models and configs to UC Volumedownload_and_extract_dataset()- Download and extract COCO128split_dataset()- Create reproducible train/val/test splits
MLflow Integration:
infer_model_signature()- Automatically infer model signature from predictionssetup_mlflow_experiment()- Configure MLflow with system metricsregister_yolo_model()- Register model to Unity Catalog with custom wrapper
Model Evaluation:
evaluate_model_on_split()- Evaluate and visualize predictions on data splits
Custom Wrapper:
YOLOWrapper- MLflow PyFunc wrapper for YOLO models- Input: Base64-encoded images (universal format, works across network boundaries)
- Output: DataFrame with class, confidence, bounding boxes (11 columns)
- Purpose: Enables deployment to Model Serving endpoints
- Production-ready: Tested locally before deployment
# ============================================================
# HELPER FUNCTIONS
# ============================================================
import os
import shutil
import requests
import zipfile
import io
import random
import yaml
import glob
import json
from datetime import datetime
import pandas as pd
import matplotlib.pyplot as plt
from ultralytics import YOLO
import mlflow
from mlflow import MlflowClient
import importlib.metadata
def download_file(url, destination, description="file"):
"""Download a file from URL to destination path."""
if os.path.exists(destination):
print(f"[INFO] {description} already exists at: {destination}")
print(f" Skipping download")
return True
print(f"Downloading {description}...")
print(f" From: {url}")
print(f" To: {destination}")
try:
response = requests.get(url, stream=True)
if response.status_code == 200:
os.makedirs(os.path.dirname(destination), exist_ok=True)
with open(destination, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"[OK] Downloaded successfully")
if destination.endswith('.pt'):
print(f" Size: {os.path.getsize(destination) / (1024*1024):.2f} MB")
return True
else:
print(f"[ERROR] Download failed with status code: {response.status_code}")
return False
except Exception as e:
print(f"[ERROR] Download failed: {e}")
return False
def download_and_extract_dataset(download_url, extraction_path):
"""Download and extract a zip dataset."""
print("Downloading dataset...")
response = requests.get(download_url)
print("Extracting dataset...")
z = zipfile.ZipFile(io.BytesIO(response.content))
z.extractall(extraction_path)
print(f"[OK] Dataset downloaded and extracted to: {extraction_path}")
return True
def split_dataset(source_images_dir, source_labels_dir, base_images_dir, base_labels_dir,
train_ratio=0.625, val_ratio=0.1875, random_seed=42):
"""Split dataset into train/val/test sets with reproducible random seed."""
print("=" * 60)
print("DATASET SPLITTING")
print("=" * 60)
random.seed(random_seed)
print(f"\nRandom seed: {random_seed}")
test_ratio = 1.0 - train_ratio - val_ratio
print(f"Split ratios: Train={train_ratio:.1%}, Val={val_ratio:.1%}, Test={test_ratio:.1%}\n")
# Get all images
all_images = sorted([f for f in os.listdir(source_images_dir) if f.endswith('.jpg')])
print(f"Total images: {len(all_images)}")
# Shuffle and split
random.shuffle(all_images)
train_size = int(len(all_images) * train_ratio)
val_size = int(len(all_images) * val_ratio)
train_images = all_images[:train_size]
val_images = all_images[train_size:train_size + val_size]
test_images = all_images[train_size + val_size:]
print(f"Split sizes: Train={len(train_images)}, Val={len(val_images)}, Test={len(test_images)}\n")
# Create directories
for split_name in ['train', 'val', 'test']:
os.makedirs(f"{base_images_dir}/{split_name}", exist_ok=True)
os.makedirs(f"{base_labels_dir}/{split_name}", exist_ok=True)
# Copy files
print("Copying files to splits...")
for split_name, image_list in [('train', train_images), ('val', val_images), ('test', test_images)]:
print(f" Processing {split_name} split ({len(image_list)} images)...")
for img_name in image_list:
# Copy image
src_img = os.path.join(source_images_dir, img_name)
dst_img = os.path.join(base_images_dir, split_name, img_name)
shutil.copy2(src_img, dst_img)
# Copy label if exists
label_name = img_name.replace('.jpg', '.txt')
src_label = os.path.join(source_labels_dir, label_name)
dst_label = os.path.join(base_labels_dir, split_name, label_name)
if os.path.exists(src_label):
shutil.copy2(src_label, dst_label)
print(f" [OK] {split_name}: {len(image_list)} images copied")
print(f"\n[OK] Dataset split complete!")
print("=" * 60)
return len(train_images), len(val_images), len(test_images)
def infer_model_signature(model_path, sample_image_path):
"""Infer MLflow model signature using actual model predictions."""
import base64
print("[INFO] Inferring model signature...\n")
# Load YOLO model
model = YOLO(model_path)
# Read and encode image as base64
with open(sample_image_path, 'rb') as f:
image_bytes = f.read()
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
# Create input example
input_example = pd.DataFrame({"image_base64": [image_base64]})
# Create YOLOWrapper instance and get predictions to infer output schema
wrapper = YOLOWrapper()
# Simulate load_context
class MockContext:
def __init__(self, model_path):
self.artifacts = {"yolo_model": model_path}
wrapper.load_context(MockContext(model_path))
# Get output example by running prediction
output_example = wrapper.predict(None, input_example)
# Use MLflow's infer_signature to automatically create signature
signature = mlflow.models.infer_signature(input_example, output_example)
print(f"[OK] Model signature inferred successfully!")
print(f" Input: DataFrame with 'image_base64' column (base64 string)")
print(f" Output: DataFrame with {len(output_example.columns)} columns")
print(f" Columns: {', '.join(output_example.columns.tolist())}")
# Optional: Show how to use manual schema (commented out)
# from mlflow.types.schema import Schema, ColSpec
# from mlflow.models.signature import ModelSignature
# input_schema = Schema([ColSpec("string", "image_base64")])
# output_schema = Schema([
# ColSpec("string", "class_name"),
# ColSpec("long", "class_num"),
# ColSpec("double", "confidence"),
# ColSpec("double", "bbox_x1"),
# ColSpec("double", "bbox_y1"),
# ColSpec("double", "bbox_x2"),
# ColSpec("double", "bbox_y2"),
# ColSpec("double", "bbox_center_x"),
# ColSpec("double", "bbox_center_y"),
# ColSpec("double", "bbox_width"),
# ColSpec("double", "bbox_height")
# ])
# signature = ModelSignature(inputs=input_schema, outputs=output_schema)
return signature, input_example
def setup_mlflow_experiment(use_workspaceUsers_path=False, expt_name_suffix="Experiments_YOLO_CoCo"):
"""Setup MLflow experiment with system metrics enabled.
Args:
use_workspaceUsers_path: If True, derives experiment directory from the
current notebook location under /Workspace/Users/.
If False (default), uses '/Workspace/Shared'.
expt_name_suffix: Name of the experiment folder.
Defaults to 'Experiments_YOLO_CoCo'.
Returns:
tuple: (experiment_name, experiment_id)
"""
if use_workspaceUsers_path:
# Extract username from notebook path and build /Workspace/Users/{username}
# This avoids Git folder paths where MLflow experiment creation is not permitted
notebook_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()
path_parts = notebook_path.strip('/').split('/')
if 'Users' in path_parts:
user_idx = path_parts.index('Users')
username = path_parts[user_idx + 1]
experiment_base_path = f"/Workspace/Users/{username}"
else:
# Fallback if path doesn't contain Users
experiment_base_path = "/Workspace/Shared"
else:
# Default: Use /Workspace/Shared
experiment_base_path = "/Workspace/Shared"
experiment_name = f"{experiment_base_path}/{expt_name_suffix}"
# Ensure the experiment directory exists
os.makedirs(experiment_base_path, exist_ok=True)
os.environ['MLFLOW_ENABLE_SYSTEM_METRICS_LOGGING'] = "true"
os.environ['MLFLOW_EXPERIMENT_NAME'] = experiment_name
mlflow.set_experiment(experiment_name)
experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id
if 'MLFLOW_RUN_ID' in os.environ:
del os.environ['MLFLOW_RUN_ID']
print(f"[OK] MLflow experiment initialized: {experiment_name}")
print(f" Experiment ID: {experiment_id}")
print(f" System metrics: ENABLED")
return experiment_name, experiment_id
class YOLOWrapper(mlflow.pyfunc.PythonModel):
"""Custom MLflow wrapper for YOLO models using base64-encoded images."""
def load_context(self, context):
"""Load YOLO model from artifacts."""
from ultralytics import YOLO
model_path = context.artifacts["yolo_model"]
self.model = YOLO(model_path, task='detect')
def _format_predictions(self, predictions):
"""Format YOLO prediction results with bounding boxes.
Args:
predictions: YOLO prediction results from model.predict()
Returns:
pd.DataFrame with class, confidence, and bounding box coordinates
"""
import pandas as pd
all_results = []
for prediction in predictions:
if prediction.boxes is not None:
boxes = prediction.boxes
for i in range(len(boxes)):
# Get bounding box coordinates in both formats
box_xyxy = boxes.xyxy[i].cpu().numpy()
box_xywh = boxes.xywh[i].cpu().numpy()
all_results.append({
"class_name": prediction.names[int(boxes.cls[i])],
"class_num": int(boxes.cls[i]),
"confidence": float(boxes.conf[i]),
"bbox_x1": float(box_xyxy[0]),
"bbox_y1": float(box_xyxy[1]),
"bbox_x2": float(box_xyxy[2]),
"bbox_y2": float(box_xyxy[3]),
"bbox_center_x": float(box_xywh[0]),
"bbox_center_y": float(box_xywh[1]),
"bbox_width": float(box_xywh[2]),
"bbox_height": float(box_xywh[3])
})
return pd.DataFrame(all_results)
def predict(self, context, model_input):
"""Run YOLO prediction on base64-encoded images.
Args:
context: MLflow context
model_input: DataFrame with 'image_base64' column (base64-encoded images)
Returns:
pd.DataFrame with detection results including bounding boxes
"""
import pandas as pd
import base64
from PIL import Image
import io
import numpy as np
if not isinstance(model_input, pd.DataFrame):
raise ValueError("Input must be a DataFrame with 'image_base64' column")
if 'image_base64' not in model_input.columns:
raise ValueError("DataFrame must contain 'image_base64' column with base64-encoded images")
# Process base64-encoded images
all_predictions = []
for image_base64 in model_input['image_base64'].tolist():
# Decode base64 to image
image_bytes = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_bytes))
image_array = np.array(image)
# Run prediction
predictions = self.model.predict(image_array, verbose=False)
all_predictions.extend(predictions)
return self._format_predictions(all_predictions)
def register_yolo_model(run_id, model_path, catalog_name, schema_name, model_name,
signature=None, input_example=None, data_yaml_path=None):
"""Register YOLO model to Unity Catalog with custom wrapper."""
registered_model_name = f"{catalog_name}.{schema_name}.{model_name}"
ultralytics_version = importlib.metadata.version('ultralytics')
cloudpickle_version = importlib.metadata.version('cloudpickle')
print(f"\n[INFO] Registering model to Unity Catalog...")
print(f" Model name: {registered_model_name}")
print(f" Using custom YOLO wrapper (base64 input, bbox output)")
print(f" Pinning CloudPickle version: {cloudpickle_version}")
with mlflow.start_run(run_id=run_id):
if data_yaml_path:
mlflow.log_artifact(data_yaml_path, "input_data")
mlflow.pyfunc.log_model(
name="model",
python_model=YOLOWrapper(),
artifacts={"yolo_model": model_path},
signature=signature,
input_example=input_example,
registered_model_name=registered_model_name,
pip_requirements=[
f"ultralytics=={ultralytics_version}",
f"cloudpickle=={cloudpickle_version}",
"torch",
"torchvision",
"pillow",
"numpy"
]
)
print(f" [OK] Model registered: {registered_model_name}")
return registered_model_name
def evaluate_model_on_split(model, image_dir, split_name, output_dir, run_id,
registered_model_name, organized_run_name, num_samples=3):
"""Evaluate model on a dataset split and save results."""
print("=" * 60)
print(f"{split_name.upper()} SET EVALUATION")
print("=" * 60)
os.makedirs(output_dir, exist_ok=True)
images = glob.glob(f"{image_dir}/*.jpg")
if not images:
print(f"[WARNING] No {split_name} images found")
return
print(f"\n{split_name.capitalize()} set: {len(images)} images\n")
# Visualize sample predictions
sample_images = images[:num_samples]
fig, axes = plt.subplots(1, len(sample_images), figsize=(15, 5))
if len(sample_images) == 1:
axes = [axes]
results = []
for i, img_path in enumerate(sample_images):
print(f"Sample {i+1}/{len(sample_images)}: {img_path.split('/')[-1]}")
predictions = model.predict(img_path, verbose=False)
if len(predictions) > 0:
result = predictions[0]
annotated_img = result.plot()
axes[i].imshow(annotated_img)
axes[i].axis('off')
if result.boxes is not None:
num_detections = len(result.boxes)
axes[i].set_title(f"{img_path.split('/')[-1]}\n{num_detections} objects", fontsize=10)
print(f" [OK] Detections: {num_detections} objects")
img_results = {
"image": img_path.split('/')[-1],
"num_detections": num_detections,
"detections": []
}
for j in range(min(num_detections, 3)):
class_name = result.names[int(result.boxes.cls[j])]
confidence = float(result.boxes.conf[j])
print(f" - {class_name}: {confidence:.3f}")
img_results["detections"].append({
"class_name": class_name,
"confidence": confidence
})
results.append(img_results)
print()
plt.tight_layout()
plt.suptitle(f"{split_name.capitalize()} Set Predictions - Run {run_id[:8]}", fontsize=14, y=1.02)
plot_path = os.path.join(output_dir, f"{split_name}_predictions.png")
plt.savefig(plot_path, dpi=150, bbox_inches='tight')
print(f"[OK] Plot saved to: {plot_path}")
plt.show()
# Save results JSON
json_path = os.path.join(output_dir, f"{split_name}_results.json")
with open(json_path, 'w') as f:
json.dump({
"run_id": run_id,
"registered_model": registered_model_name,
"timestamp": organized_run_name.split('_run_')[0],
"num_images": len(images),
"sample_results": results
}, f, indent=2)
print(f"[OK] Results saved to: {json_path}")
# Log to MLflow
with mlflow.start_run(run_id=run_id):
mlflow.log_artifact(plot_path, split_name)
mlflow.log_artifact(json_path, split_name)
print(f"\n[OK] {split_name.upper()} SET EVALUATION COMPLETE")
print("=" * 60)
print("[OK] Helper functions loaded successfully")
Unity Catalog configuration
Configure catalog, schema, volume, and project paths.
dbutils.widgets.removeAll()
# Define widgets for catalog, schema, volume, model name, and deployment approval
dbutils.widgets.text("catalog_name", "main", "Catalog Name")
dbutils.widgets.text("schema_name", "default", "Schema Name")
dbutils.widgets.text("volume_name", "yolo_sgc", "Volume Name")
dbutils.widgets.text("model_name", "yolo11n_coco128_sgc", "Model Name")
dbutils.widgets.dropdown("proceed_with_deployment", "false", ["false", "true"], "Proceed with Deployment")
# External artifact sources, pinned to fixed versions. For production or
# network-restricted workspaces, stage these in a Unity Catalog volume or an
# internal mirror and point these widgets there instead of the public internet.
dbutils.widgets.text("model_source_url", "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt", "Pretrained Model URL")
dbutils.widgets.text("dataset_config_url", "https://raw.githubusercontent.com/ultralytics/ultralytics/v8.3.204/ultralytics/cfg/datasets/coco128.yaml", "Dataset Config URL")
# Get widget values
catalog_name = dbutils.widgets.get("catalog_name")
schema_name = dbutils.widgets.get("schema_name")
volume_name = dbutils.widgets.get("volume_name")
model_name = dbutils.widgets.get("model_name")
proceed_with_deployment_str = dbutils.widgets.get("proceed_with_deployment")
model_source_url = dbutils.widgets.get("model_source_url")
dataset_config_url = dbutils.widgets.get("dataset_config_url")
print(f"[Configuration]")
print(f" Catalog: {catalog_name}")
print(f" Schema: {schema_name}")
print(f" Volume: {volume_name}")
print(f" Model: {model_name}")
print(f" Proceed with Deployment: {proceed_with_deployment_str}")
print(f"\nUsing catalog: {catalog_name} (already exists)")
# Create schema if it doesn't exist
spark.sql(f"CREATE SCHEMA IF NOT EXISTS `{catalog_name}`.`{schema_name}`")
print(f"[OK] Schema: {catalog_name}.{schema_name}")
# Create volume for persistent storage
spark.sql(f"CREATE VOLUME IF NOT EXISTS `{catalog_name}`.`{schema_name}`.`{volume_name}`")
print(f"[OK] Volume: {catalog_name}.{schema_name}.{volume_name}")
# Get Unity Catalog parameters
catalog_name = dbutils.widgets.get("catalog_name")
schema_name = dbutils.widgets.get("schema_name")
volume_name = dbutils.widgets.get("volume_name")
# Construct volume path from parameters
project_location = f'/Volumes/{catalog_name}/{schema_name}/{volume_name}/'
print(f"Using Unity Catalog Volume: {catalog_name}.{schema_name}.{volume_name}")
print(f"Volume path: {project_location}")
# Create subdirectories in the volume
os.makedirs(f'{project_location}runs/', exist_ok=True) # Training runs (organized by task/model/dataset)
os.makedirs(f'{project_location}data/', exist_ok=True) # Dataset storage
os.makedirs(f'{project_location}raw_model/', exist_ok=True) # Pretrained models
# Ephemeral /tmp/ location for faster I/O during training
tmp_project_location = "/tmp/training_results/"
os.makedirs(tmp_project_location, exist_ok=True)
print(f"\n[OK] Project directories created:")
print(f" Runs: {project_location}runs/")
print(f" Data: {project_location}data/")
print(f" Raw models: {project_location}raw_model/")
print(f" Temp (training): {tmp_project_location} # Ephemeral, fast I/O")
Project folder structure
Unity Catalog volume organization:
/Volumes/{catalog}/{schema}/{volume}/
├── data/
│ ├── coco128.yaml # Dataset configuration
│ └── coco128/
│ ├── images/
│ │ ├── train2017/ # Original 128 images (from zip)
│ │ ├── train/ val/ test/ # Custom splits (80/24/24)
│ └── labels/
│ ├── train2017/ # Original labels
│ └── train/ val/ test/ # Split labels
│
├── raw_model/
│ └── yolo11n.pt # Pretrained YOLO11n weights
│
└── runs/
└── {task}_{model}_{dataset}_{timestamp}_run_{mlflow_run_id}/
├── train/ # MLflow training outputs
│ ├── weights/ (best.pt, last.pt)
│ └── results.csv, confusion_matrix.png
├── validation_metrics/ # YOLO validation outputs
├── validation_samples/ # Custom evaluation samples
└── test_samples/ # Test evaluation samples
Run Naming: detection_yolo11n_coco128_20260120_143052_run_{mlflow_run_id}
- Includes task, model, dataset, timestamp, and MLflow run ID for easy identification
Download pretrained YOLO model
Download YOLO11n pretrained weights to Unity Catalog volume.
🔒 Artifact sources: This notebook downloads the pretrained model and dataset config from public Ultralytics sources, pinned to fixed versions. For production or network-restricted workspaces, stage these artifacts in a Unity Catalog volume or an internal mirror and repoint the Pretrained Model URL and Dataset Config URL widgets to that location.
# Download pretrained YOLO11n model to Unity Catalog Volume
model_path = f"{project_location}raw_model/yolo11n.pt"
model_url = dbutils.widgets.get("model_source_url")
download_file(model_url, model_path, "YOLO11n model")
print(f"\n[OK] Pretrained model ready at: {model_path}")
Dataset preparation
Download and configure COCO128 dataset.
# Download COCO128 dataset configuration to UC Volume
import yaml
# Create data directory in UC Volume
os.makedirs(f'{project_location}data/coco128', exist_ok=True)
# Download config directly to UC Volume
config_url = dbutils.widgets.get("dataset_config_url")
config_path = f"{project_location}data/coco128.yaml" # UC Volume path
download_file(config_url, config_path, "COCO128 config")
# Load and update configuration
with open(config_path, 'r') as f:
data = yaml.safe_load(f)
print(f"\n[Dataset Configuration]")
print(f" Dataset: {data.get('path', 'coco128')}")
print(f" Classes: {data.get('nc', 'unknown')}")
print(f" Download URL: {data.get('download', 'N/A')}")
# Update paths for Unity Catalog Volume
data['path'] = f"{project_location}data/coco128"
# Check if dataset already exists
dataset_images_dir = f"{project_location}data/coco128/images/train2017"
if os.path.exists(dataset_images_dir) and len(os.listdir(dataset_images_dir)) > 0:
print(f"\n[INFO] Dataset already exists at: {dataset_images_dir}")
print(f" Found {len(os.listdir(dataset_images_dir))} images")
print(f" Skipping download")
else:
# Download and extract dataset
extraction_path = f"{project_location}data"
download_and_extract_dataset(data['download'], extraction_path)
# Save updated configuration to UC Volume
data_yaml_path = f"{project_location}data/coco128.yaml"
with open(data_yaml_path, 'w') as f:
yaml.dump(data, f, default_flow_style=False)
print(f"\n[OK] Dataset configuration saved to UC Volume: {data_yaml_path}")
print(f" All dataset files in: {project_location}data/coco128/")
Dataset splits
Split COCO128 into train (62.5%), val (18.75%), and test (18.75%) sets with reproducible random seed.
# Split dataset into train/val/test with reproducible random seed
source_images_dir = f"{project_location}data/coco128/images/train2017"
source_labels_dir = f"{project_location}data/coco128/labels/train2017"
base_images_dir = f"{project_location}data/coco128/images"
base_labels_dir = f"{project_location}data/coco128/labels"
# Skip if splits already exist (deterministic with seed=42)
train_dir = f"{base_images_dir}/train"
val_dir = f"{base_images_dir}/val"
test_dir = f"{base_images_dir}/test"
if (os.path.isdir(train_dir) and len(os.listdir(train_dir)) > 0
and os.path.isdir(val_dir) and len(os.listdir(val_dir)) > 0
and os.path.isdir(test_dir) and len(os.listdir(test_dir)) > 0):
train_size = len(os.listdir(train_dir))
val_size = len(os.listdir(val_dir))
test_size = len(os.listdir(test_dir))
print(f"[INFO] Splits already exist \u2014 skipping re-split")
else:
train_size, val_size, test_size = split_dataset(
source_images_dir=source_images_dir,
source_labels_dir=source_labels_dir,
base_images_dir=base_images_dir,
base_labels_dir=base_labels_dir,
train_ratio=0.625, # 62.5%
val_ratio=0.1875, # 18.75%
random_seed=42
)
print(f"\nSplit summary:")
print(f" - Train: {train_size} images (62.5%)")
print(f" - Val: {val_size} images (18.75%)")
print(f" - Test: {test_size} images (18.75%)")
print(f" - Random seed: 42")
# Update data.yaml to use train/val/test splits
with open(data_yaml_path, 'r') as f:
yaml_content = yaml.safe_load(f)
# Update paths
yaml_content['train'] = f"{project_location}data/coco128/images/train"
yaml_content['val'] = f"{project_location}data/coco128/images/val"
yaml_content['test'] = f"{project_location}data/coco128/images/test"
# Save updated configuration
with open(data_yaml_path, 'w') as f:
yaml.dump(yaml_content, f, default_flow_style=False)
print(f"[OK] data.yaml updated with train/val/test splits")
print(f" Train: {yaml_content['train']}")
print(f" Val: {yaml_content['val']}")
print(f" Test: {yaml_content['test']}")
MLflow configuration
Infer model signature and configure experiment tracking with system metrics.
# Infer model signature from sample prediction
# This defines the input/output schema for the serving endpoint
# Input: base64-encoded images
# Output: class, confidence, bounding boxes (11 columns)
model_path = f"{project_location}raw_model/yolo11n.pt"
# Find a sample image from training set
sample_images = glob.glob(f"{project_location}data/coco128/images/train/*.jpg")
if sample_images:
signature, input_example = infer_model_signature(model_path, sample_images[0])
print(f"\n[OK] Signature and input example ready for model registration")
else:
print("[WARNING] No sample images found. Run dataset preparation first.")
signature = None
input_example = None
# Validate that signature and input_example are set before training/registering
if 'signature' not in dir() or signature is None:
raise ValueError(
"signature is None — re-run 'Infer MLflow Model Signature' cell before registering."
)
if 'input_example' not in dir() or input_example is None:
raise ValueError(
"input_example is None — re-run 'Infer MLflow Model Signature' cell before registering."
)
print(f"✓ signature: {signature}")
print(f"✓ input_example shape: {input_example.shape}, columns: {list(input_example.columns)}")
# Configure YOLO to use MLflow
from ultralytics import settings
settings.update({"mlflow": True})
# Enable MLflow autologging for system metrics
mlflow.autolog(disable=False)
print(f"\n[MLflow Configuration]")
print(f" YOLO MLflow integration: Enabled")
print(f" MLflow autologging: Enabled")
print(f" System metrics: Enabled")
# Setup MLflow experiment with system metrics
# Default: /Workspace/Shared/Experiments_YOLO_CoCo
# To use /Workspace/Users/ path instead: setup_mlflow_experiment(use_workspaceUsers_path=True, expt_name_suffix="NightlyJob_YOLO_CoCo")
experiment_name, experiment_id = setup_mlflow_experiment(use_workspaceUsers_path=False, expt_name_suffix="Experiments_YOLO_CoCo")
print(f"\n[Ready for Training]")
print(f" Experiment: {experiment_name}")
print(f" Experiment ID: {experiment_id}")
Model training
Train YOLO11n with MLflow tracking and register to Unity Catalog.
from datetime import datetime
import uuid
# Close any active MLflow runs
mlflow.end_run()
# Create unique timestamp and temp directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
tmp_project_location_unique = f"/tmp/training_results_{uuid.uuid4().hex[:8]}/"
os.makedirs(tmp_project_location_unique, exist_ok=True)
# Model configuration
model_path = f"{project_location}raw_model/yolo11n.pt"
model_task = "detection" # Task type
model_arch = "yolo11n" # Model architecture
dataset_name = "coco128" # Dataset name
# ── GPU-aware batch size ──────────────────────────────────────
# Auto-detect GPU and scale batch_size to available VRAM.
# YOLO11n memory footprint at img_size=640:
# batch=8 → ~2.5 GB | batch=16 → ~4.5 GB | batch=32 → ~8 GB
_gpu_name = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu"
_vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if torch.cuda.is_available() else 0
if _vram_gb >= 70: # H100 (80 GB)
batch_size = 32
elif _vram_gb >= 20: # A10G (22.5 GB)
batch_size = 16
else: # Fallback
batch_size = 8
# Append GPU suffix to model/volume names for registry clarity
if "H100" in _gpu_name:
gpu_suffix = "_h100"
elif "A10" in _gpu_name:
gpu_suffix = "_a10"
else:
gpu_suffix = ""
model_name = dbutils.widgets.get("model_name") + gpu_suffix
print(f"[GPU] {_gpu_name} ({_vram_gb:.0f} GB VRAM) → batch_size={batch_size}")
print(f" Volume: {catalog_name}.{schema_name}.{volume_name}")
print(f" Model: {catalog_name}.{schema_name}.{model_name}")
# ── Training hyperparameters ──────────────────────────────────
# Adjust these based on your dataset size and training goals:
# - epochs/patience: increase for larger datasets (e.g. epochs=300, patience=50)
# - learning_rate: lower (0.0001) for fine-tuning, higher (0.01) for scratch
# - dropout: increase (0.3-0.5) if overfitting, decrease (0.0) for large datasets
# - weight_decay: standard 0.0005 works for most cases
epochs = 100
learning_rate = 0.001 # Initial learning rate (YOLO lr0)
patience = 5 # Early stopping patience (epochs without improvement)
dropout = 0.2 # Dropout rate for regularization
weight_decay = 0.0005 # L2 regularization
print(f"\nTraining configuration:")
print(f" GPU: {_gpu_name} ({_vram_gb:.0f} GB)")
print(f" Task: {model_task}")
print(f" Model: {model_arch}")
print(f" Dataset: {dataset_name}")
print(f" Weights: {model_path}")
print(f" Data config: {data_yaml_path}")
print(f" Output: {tmp_project_location_unique}")
print(f" Epochs: {epochs}")
print(f" Batch size: {batch_size} (auto-scaled to GPU)")
print(f" Learning rate (lr0): {learning_rate}")
print(f" Patience: {patience}")
print(f" Dropout: {dropout}")
print(f" Weight decay: {weight_decay}")
print()
# Initialize and train model
print(f"Loading YOLO model...")
model = YOLO(model_path)
print("Starting training...\n")
results = model.train(
task="detect",
batch=batch_size,
device=0, # Single GPU for Serverless
data=data_yaml_path,
epochs=epochs,
lr0=learning_rate, # Initial learning rate
project=tmp_project_location_unique,
name=f"run_{timestamp}",
exist_ok=True,
fliplr=1,
flipud=1,
perspective=0.001,
degrees=0.45,
amp=True,
patience=patience,
dropout=dropout,
weight_decay=weight_decay,
save=True,
save_period=10
)
# Get MLflow run ID
run_id = mlflow.last_active_run().info.run_id
print(f"\n[OK] Training complete! MLflow Run ID: {run_id}")
# Copy training results to Unity Catalog Volume with enhanced naming
print(f"\n[INFO] Copying training results to Unity Catalog Volume...")
training_run_dir = os.path.join(tmp_project_location_unique, f"run_{timestamp}")
# Enhanced folder naming: {task}_{model}_{dataset}_{datetime}_run_{mlflow_run_id}
organized_run_name = f"{model_task}_{model_arch}_{dataset_name}_{timestamp}_run_{run_id}"
volume_run_dir = os.path.join(project_location, "runs", organized_run_name)
# Create train subfolder for MLflow training outputs
volume_train_dir = os.path.join(volume_run_dir, "train")
if os.path.exists(training_run_dir):
shutil.copytree(training_run_dir, volume_train_dir, dirs_exist_ok=True)
print(f" [OK] Training outputs copied to: {volume_train_dir}")
# Validate model and save outputs to UC Volume
print("\n[INFO] Validating model...")
val_metrics_dir = os.path.join(tmp_project_location_unique, "validation_metrics")
os.makedirs(val_metrics_dir, exist_ok=True)
val_results = model.val(
project=tmp_project_location_unique,
name="validation_metrics"
)
# Copy validation metrics to UC Volume (at run level, not in train/)
if os.path.exists(val_metrics_dir):
volume_val_metrics_dir = os.path.join(volume_run_dir, "validation_metrics")
shutil.copytree(val_metrics_dir, volume_val_metrics_dir, dirs_exist_ok=True)
print(f" [OK] Validation metrics copied to: {volume_val_metrics_dir}")
# Save best model in YOLO native format
print("\n[INFO] Saving best model...")
best_model = YOLO(str(model.trainer.best))
best_model_path = f"/tmp/best_yolo_model_{timestamp}.pt"
best_model.save(best_model_path)
print(f" Saved to: {best_model_path}")
# Register model to Unity Catalog using model_name widget
registered_model_name = register_yolo_model(
run_id=run_id,
model_path=best_model_path,
catalog_name=catalog_name,
schema_name=schema_name,
model_name=model_name, # Use widget parameter
signature=signature,
input_example=input_example,
data_yaml_path=data_yaml_path
)
print(f"\n[OK] Training complete!")
print(f"\n[Model Details]")
print(f" - Name: {registered_model_name}")
print(f" - Run ID: {run_id}")
print(f" - Location: Unity Catalog Model Registry")
print(f" - Format: Custom YOLO wrapper (base64 input, bbox output)")
print(f"\n[Training Artifacts]")
print(f" - Volume location: {volume_run_dir}")
print(f" - Run name: {organized_run_name}")
print(f" - Structure: train/, validation_metrics/")
print(f"\n[View Results]")
print(f" {mlflow.get_tracking_uri()}/#/experiments/{experiment_id}/runs/{run_id}")
Model evaluations
Split Evaluation (Native YOLO): Assess model accuracy using file paths from UC Volume. Validates model quality on validation/test sets.
Local Serving Test (MLflow PyFunc): Validate production serving format using base64-encoded images. Ensures endpoint compatibility before deployment.
Split evaluation
Evaluate model performance on validation and test sets before deployment.
# Load model from MLflow and evaluate on validation set
model_uri = f"runs:/{run_id}/model"
model_path = mlflow.artifacts.download_artifacts(model_uri)
# Find the .pt file
import glob as glob_module
pt_files = glob_module.glob(f"{model_path}/**/*.pt", recursive=True)
if pt_files:
loaded_model = YOLO(pt_files[0], task='detect')
print(f"[OK] Model loaded from MLflow\n")
# Evaluate on validation set
val_image_dir = f"{project_location}data/coco128/images/val"
val_output_dir = os.path.join(volume_run_dir, "validation_samples")
evaluate_model_on_split(
model=loaded_model,
image_dir=val_image_dir,
split_name="validation",
output_dir=val_output_dir,
run_id=run_id,
registered_model_name=registered_model_name,
organized_run_name=organized_run_name,
num_samples=3
)
else:
print("[ERROR] Model file not found")
# Evaluate model on test set (uses loaded_model from validation cell)
test_image_dir = f"{project_location}data/coco128/images/test"
test_output_dir = os.path.join(volume_run_dir, "test_samples")
evaluate_model_on_split(
model=loaded_model,
image_dir=test_image_dir,
split_name="test",
output_dir=test_output_dir,
run_id=run_id,
registered_model_name=registered_model_name,
organized_run_name=organized_run_name,
num_samples=3
)
Local serving test
Validate the registered model's serving format (base64 in → bbox out) before deployment.
import base64
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import io
import numpy as np
# Test registered model locally with base64 input (serving format)
print("=" * 60)
print("LOCAL MODEL TEST - BASE64 INPUT")
print("=" * 60)
client = MlflowClient()
# Ensure we have the registered model name (use widget parameter)
if 'registered_model_name' not in dir():
registered_model_name = f"{catalog_name}.{schema_name}.{model_name}"
print(f"\nTesting model: {registered_model_name}\n")
try:
# Get latest model version
model_versions = client.search_model_versions(f"name='{registered_model_name}'")
if model_versions:
latest_version = model_versions[0].version
print(f"[OK] Found model version: {latest_version}")
print(f" Status: {model_versions[0].status}")
# Load model using pyfunc (this is what serving endpoint uses)
model_uri = f"models:/{registered_model_name}/{latest_version}"
serving_model = mlflow.pyfunc.load_model(model_uri)
print(f"[OK] MLflow pyfunc model loaded successfully\n")
# Get test images (skip first 3 used in test_samples evaluation)
test_images = glob.glob(f"{project_location}data/coco128/images/test/*.jpg")
if test_images:
# Use images 10-12 (different from test_samples which uses 1-3)
num_samples = min(3, len(test_images) - 3)
sample_images = test_images[10:10+num_samples] # Skip first 10 test images
print(f"Testing with {num_samples} sample images (different from test_samples)\n")
# Create color map for different classes
colors = plt.cm.tab20(np.linspace(0, 1, 20)) # 20 distinct colors
# Create visualization
fig, axes = plt.subplots(1, num_samples, figsize=(15, 5))
if num_samples == 1:
axes = [axes]
for i, test_image_path in enumerate(sample_images):
print(f"Sample {i+1}/{num_samples}: {test_image_path.split('/')[-1]}")
# Encode image as base64
with open(test_image_path, 'rb') as f:
image_bytes = f.read()
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
# Test pyfunc wrapper with base64 input
input_df = pd.DataFrame({"image_base64": [image_base64]})
predictions = serving_model.predict(input_df)
# Load and display image
image = Image.open(test_image_path)
axes[i].imshow(image)
axes[i].axis('off')
# Draw bounding boxes from pyfunc predictions
if len(predictions) > 0:
num_detections = len(predictions)
# Draw each bounding box with class-specific color
for idx, row in predictions.iterrows():
# Use xyxy coordinates
x1, y1, x2, y2 = row['bbox_x1'], row['bbox_y1'], row['bbox_x2'], row['bbox_y2']
width = x2 - x1
height = y2 - y1
# Get color based on class number
color = colors[int(row['class_num']) % len(colors)]
# Draw rectangle
rect = patches.Rectangle(
(x1, y1), width, height,
linewidth=2, edgecolor=color, facecolor='none'
)
axes[i].add_patch(rect)
# Add label with matching color
label = f"{row['class_name']} {row['confidence']:.2f}"
axes[i].text(
x1, y1 - 5, label,
color='white', fontsize=8,
bbox=dict(facecolor=color, alpha=0.8, pad=2)
)
axes[i].set_title(f"{test_image_path.split('/')[-1]}\n{num_detections} objects", fontsize=10)
print(f" [OK] Detections: {num_detections} objects")
for idx, row in predictions.head(3).iterrows():
print(f" - {row['class_name']}: {row['confidence']:.3f}")
else:
axes[i].set_title(f"{test_image_path.split('/')[-1]}\nNo objects", fontsize=10)
print(f" [OK] No objects detected")
print()
plt.tight_layout()
plt.suptitle(f"Local Serving Test - MLflow PyFunc with Base64 (v{latest_version})", fontsize=14, y=1.02)
plt.show()
print("=" * 60)
print("[OK] MODEL READY FOR DEPLOYMENT")
print("=" * 60)
print(f"\n[Test Summary]")
print(f" - Model: {registered_model_name} (v{latest_version})")
print(f" - Input format: Base64-encoded images ✓")
print(f" - MLflow pyfunc wrapper: ✓")
print(f" - Bounding boxes: ✓ (color-coded by class)")
print(f" - Test images: Different from test_samples evaluation")
print(f" - Status: Validated and ready")
print(f"\n Next: Deploy to serving endpoint")
else:
print(f"[ERROR] No versions found for: {registered_model_name}")
print(f"\nPlease run training cell to register the model first.")
except Exception as e:
print(f"[ERROR] {e}")
import traceback
traceback.print_exc()
Deployment checkpoint
Manual gate before creating/updating the serving endpoint. Set the Proceed with Deployment widget to true to continue.
# ============================================================
# DEPLOYMENT CHECKPOINT
# ============================================================
# This cell acts as a safety gate before deployment cells.
# Set the 'Proceed with Deployment' widget to 'true' to continue.
# ============================================================
# Get deployment approval from widget (set in cell 11)
PROCEED_WITH_DEPLOYMENT = dbutils.widgets.get("proceed_with_deployment") == "true"
if not PROCEED_WITH_DEPLOYMENT:
message = """
============================================================
⚠️ DEPLOYMENT PAUSED - MANUAL CONFIRMATION REQUIRED
============================================================
This checkpoint prevents accidental execution of deployment cells.
[To Proceed]
1. Review the model validation results above
2. Verify the model is ready for deployment
3. Set 'Proceed with Deployment' widget to 'true' (top of notebook)
4. Re-run this cell
[What Happens Next]
- Cell 37: Create/update serving endpoint (AI Gateway enabled automatically)
- Cell 39: Test deployed endpoint
[Safety Note]
This checkpoint ensures you don't accidentally deploy
an unvalidated model or overwrite a production endpoint.
[For 'Run All']
Deployment cells will skip execution if not approved.
No errors will be raised.
============================================================
⏸️ DEPLOYMENT PAUSED - AWAITING APPROVAL
============================================================
"""
dbutils.notebook.exit(message)
else:
message = """
============================================================
✓ DEPLOYMENT CHECKPOINT PASSED
============================================================
[Confirmation]
User has manually approved deployment
Execution will stop here for manual control
[Next Steps - Run Manually]
1. Run cell 37: Create/update serving endpoint (AI Gateway enabled automatically)
2. Wait for endpoint to be ready (10-20 minutes)
3. Run cell 39: Test deployed endpoint
[Why Manual Execution?]
- Endpoint provisioning takes 10-20 minutes
- You can monitor progress in the UI
- Each step requires verification before proceeding
- Prevents accidental 'Run All' through deployment
============================================================
⏸️ STOPPING HERE - RUN DEPLOYMENT CELLS MANUALLY
============================================================
"""
dbutils.notebook.exit(message)
Model deployment
Deploy to Model Serving endpoint with AI Gateway and inference table logging.
# Check deployment approval from cell 35
if 'PROCEED_WITH_DEPLOYMENT' not in dir() or not PROCEED_WITH_DEPLOYMENT:
print("\n" + "="*60)
print("⚠️ DEPLOYMENT SKIPPED - NOT APPROVED")
print("="*60)
print("\n[Reason]")
print(" PROCEED_WITH_DEPLOYMENT flag is not set to True")
print("\n[To Enable Deployment]")
print(" 1. Set 'Proceed with Deployment' widget to 'true' (top of notebook)")
print(" 2. Re-run cell 35 (checkpoint) and this cell")
print("\n" + "="*60)
else:
# Deployment approved - proceed with endpoint creation
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
ServedEntityInput,
EndpointCoreConfigInput,
AiGatewayConfig,
AiGatewayInferenceTableConfig
)
from mlflow.tracking import MlflowClient
import time
w = WorkspaceClient()
client = MlflowClient()
# Get latest model version (use widget parameter)
if 'registered_model_name' not in dir():
registered_model_name = f"{catalog_name}.{schema_name}.{model_name}"
model_versions = client.search_model_versions(
f"name='{registered_model_name}'"
)
if not model_versions:
raise ValueError(f"No model versions found for {registered_model_name}. Run training cell to register the model.")
# Unity Catalog does not support order_by; sort versions in Python instead
model_versions = sorted(model_versions, key=lambda mv: int(mv.version), reverse=True)
model_version = model_versions[0].version # latest version (sorted DESC)
# Derive endpoint name from model name
model_name_only = registered_model_name.split('.')[-1]
endpoint_name = f"{model_name_only}_endpoint"
print("=" * 60)
print("CREATING MODEL SERVING ENDPOINT")
print("=" * 60)
print(f"\nEndpoint configuration:")
print(f" Name: {endpoint_name}")
print(f" Model: {registered_model_name}")
print(f" Version: {model_version}")
print(f" Workload size: Small")
print(f" Scale to zero: Enabled")
print(f" AI Gateway: Enabled with inference tables")
print(f" Inference table: {catalog_name}.{schema_name}.{endpoint_name}_payload")
print()
try:
# Check if endpoint already exists
endpoint_exists = False
needs_update = True
needs_ai_gateway_update = False
try:
existing_endpoint = w.serving_endpoints.get(endpoint_name)
endpoint_exists = True
print(f"[INFO] Endpoint '{endpoint_name}' already exists")
# Check if endpoint is currently being updated
if existing_endpoint.state.config_update.value != "NOT_UPDATING":
print(f"[INFO] Endpoint is currently being updated (status: {existing_endpoint.state.config_update.value})")
print(f" Checking status briefly (will timeout after 2 minutes)...\n")
# Brief check for current update status
max_wait_time = 120 # Only wait 2 minutes
poll_interval = 10
elapsed_time = 0
while elapsed_time < max_wait_time:
endpoint = w.serving_endpoints.get(endpoint_name)
if endpoint.state.config_update.value == "NOT_UPDATING":
print(f"\n[OK] Current update completed (took {elapsed_time}s)")
existing_endpoint = endpoint
break
elif endpoint.state.config_update.value == "UPDATE_FAILED":
print(f"\n[WARNING] Current update failed")
existing_endpoint = endpoint
break
else:
if elapsed_time % 30 == 0:
print(f" Status: {endpoint.state.config_update.value} ({elapsed_time}s elapsed)")
time.sleep(poll_interval)
elapsed_time += poll_interval
if elapsed_time >= max_wait_time:
print(f"\n[INFO] Endpoint update still in progress after {max_wait_time}s")
print(f" This cell will complete now to avoid blocking")
print(f"\n[NEXT STEP]")
print(f" 1. Wait a few minutes for the update to complete")
print(f" 2. Re-run this cell to check status")
print(f" 3. Once ready, proceed to testing")
# Get final status and exit
endpoint = w.serving_endpoints.get(endpoint_name)
workspace_url = spark.conf.get("spark.databricks.workspaceUrl")
endpoint_url = f"https://{workspace_url}/ml/endpoints/{endpoint_name}"
print(f"\n[View Endpoint]")
print(f" {endpoint_url}")
print(f"\n[Current Status]")
print(f" - Config State: {endpoint.state.config_update.value}")
print(f" - Ready State: {endpoint.state.ready.value}")
# Exit early - don't proceed to update logic
raise SystemExit("Endpoint update in progress - cell completed to avoid blocking")
# Check if it's already serving the same model version
current_config = existing_endpoint.config
if current_config and current_config.served_entities:
current_entity = current_config.served_entities[0]
current_model = current_entity.entity_name
current_version = current_entity.entity_version
if current_model == registered_model_name and current_version == str(model_version):
print(f" Already serving {registered_model_name} version {model_version}")
print(f" No model update needed")
needs_update = False
# Check if AI Gateway inference tables are enabled
ai_gateway = existing_endpoint.ai_gateway
if ai_gateway and ai_gateway.inference_table_config and ai_gateway.inference_table_config.enabled:
print(f" AI Gateway inference tables already enabled")
print(f" No AI Gateway update needed\n")
else:
print(f" AI Gateway inference tables not enabled")
print(f" Will enable AI Gateway\n")
needs_ai_gateway_update = True
else:
print(f" Currently serving: {current_model} v{current_version}")
print(f" Updating to: {registered_model_name} v{model_version}")
print(f" Note: AI Gateway must be configured separately for updates\n")
needs_ai_gateway_update = True
except Exception as e:
if "does not exist" in str(e).lower() or "RESOURCE_DOES_NOT_EXIST" in str(e):
print(f"[INFO] Endpoint '{endpoint_name}' does not exist")
print(f" Creating new endpoint with AI Gateway enabled...\n")
else:
raise
# Create/update endpoint if needed
if needs_update:
if endpoint_exists:
# Update existing endpoint using SDK method
# Note: update_config() doesn't support ai_gateway parameter
w.serving_endpoints.update_config(
name=endpoint_name,
served_entities=[
ServedEntityInput(
entity_name=registered_model_name,
entity_version=str(model_version),
workload_size="Small",
scale_to_zero_enabled=True
)
]
)
print(f"[OK] Endpoint update submitted")
else:
# Create new endpoint with AI Gateway enabled using SDK method
w.serving_endpoints.create(
name=endpoint_name,
config=EndpointCoreConfigInput(
served_entities=[
ServedEntityInput(
entity_name=registered_model_name,
entity_version=str(model_version),
workload_size="Small",
scale_to_zero_enabled=True
)
]
),
ai_gateway=AiGatewayConfig(
inference_table_config=AiGatewayInferenceTableConfig(
catalog_name=catalog_name,
schema_name=schema_name,
table_name_prefix=endpoint_name,
enabled=True
)
)
)
print(f"[OK] Endpoint creation submitted (with AI Gateway enabled)")
# Brief initial wait with shorter timeout to avoid stuck state
print(f"\n[INFO] Checking initial status (endpoint provisioning may take 10-20+ minutes)...")
max_wait_time = 120 # Only wait 2 minutes here
poll_interval = 10 # Check every 10 seconds
elapsed_time = 0
while elapsed_time < max_wait_time:
endpoint = w.serving_endpoints.get(endpoint_name)
if endpoint.state.config_update.value == "NOT_UPDATING" and endpoint.state.ready.value == "READY":
print(f"\n[OK] Endpoint is ready! (took {elapsed_time}s)")
break
elif endpoint.state.config_update.value == "UPDATE_FAILED":
print(f"\n[ERROR] Endpoint update failed!")
print(f" Check the endpoint UI for error details")
break
else:
if elapsed_time % 30 == 0: # Print status every 30 seconds
print(f" Status: {endpoint.state.config_update.value} ({elapsed_time}s elapsed)")
time.sleep(poll_interval)
elapsed_time += poll_interval
if elapsed_time >= max_wait_time:
print(f"\n[INFO] Endpoint is still initializing (this may take several more minutes)")
print(f" This cell will complete now to avoid blocking")
print(f"\n[NEXT STEP]")
print(f" 1. Wait for endpoint to finish provisioning (check UI)")
print(f" 2. Re-run this cell to verify status")
if needs_ai_gateway_update:
print(f" 3. Once ready, run 'Enable AI Gateway Inference Tables' cell")
else:
print(f" 3. Once ready, proceed to testing")
# Enable AI Gateway if needed (for existing endpoints that were updated)
if needs_ai_gateway_update and endpoint_exists:
print(f"\n[INFO] Enabling AI Gateway inference tables...")
# First verify endpoint is ready
endpoint = w.serving_endpoints.get(endpoint_name)
if endpoint.state.ready.value != "READY":
print(f"[WARNING] Endpoint not ready yet (status: {endpoint.state.ready.value})")
print(f" Run 'Enable AI Gateway Inference Tables' cell once endpoint is ready")
else:
# Enable AI Gateway (table will be created automatically by AI Gateway)
w.serving_endpoints.put_ai_gateway(
name=endpoint_name,
inference_table_config=AiGatewayInferenceTableConfig(
catalog_name=catalog_name,
schema_name=schema_name,
table_name_prefix=endpoint_name,
enabled=True
)
)
print(f"[OK] AI Gateway configuration submitted")
# Brief wait for configuration
time.sleep(5)
max_wait = 60
elapsed = 0
while elapsed < max_wait:
endpoint = w.serving_endpoints.get(endpoint_name)
if endpoint.state.config_update.value == "NOT_UPDATING":
break
time.sleep(5)
elapsed += 5
# Get final status
endpoint = w.serving_endpoints.get(endpoint_name)
# Get workspace URL for link
workspace_url = spark.conf.get("spark.databricks.workspaceUrl")
endpoint_url = f"https://{workspace_url}/ml/endpoints/{endpoint_name}"
print("\n" + "=" * 60)
if endpoint.state.config_update.value == "NOT_UPDATING" and endpoint.state.ready.value == "READY":
print("[OK] SERVING ENDPOINT READY")
else:
print("[INFO] SERVING ENDPOINT INITIALIZING")
print("=" * 60)
print(f"\n[Endpoint Details]")
print(f" - Name: {endpoint_name}")
print(f" - Model: {registered_model_name} (v{model_version})")
print(f" - Config State: {endpoint.state.config_update.value}")
print(f" - Ready State: {endpoint.state.ready.value}")
# Check AI Gateway status
if endpoint.ai_gateway and endpoint.ai_gateway.inference_table_config:
ai_config = endpoint.ai_gateway.inference_table_config
if ai_config.enabled:
print(f" - AI Gateway: Enabled")
print(f" - Inference Table: {ai_config.catalog_name}.{ai_config.schema_name}.{ai_config.table_name_prefix}_payload")
else:
print(f" - AI Gateway: Disabled")
else:
print(f" - AI Gateway: Not configured")
print(f"\n[View Endpoint]")
print(f" {endpoint_url}")
if endpoint.state.config_update.value == "NOT_UPDATING" and endpoint.state.ready.value == "READY":
if endpoint.ai_gateway and endpoint.ai_gateway.inference_table_config and endpoint.ai_gateway.inference_table_config.enabled:
print(f"\n[Next Step]")
print(f" Run 'Test Deployed Endpoint' cell")
else:
print(f"\n[Next Step]")
print(f" Run 'Enable AI Gateway Inference Tables' cell")
else:
print(f"\n[Next Step]")
print(f" Wait for endpoint to be ready, then re-run this cell")
except SystemExit as e:
# Clean exit when endpoint is still updating
print(f"\n[INFO] Cell completed (endpoint update in progress)")
except Exception as e:
print(f"[ERROR] Failed to create/update endpoint: {e}")
import traceback
traceback.print_exc()
AI Gateway inference tables
AI Gateway is configured automatically when creating new endpoints. For existing endpoints being updated, AI Gateway is enabled separately after the model update.
Key behaviors:
- Table Creation: AI Gateway creates the inference table automatically AFTER the first request is made to the endpoint, not when AI Gateway is configured. The table structure is created immediately, but remains empty until requests are logged.
- Logging Delay: There is typically a delay (usually 2-5 minutes) between when an inference request is made and when the request/response data appears in the payload table. This is normal behavior; the data is being processed and written asynchronously.
- Verification: After running the test endpoint cell below, wait a few minutes then query the table to see logged requests:
SQL
SELECT * FROM `main`.`default`.`yolo11n_coco128_sgc_endpoint_payload`
ORDER BY timestamp_ms DESC LIMIT 10
# Check deployment approval
if 'PROCEED_WITH_DEPLOYMENT' not in dir() or not PROCEED_WITH_DEPLOYMENT:
print("\n" + "="*60)
print("⚠️ DEPLOYMENT SKIPPED - NOT APPROVED")
print("="*60)
print("\n[Reason]")
print(" PROCEED_WITH_DEPLOYMENT flag is not set to True")
print("\n[To Enable Deployment]")
print(" 1. Go to cell 35")
print(" 2. Set PROCEED_WITH_DEPLOYMENT = True")
print(" 3. Re-run cell 34 and deployment cells")
print("\n" + "="*60)
else:
# Deployment approved - proceed with endpoint testing
import json
import base64
import glob
import random
print("=" * 60)
print("TESTING DEPLOYED ENDPOINT")
print("=" * 60)
print(f"\nEndpoint: {endpoint_name}\n")
try:
test_images = glob.glob(f"{project_location}data/coco128/images/test/*.jpg")
if test_images:
test_image_path = random.choice(test_images)
print(f"Test image: {os.path.basename(test_image_path)} (randomly selected from {len(test_images)} test images)")
# Encode image as base64
print(f"[INFO] Encoding image as base64...")
with open(test_image_path, 'rb') as f:
image_bytes = f.read()
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
# Test endpoint with base64 input
print(f"[INFO] Testing endpoint with base64 input...\n")
input_data = {"dataframe_records": [{"image_base64": image_base64}]}
response = w.serving_endpoints.query(
name=endpoint_name,
dataframe_records=input_data["dataframe_records"]
)
print(f"[OK] Endpoint test successful!\n")
print(f"Response preview:")
response_dict = response.as_dict()
print(json.dumps(response_dict, indent=2)[:500])
print("\n" + "=" * 60)
print("[OK] DEPLOYMENT COMPLETE")
print("=" * 60)
print(f"\n[Deployment Summary]")
print(f" - Endpoint: {endpoint_name}")
print(f" - Model: {registered_model_name} (v{model_version})")
print(f" - Status: Ready and tested")
print(f" - AI Gateway: Enabled")
print(f" - Input format: Base64-encoded images")
print(f" - Inference table: {catalog_name}.{schema_name}.{endpoint_name}_payload")
# Get workspace URL for links
workspace_url = spark.conf.get("spark.databricks.workspaceUrl")
endpoint_url = f"https://{workspace_url}/ml/endpoints/{endpoint_name}"
table_url = f"https://{workspace_url}/explore/data/{catalog_name}/{schema_name}/{endpoint_name}_payload"
print(f"\n[Links]")
print(f" - Endpoint: {endpoint_url}")
print(f" - Inference table: {table_url}")
print(f"\n[Usage Example]")
print(f" import base64")
print(f" with open('image.jpg', 'rb') as f:")
print(f" img_b64 = base64.b64encode(f.read()).decode('utf-8')")
print(f" ")
print(f" w.serving_endpoints.query(")
print(f" name='{endpoint_name}',")
print(f" dataframe_records=[{{'image_base64': img_b64}}]")
print(f" )")
print(f"\n[Monitor Inference]")
print(f" SELECT * FROM {catalog_name}.{schema_name}.{endpoint_name}_payload")
print(f" ORDER BY timestamp_ms DESC LIMIT 10")
else:
print("[WARNING] No test images found")
except Exception as e:
print(f"[ERROR] Endpoint test failed: {e}")
import traceback
traceback.print_exc()
print(f"\n[INFO] Verify endpoint is ready and AI Gateway is configured")
# ============================================================
# VISUALIZE ENDPOINT PREDICTIONS
# ============================================================
# Draws bounding boxes returned by the serving endpoint onto
# the test image used in the previous cell.
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
import numpy as np
if 'response_dict' not in dir() or 'test_image_path' not in dir():
print("[SKIP] Run the 'Test Deployed Endpoint' cell first.")
else:
img = Image.open(test_image_path)
fig, (ax_orig, ax) = plt.subplots(1, 2, figsize=(20, 9))
# Left panel: original image
ax_orig.imshow(img)
ax_orig.set_title("Original", fontsize=13)
ax_orig.axis("off")
# Right panel: image with endpoint detections drawn
ax.imshow(img)
# Parse predictions from the endpoint response
predictions = response_dict.get("predictions", [])
if isinstance(predictions, list) and len(predictions) > 0:
# predictions is a list of dicts (one per detection)
# or a list with one element that is a list of dicts
if isinstance(predictions[0], dict) and "class_name" in predictions[0]:
detections = predictions
elif isinstance(predictions[0], list):
detections = predictions[0]
else:
detections = predictions
else:
detections = []
# Color map for distinct classes
unique_classes = list({d.get("class_name", "?") for d in detections})
cmap = plt.get_cmap("tab10", max(len(unique_classes), 1))
class_colors = {cls: cmap(i) for i, cls in enumerate(unique_classes)}
for det in detections:
x1 = det.get("bbox_x1", 0)
y1 = det.get("bbox_y1", 0)
x2 = det.get("bbox_x2", 0)
y2 = det.get("bbox_y2", 0)
cls_name = det.get("class_name", "?")
conf = det.get("confidence", 0)
color = class_colors.get(cls_name, "lime")
rect = patches.Rectangle(
(x1, y1), x2 - x1, y2 - y1,
linewidth=2, edgecolor=color, facecolor="none"
)
ax.add_patch(rect)
ax.text(
x1, y1 - 4,
f"{cls_name} {conf:.2f}",
fontsize=9, color="white",
bbox=dict(facecolor=color, alpha=0.7, pad=1, edgecolor="none")
)
ax.set_title(f"Endpoint Detections — {len(detections)} objects", fontsize=13)
ax.axis("off")
plt.tight_layout()
plt.show()
print(f"\n[INFO] {len(detections)} detections drawn from endpoint response")
print(f" Classes: {', '.join(sorted(unique_classes))}")
# ============================================================
# DETECTION SUMMARY TABLE
# ============================================================
# Tabular summary of the endpoint detections drawn in the cell above.
import pandas as pd
if 'detections' not in dir():
print("[SKIP] Run the 'Visualize Endpoint Predictions' cell first.")
elif len(detections) == 0:
print("[INFO] No detections returned by the endpoint for this image.")
else:
det_df = pd.DataFrame(detections)
# Per-class summary: count + confidence stats
summary_df = (
det_df.groupby("class_name")
.agg(
count=("confidence", "size"),
avg_confidence=("confidence", "mean"),
min_confidence=("confidence", "min"),
max_confidence=("confidence", "max"),
)
.round(3)
.reset_index()
.sort_values("count", ascending=False)
.reset_index(drop=True)
)
# Per-detection detail, sorted by confidence
detail_cols = [c for c in ["class_name", "confidence", "bbox_x1", "bbox_y1", "bbox_x2", "bbox_y2"] if c in det_df.columns]
detail_df = det_df[detail_cols].copy()
for c in detail_df.columns:
if c != "class_name":
detail_df[c] = detail_df[c].round(2)
detail_df = detail_df.sort_values("confidence", ascending=False).reset_index(drop=True)
print(f"[Per-class summary] {len(detections)} detections across {summary_df.shape[0]} class(es)")
print(summary_df.to_string(index=False))
print("\n[Per-detection detail]")
print(detail_df.to_string(index=False))
Next steps
- Learn more about Databricks AI Runtime and the AI Runtime environment versions.
- Explore Model Serving to deploy and manage your model endpoint.
- Scale this workflow to a larger dataset (1K+ images) by updating the data paths. See NuInsSeg for a real-world example.