databricks-logo

mlflow-classic-ml-e2e

(Python)
Loading...

XGBoost MLflow tutorial

This tutorial covers the full lifecycle of experimentation, training, tuning, registration, evaluation, and deployment for a classic ML modeling project. It shows you how to use MLflow to keep track of every aspect of the model development and deployment processes.

In this step-by-step tutorial, you'll discover how to:

  • Generate and visualize data: Create synthetic data to simulate real-world scenarios, and visualize feature relationships with Seaborn.
  • Train and log models: Train an XGBoost model, and log important metrics, parameters, and artifacts using MLflow, including visualizations.
  • Register models: Register your model with Unity Catalog, preparing it for review and future deployment to managed serving endpoints.
  • Load and evaluate models: Load your registered model, make predictions, and perform error analysis to validate model performance.
(Optional) Install the latest version of MLflow
%pip install -Uqqq mlflow xgboost optuna uv
%restart_python
3
from typing import Tuple, Optional

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from pandas.api.types import CategoricalDtype
from statsmodels.graphics.mosaicplot import mosaic

import xgboost as xgb

import mlflow
from mlflow.models import infer_signature

0. Configure the Model Registry with Unity Catalog

One of the key advantages of using MLflow on Databricks is the seamless integration with Unity Catalog. This integration simplifies model management and governance, ensuring that every model you develop is tracked, versioned, and secure. For more information about Unity Catalog, see (AWS | Azure | GCP).

Set the registry URI

The following cell configures MLflow to use Unity Catalog for model registration.

mlflow.set_registry_uri("databricks-uc")

1. Create a synthetic regression dataset

The next cell defines the create_regression_data function. This function generates synthetic data for regression. The resulting DataFrame includes correlated data, cyclical patterns, and outliers. These features are designed to mimic real-world data scenarios.

def create_regression_data(
    n_samples: int, 
    n_features: int,
    seed: int = 1994,
    noise_level: float = 0.3,
    nonlinear: bool = True
) -> Tuple[pd.DataFrame, pd.Series]:
    """Generates synthetic regression data with interesting correlations for MLflow and XGBoost demonstrations.

    This function creates a DataFrame of continuous features and computes a target variable with nonlinear
    relationships and interactions between features. The data is designed to be complex enough to demonstrate
    the capabilities of XGBoost, but not so complex that a reasonable model can't be learned.

    Args:
        n_samples (int): Number of samples (rows) to generate.
        n_features (int): Number of feature columns.
        seed (int, optional): Random seed for reproducibility. Defaults to 1994.
        noise_level (float, optional): Level of Gaussian noise to add to the target. Defaults to 0.3.
        nonlinear (bool, optional): Whether to add nonlinear feature transformations. Defaults to True.

    Returns:
        Tuple[pd.DataFrame, pd.Series]:
            - pd.DataFrame: DataFrame containing the synthetic features.
            - pd.Series: Series containing the target labels.

    Example:
        >>> df, target = create_regression_data(n_samples=1000, n_features=10)
    """
    rng = np.random.RandomState(seed)
    
    # Generate random continuous features
    X = rng.uniform(-5, 5, size=(n_samples, n_features))
    
    # Create feature DataFrame with meaningful names
    columns = [f"feature_{i}" for i in range(n_features)]
    df = pd.DataFrame(X, columns=columns)
    
    # Generate base target variable with linear relationship to a subset of features
    # Use only the first n_features//2 features to create some irrelevant features
    weights = rng.uniform(-2, 2, size=n_features//2)
    target = np.dot(X[:, :n_features//2], weights)
    
    # Add some nonlinear transformations if requested
    if nonlinear:
        # Add square term for first feature
        target += 0.5 * X[:, 0]**2
        
        # Add interaction between the second and third features
        if n_features >= 3:
            target += 1.5 * X[:, 1] * X[:, 2]
        
        # Add sine transformation of fourth feature
        if n_features >= 4:
            target += 2 * np.sin(X[:, 3])
        
        # Add exponential of fifth feature, scaled down
        if n_features >= 5:
            target += 0.1 * np.exp(X[:, 4] / 2)
            
        # Add threshold effect for sixth feature
        if n_features >= 6:
            target += 3 * (X[:, 5] > 1.5).astype(float)
    
    # Add Gaussian noise
    noise = rng.normal(0, noise_level * target.std(), size=n_samples)
    target += noise
    
    # Add a few more interesting features to the DataFrame
    
    # Add a correlated feature (but not used in target calculation)
    if n_features >= 7:
        df['feature_correlated'] = df['feature_0'] * 0.8 + rng.normal(0, 0.2, size=n_samples)
    
    # Add a cyclical feature
    df['feature_cyclical'] = np.sin(np.linspace(0, 4*np.pi, n_samples))
    
    # Add a feature with outliers
    df['feature_with_outliers'] = rng.normal(0, 1, size=n_samples)
    # Add outliers to ~1% of samples
    outlier_idx = rng.choice(n_samples, size=n_samples//100, replace=False)
    df.loc[outlier_idx, 'feature_with_outliers'] = rng.uniform(10, 15, size=len(outlier_idx))
    
    return df, pd.Series(target, name='target')

2. Exploratory data analysis (EDA) visualizations

Before training your model, it’s essential to examine your data. Visualizations help you validate that the data is as expected, spot unexpected anomalies, and drive feature selection. As you move forward with model development, these visualizations serve as a record of your work that can help with troubleshooting, reproducibility, and collaboration.

You can use MLflow to log visualizations, making your experimentation fully reproducible.

The code in the following cell creates 6 functions, each of which generates a different plot to help you visually inspect your dataset.

def plot_feature_distributions(X: pd.DataFrame, y: pd.Series, n_cols: int = 3) -> plt.Figure:
    """
    Creates a grid of histograms for each feature in the dataset.

    Args:
        X (pd.DataFrame): DataFrame containing synthetic features.
        y (pd.Series): Series containing the target variable.
        n_cols (int): Number of columns in the grid layout.

    Returns:
        plt.Figure: The matplotlib Figure object containing the distribution plots.
    """
    features = X.columns
    n_features = len(features)
    n_rows = (n_features + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4 * n_rows))
    axes = axes.flatten() if n_rows * n_cols > 1 else [axes]
    
    for i, feature in enumerate(features):
        if i < len(axes):
            ax = axes[i]
            sns.histplot(X[feature], ax=ax, kde=True, color='skyblue')
            ax.set_title(f'Distribution of {feature}')
    
    # Hide any unused subplots
    for i in range(n_features, len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    fig.suptitle('Feature Distributions', y=1.02, fontsize=16)
    plt.close(fig)
    return fig

def plot_correlation_heatmap(X: pd.DataFrame, y: pd.Series) -> plt.Figure:
    """
    Creates a correlation heatmap of all features and the target variable.

    Args:
        X (pd.DataFrame): DataFrame containing features.
        y (pd.Series): Series containing the target variable.

    Returns:
        plt.Figure: The matplotlib Figure object containing the heatmap.
    """
    # Combine features and target into one DataFrame
    data = X.copy()
    data['target'] = y
    
    # Calculate correlation matrix
    corr_matrix = data.corr()
    
    # Set up the figure
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Draw the heatmap with a color bar
    cmap = sns.diverging_palette(220, 10, as_cmap=True)
    sns.heatmap(corr_matrix, annot=True, fmt='.2f', cmap=cmap,
                center=0, square=True, linewidths=0.5, ax=ax)
    
    ax.set_title('Feature Correlation Heatmap', fontsize=16)
    plt.close(fig)
    return fig

def plot_feature_target_relationships(X: pd.DataFrame, y: pd.Series, n_cols: int = 3) -> plt.Figure:
    """
    Creates a grid of scatter plots showing the relationship between each feature and the target.

    Args:
        X (pd.DataFrame): DataFrame containing features.
        y (pd.Series): Series containing the target variable.
        n_cols (int): Number of columns in the grid layout.

    Returns:
        plt.Figure: The matplotlib Figure object containing the relationship plots.
    """
    features = X.columns
    n_features = len(features)
    n_rows = (n_features + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4 * n_rows))
    axes = axes.flatten() if n_rows * n_cols > 1 else [axes]
    
    for i, feature in enumerate(features):
        if i < len(axes):
            ax = axes[i]
            # Scatter plot with regression line
            sns.regplot(x=X[feature], y=y, ax=ax, 
                       scatter_kws={'alpha': 0.5, 'color': 'blue'}, 
                       line_kws={'color': 'red'})
            ax.set_title(f'{feature} vs Target')
    
    for i in range(n_features, len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    fig.suptitle('Feature vs Target Relationships', y=1.02, fontsize=16)
    plt.close(fig)
    return fig

def plot_pairwise_relationships(X: pd.DataFrame, y: pd.Series, features: list[str]) -> plt.Figure:
    """
    Creates a pairplot showing relationships between selected features and the target.

    Args:
        X (pd.DataFrame): DataFrame containing features.
        y (pd.Series): Series containing the target variable.
        features (List[str]): List of feature names to include in the plot.

    Returns:
        plt.Figure: The matplotlib Figure object containing the pairplot.
    """
    # Ensure features exist in the DataFrame
    valid_features = [f for f in features if f in X.columns]
    
    if not valid_features:
        fig, ax = plt.subplots()
        ax.text(0.5, 0.5, "No valid features provided", ha='center', va='center')
        return fig
    
    # Combine selected features and target
    data = X[valid_features].copy()
    data['target'] = y
    
    # Create pairplot
    pairgrid = sns.pairplot(data, diag_kind="kde", 
                          plot_kws={"alpha": 0.6, "s": 50},
                          corner=True)
    
    pairgrid.fig.suptitle("Pairwise Feature Relationships", y=1.02, fontsize=16)
    plt.close(pairgrid.fig)
    return pairgrid.fig

def plot_boxplots(X: pd.DataFrame, y: pd.Series, n_cols: int = 3) -> plt.Figure:
    """
    Creates a grid of box plots for each feature, with points colored by target value.

    Args:
        X (pd.DataFrame): DataFrame containing features.
        y (pd.Series): Series containing the target variable.
        n_cols (int): Number of columns in the grid layout.

    Returns:
        plt.Figure: The matplotlib Figure object containing the box plots.
    """
    features = X.columns
    n_features = len(features)
    n_rows = (n_features + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4 * n_rows))
    axes = axes.flatten() if n_rows * n_cols > 1 else [axes]
    
    # Create target bins for coloring
    y_binned = pd.qcut(y, 3, labels=['Low', 'Medium', 'High'])
    
    for i, feature in enumerate(features):
        if i < len(axes):
            ax = axes[i]
            # Box plot for each feature
            sns.boxplot(x=y_binned, y=X[feature], ax=ax)
            ax.set_title(f'Distribution of {feature} by Target Range')
            ax.set_xlabel('Target Range')
    
    # Hide any unused subplots
    for i in range(n_features, len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    fig.suptitle('Feature Distributions by Target Range', y=1.02, fontsize=16)
    plt.close(fig)
    return fig

def plot_outliers(X: pd.DataFrame, n_cols: int = 3) -> plt.Figure:
    """
    Creates a grid of box plots to detect outliers in each feature.

    Args:
        X (pd.DataFrame): DataFrame containing features.
        n_cols (int): Number of columns in the grid layout.

    Returns:
        plt.Figure: The matplotlib Figure object containing the outlier plots.
    """
    features = X.columns
    n_features = len(features)
    n_rows = (n_features + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4 * n_rows))
    axes = axes.flatten() if n_rows * n_cols > 1 else [axes]
    
    for i, feature in enumerate(features):
        if i < len(axes):
            ax = axes[i]
            # Box plot to detect outliers
            sns.boxplot(x=X[feature], ax=ax, color='skyblue')
            ax.set_title(f'Outlier Detection for {feature}')
            ax.set_xlabel(feature)
    
    # Hide any unused subplots
    for i in range(n_features, len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    fig.suptitle('Outlier Detection for Features', y=1.02, fontsize=16)
    plt.close(fig)
    return fig

3. Standard modeling workflow

The code in the next cell does the following:

  1. Uses the function you created, create_regression_data, to create a dataset.
  2. Uses the visualization functions you created to create EDA plots.
  3. Configures and trains an XGBoost model.
  4. Uses the trained model to make predictions on the test dataset.
# Create the regression dataset
n_samples = 1000
n_features = 10
X, y = create_regression_data(n_samples=n_samples, n_features=n_features, nonlinear=True)

# Create EDA plots
dist_plot = plot_feature_distributions(X, y)
corr_plot = plot_correlation_heatmap(X, y)
scatter_plot = plot_feature_target_relationships(X, y)
corr_with_target = X.corrwith(y).abs().sort_values(ascending=False)
top_features = corr_with_target.head(4).index.tolist()
pairwise_plot = plot_pairwise_relationships(X, y, top_features)
outlier_plot = plot_outliers(X)

# Configure the XGBoost model
reg = xgb.XGBRegressor(
    tree_method="hist",
    n_estimators=100,
    learning_rate=0.1,
    max_depth=6,
    subsample=0.8,
    colsample_bytree=0.8,
    eval_metric='rmse',
)

# Create train/test split to properly evaluate the model
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7722)

# Train the model with evaluation
reg.fit(
    X_train, y_train,
    eval_set=[(X_train, y_train), (X_test, y_test)],
    verbose=False
)

# Generate predictions for residual plot
y_pred = reg.predict(X_test)
residual_plot = plot_boxplots(X, y)


4. Log the model using MLflow

When you log a model using MLflow on Databricks, important artifacts and metadata are captured. This ensures that your model is not only reproducible but also ready for deployment with all necessary dependencies and clear API contracts. For details on what is logged, see the MLflow documentation.

The code in the next cell starts an MLflow run using with mlflow.start_run():. This initializes the MLflow context manager for the run and encloses the run in a code block. When the code block ends, all logged metrics, parameters, and artifacts are saved, and the MLflow run is automatically terminated.

# Incorporate MLflow evaluation
evaluation_data = X_test.copy()
evaluation_data["label"] = y_test

# Log the model and training metadata results
with mlflow.start_run() as run:
    # Extract metrics
    final_train_rmse = np.array(reg.evals_result()["validation_0"]["rmse"])[-1]
    final_test_rmse = np.array(reg.evals_result()["validation_1"]["rmse"])[-1]
    
    # Extract parameters for logging
    feature_map = {key: value for key, value in reg.get_xgb_params().items() if value is not None}

    # Generate a model signature using the infer_signature utility in MLflow
    # A signature is required to register the model to Unity Catalog 
    # so that the model can be used in SQL queries
    signature = infer_signature(X, reg.predict(X))
    
    # Log the model to MLflow and register the model to Unity Catalog
    model_info = mlflow.xgboost.log_model(
        xgb_model=reg,
        artifact_path="model",
        input_example=X.iloc[[0]],
        signature=signature,
        registered_model_name="main.default.xgboost_regression_model",
    )

    # Log parameters
    mlflow.log_params(feature_map)

    # Log metrics
    mlflow.log_metric("train_rmse", final_train_rmse)
    mlflow.log_metric("test_rmse", final_test_rmse)
    
    # Log feature analysis plots
    # Plots are saved as artifacts in MLflow
    mlflow.log_figure(dist_plot, "feature_distributions.png")
    mlflow.log_figure(corr_plot, "correlation_heatmap.png")
    mlflow.log_figure(scatter_plot, "feature_target_relationships.png")
    mlflow.log_figure(pairwise_plot, "pairwise_relationships.png")
    mlflow.log_figure(outlier_plot, "outlier_detection.png")
    mlflow.log_figure(residual_plot, "feature_boxplots_by_target.png")
        
    # Plot feature importance
    fig, ax = plt.subplots(figsize=(10, 8))
    xgb.plot_importance(reg, ax=ax, importance_type='gain')
    plt.title('Feature Importance')
    plt.tight_layout()
    plt.close(fig)

    mlflow.log_figure(fig, "feature_importance.png")

    # Run MLflow evaluation to generate additional metrics without having to implement them
    mlflow.evaluate(
        model=model_info.model_uri, 
        data=evaluation_data, 
        targets="label", 
        model_type="regressor", 
        evaluator_config={"metric_prefix": "mlflow_evaluation_"},
    )
    
    print(f"Model logged: {model_info.model_uri}")
    print(f"Train RMSE: {final_train_rmse:.4f}")
    print(f"Test RMSE: {final_test_rmse:.4f}")

5. Hyperparameter tuning

This section shows how to automate hyperparameter tuning using Optuna and nested runs in MLflow. In this way you can explore a range of parameter configurations and capture all of the experimental details.

The code in the next cell does the following:

  1. Uses the create_regression_data function defined previously to generate a synthetic regression dataset.

  2. Splits the dataset into separate training and test datasets, and saves a copy of the test dataset for evaluation.

  3. Trains the XGBoost regression model.

  4. Creates an objective function for the hyperparameter tuning process. The objective function defines the search space for hyperparameters of the XGBoost regressor, such as the maximum tree depth, number of estimators, learning rate, and sampling ratios. Optuna dynamically samples these values, ensuring that each trial tests a different combination of parameters.

  5. Initiates a nested MLflow run inside the objective function. This nested run automatically captures and logs all details specific to the current hyperparameter trial. By isolating each trial in its own nested run, you can keep a well-organized record of each configuration and its corresponding performance metrics. The nested run logs the following:

    • The specific hyperparameters used for that trial.
    • The performance metric (in this case, RMSE) computed on the test set.
    • The trained model instance is also stored as part of the trial’s metadata, allowing easy retrieval of the best-performing model later.

    The code does not record each model to MLflow. While doing hyperparameter tuning, each iteration is not guaranteed to be particularly good, so there is no reason to record the model artifact for each one.

  6. Create a parent MLflow run. This run initiates an Optuna study designed to identify the optimal set of hyperparameters (the set that produces the minimum RMSE). Optuna runs a series of trials where each trial uses a unique combination of hyperparameters. During each trial, the nested MLflow run captures all the experiment details, so you can later track and compare the performance of each model configuration.

  7. The study identifies the best trial based on the lowest RMSE. The code logs the metrics, parameters, and model from the best trial, and registers the best model in Unity Catalog. The code uses infer_signature to save a model signature, which specifies the expected input and output schemas and is important for consistent deployment and integration with systems like Unity Catalog. Finally, additional artifacts such as EDA plots and feature importance charts are recorded.

import optuna
from sklearn.metrics import mean_squared_error
import numpy as np
import mlflow
import xgboost as xgb

# Generate training and validation data
n_samples = 2000
n_features = 10

X, y = create_regression_data(n_samples=n_samples, n_features=n_features, nonlinear=True)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Prepare the evaluation data
evaluation_data = X_test.copy()
evaluation_data["label"] = y_test

reg.fit(
    X_train, y_train,
    eval_set=[(X_train, y_train), (X_test, y_test)],
    verbose=False
)
reg = xgb.XGBRegressor(
    tree_method="hist",
    n_estimators=100,
    learning_rate=0.1,
    max_depth=6,
    subsample=0.8,
    colsample_bytree=0.8,
    eval_metric='rmse',
)

# The objective function defines the search space for the key hyperparameters of the XGBRegressor algorithm.
# Optuna dynamically samples these values, so that each trial tests a different combination of parameters.
def objective(trial):
    param = {
        "max_depth": trial.suggest_int("max_depth", 3, 10),
        "n_estimators": trial.suggest_int("n_estimators", 50, 500),
        "eta": trial.suggest_float("eta", 0.01, 0.3, log=True),
        "subsample": trial.suggest_float("subsample", 0.5, 1.0),
        "colsample_bytree": trial.suggest_float("colsample_bytree", 0.5, 1.0),
        "tree_method": "hist",
        "objective": "reg:squarederror",
        "eval_metric": "rmse"
    }

    with mlflow.start_run(nested=True):
        mlflow.log_params(param)
        regressor = xgb.XGBRegressor(**param)
        regressor.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False)

        preds = regressor.predict(X_test)
        rmse = np.sqrt(mean_squared_error(y_test, preds))
        mlflow.log_metric("rmse", rmse)
    
    # Store the model in the trial's `user attributes`
    trial.set_user_attr("model", regressor)
    return rmse

# In the parent run, save the best iteration from the hyperparameter tuning execution
with mlflow.start_run():
    study = optuna.create_study(direction="minimize")
    study.optimize(objective, n_trials=50)

    best_trial = study.best_trial
    best_model = best_trial.user_attrs["model"]

    mlflow.log_metric("best_rmse", best_trial.value)
    mlflow.log_params(best_trial.params)

    signature = infer_signature(X_train, best_model.predict(X_test))

    mlflow.xgboost.log_model(
        xgb_model=best_model,     
        artifact_path="Volumes/XGBoostTuned/model",
        input_example=X_train.iloc[[0]],
        signature=signature,
        model_format="ubj",
        registered_model_name="main.default.xgboostoptuna",
    )

    mlflow.evaluate(
        model=model_info.model_uri, 
        data=evaluation_data, 
        targets="label", 
        model_type="regressor", 
        evaluator_config={"metric_prefix": "mlflow_evaluation_"},
    )

    dist_plot = plot_feature_distributions(X_train, y_train)
    corr_plot = plot_correlation_heatmap(X_train, y_train)
    scatter_plot = plot_feature_target_relationships(X_train, y_train)

    # Select a few interesting features for the pairwise plot
    # Choose features with highest correlation with target
    corr_with_target = X_train.corrwith(y_train).abs().sort_values(ascending=False)
    top_features = corr_with_target.head(4).index.tolist()
    pairwise_plot = plot_pairwise_relationships(X, y, top_features)

    # Log the plots associated with the parent run only
    mlflow.log_figure(dist_plot, "feature_distributions.png")
    mlflow.log_figure(corr_plot, "correlation_heatmap.png")
    mlflow.log_figure(scatter_plot, "feature_target_relationships.png")
    mlflow.log_figure(pairwise_plot, "pairwise_relationships.png")
    mlflow.log_figure(outlier_plot, "outlier_detection.png")
    mlflow.log_figure(residual_plot, "feature_boxplots_by_target.png")
        
    # Plot feature importance of the best model only
    fig, ax = plt.subplots(figsize=(10, 8))
    xgb.plot_importance(best_model, ax=ax, importance_type='gain')
    plt.title('Feature Importance')
    plt.tight_layout()
    plt.close(fig)

    mlflow.log_figure(fig, "feature_importance.png")

6. Assign a human-readable alias

MLflow provides the capability to assign human-readable aliases to registered models.

When you set an alias for a model, you create a meaningful label—such as "best" or "production"—to a specific version of the model. This label makes it easy for team members and automated systems to identify which version is intended for deployment or further evaluation.

Many of the model registry APIs recognize and work with these aliases.

# Use the `MlflowClient` to access metadata, artifacts, and information about models that are tracked or registered to the model registry.
from mlflow import MlflowClient
client = MlflowClient()

# Set the alias on the desired version. This example uses version 1.
client.set_registered_model_alias("main.default.xgboostoptuna", "best", 1)

7. Pre-deployment validation

MLflow provides the mlflow.models.predict utility to simulate a production-like environment and validate that your model is configured correctly.

mlflow.models.predict fetches your logged model from the specified model URI, validates any dependencies, and builds a virtual execution environment within a subprocess. This simulates deploying your model on a virtual machine, closely mirroring a production model serving scenario.

The utility supports multiple environment managers to build the execution environment. For more information, see the MLflow documentation. This example uses uv, which is recommended for best performance. This notebook installed the package in the first cell.

You can supply data for inference in two ways:

  • In-memory data:
    Pass an in-memory object to the input_data argument. This allows for immediate validation within your notebook.
  • External data location:
    Alternatively, use the input_path argument to specify a location (such as a volume in Unity Catalog) from which to read the data.

To keep a record of the model’s predictions, specify an output_path to save the inference results. If you don’t specify an output path, the prediction results are displayed in the cell's output.

model_uri = "models:/main.default.xgboostoptuna@best"

mlflow.models.predict(model_uri=model_uri, input_data=X_train, env_manager="uv")

8. Load the registered model and make predictions

The code in this section shows how to load the registered model from MLflow and use it to make predictions locally. The model URI is based on the alias defined previously. By using the alias in the model URI, you ensure that the most recent version of the model with that alias is used for inference.

After the model is loaded, you can generate predictions on your data. This is a confirmation step to confirm that the model is working as expected before you use it for larger-scale applications such as batch prediction.

Should I use PyFunc or native XGBoost?

When working with MLflow, you have two primary methods for loading your logged models: the generic pyfunc interface and the native XGBoost object. The best choice depends on your use case.

  • The pyfunc interface provides a standardized, framework-agnostic way to interact with your model. This makes it the best choice for real-time model serving and production environments, where a consistent API is required. By using pyfunc, your model is encapsulated in a generic wrapper that exposes a simple predict() method, ensuring seamless integration and consistent behavior across various deployment scenarios.

  • Alternatively, you can load the model using mlflow.xgboost.load_model(), which returns a native XGBoost object. This method preserves the full functionality of the XGBoost library, allowing you to take advantage of its specialized methods and optimizations. For local validation or batch inference tasks, the native object can offer performance benefits and more granular control during evaluation. However, this approach is less suitable for deployment to production environments that require a standardized interface.

# Load the model and use the model to make predictions locally
loaded_registered_model = mlflow.pyfunc.load_model(model_uri=model_uri)

loaded_registered_model.predict(X_train)

9. Batch prediction using Spark UDF in MLflow

The code in this section shows how to perform distributed batch predictions using the Spark UDF integration in MLflow. This approach allows you to leverage the scalable data processing capabilities of Spark to apply your model across large datasets efficiently.

# Convert the training data into a Spark DataFrame.
X_spark = spark.createDataFrame(X_train)

# Create a Spark UDF to apply the model to the Spark DataFrame.
# Note that `model_uri` is defined based on a model alias, ensuring that you always load the current, approved version.
udf = mlflow.pyfunc.spark_udf(
    spark,
    model_uri=model_uri,
)

# Apply the Spark UDF to the DataFrame. This performs batch predictions across all rows in a distributed manner. 
X_spark = X_spark.withColumn("prediction", udf(*X_train.columns))

# Display the resulting DataFrame. 
display(X_spark)
;