databricks-logo

    mlflow-classic-ml-e2e-mlflow-3

    (Python)
    Loading...

    XGBoost MLflow tutorial

    This tutorial covers the full lifecycle of experimentation, training, tuning, registration, evaluation, and deployment for a classic ML modeling project. It shows you how to use MLflow to keep track of every aspect of the model development and deployment processes.

    In this step-by-step tutorial, you'll discover how to:

    • Generate and visualize data: Create synthetic data to simulate real-world scenarios, and visualize feature relationships with Seaborn.
    • Train and log models: Train an XGBoost model, and log important metrics, parameters, and artifacts using MLflow, including visualizations.
    • Register models: Register your model with Unity Catalog, preparing it for review and future deployment to managed serving endpoints.
    • Load and evaluate models: Load your registered model, make predictions, and perform error analysis to validate model performance.

    This tutorial leverages features from Mlflow 3.0. For more details, see "Get started with MLflow 3.0" (AWS|Azure|GCP)

    Install the latest version of MLflow
    %pip install --upgrade -Uqqq mlflow>=3.0 xgboost optuna uv
    %restart_python
    3
    from typing import Tuple, Optional
    
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    from sklearn.model_selection import train_test_split
    from pandas.api.types import CategoricalDtype
    from statsmodels.graphics.mosaicplot import mosaic
    
    import xgboost as xgb
    
    import mlflow
    from mlflow.models import infer_signature
    
    

    0. Configure the Model Registry with Unity Catalog

    One of the key advantages of using MLflow on Databricks is the seamless integration with Unity Catalog. This integration simplifies model management and governance, ensuring that every model you develop is tracked, versioned, and secure. For more information about Unity Catalog, see (AWS | Azure | GCP).

    Set the registry URI

    The following cell configures MLflow to use Unity Catalog for model registration.

    mlflow.set_registry_uri("databricks-uc")

    1. Create a synthetic regression dataset

    The next cell defines the create_regression_data function. This function generates synthetic data for regression. The resulting DataFrame includes correlated data, cyclical patterns, and outliers. These features are designed to mimic real-world data scenarios.

    def create_regression_data(
        n_samples: int, 
        n_features: int,
        seed: int = 1994,
        noise_level: float = 0.3,
        nonlinear: bool = True
    ) -> Tuple[pd.DataFrame, pd.Series]:
        """Generates synthetic regression data with interesting correlations for MLflow and XGBoost demonstrations.
    
        This function creates a DataFrame of continuous features and computes a target variable with nonlinear
        relationships and interactions between features. The data is designed to be complex enough to demonstrate
        the capabilities of XGBoost, but not so complex that a reasonable model can't be learned.
    
        Args:
            n_samples (int): Number of samples (rows) to generate.
            n_features (int): Number of feature columns.
            seed (int, optional): Random seed for reproducibility. Defaults to 1994.
            noise_level (float, optional): Level of Gaussian noise to add to the target. Defaults to 0.3.
            nonlinear (bool, optional): Whether to add nonlinear feature transformations. Defaults to True.
    
        Returns:
            Tuple[pd.DataFrame, pd.Series]:
                - pd.DataFrame: DataFrame containing the synthetic features.
                - pd.Series: Series containing the target labels.
    
        Example:
            >>> df, target = create_regression_data(n_samples=1000, n_features=10)
        """
        rng = np.random.RandomState(seed)
        
        # Generate random continuous features
        X = rng.uniform(-5, 5, size=(n_samples, n_features))
        
        # Create feature DataFrame with meaningful names
        columns = [f"feature_{i}" for i in range(n_features)]
        df = pd.DataFrame(X, columns=columns)
        
        # Generate base target variable with linear relationship to a subset of features
        # Use only the first n_features//2 features to create some irrelevant features
        weights = rng.uniform(-2, 2, size=n_features//2)
        target = np.dot(X[:, :n_features//2], weights)
        
        # Add some nonlinear transformations if requested
        if nonlinear:
            # Add square term for first feature
            target += 0.5 * X[:, 0]**2
            
            # Add interaction between the second and third features
            if n_features >= 3:
                target += 1.5 * X[:, 1] * X[:, 2]
            
            # Add sine transformation of fourth feature
            if n_features >= 4:
                target += 2 * np.sin(X[:, 3])
            
            # Add exponential of fifth feature, scaled down
            if n_features >= 5:
                target += 0.1 * np.exp(X[:, 4] / 2)
                
            # Add threshold effect for sixth feature
            if n_features >= 6:
                target += 3 * (X[:, 5] > 1.5).astype(float)
        
        # Add Gaussian noise
        noise = rng.normal(0, noise_level * target.std(), size=n_samples)
        target += noise
        
        # Add a few more interesting features to the DataFrame
        
        # Add a correlated feature (but not used in target calculation)
        if n_features >= 7:
            df['feature_correlated'] = df['feature_0'] * 0.8 + rng.normal(0, 0.2, size=n_samples)
        
        # Add a cyclical feature
        df['feature_cyclical'] = np.sin(np.linspace(0, 4*np.pi, n_samples))
        
        # Add a feature with outliers
        df['feature_with_outliers'] = rng.normal(0, 1, size=n_samples)
        # Add outliers to ~1% of samples
        outlier_idx = rng.choice(n_samples, size=n_samples//100, replace=False)
        df.loc[outlier_idx, 'feature_with_outliers'] = rng.uniform(10, 15, size=len(outlier_idx))
        
        return df, pd.Series(target, name='target')

    2. Exploratory data analysis (EDA) visualizations

    Before training your model, it’s essential to examine your data. Visualizations help you validate that the data is as expected, spot unexpected anomalies, and drive feature selection. As you move forward with model development, these visualizations serve as a record of your work that can help with troubleshooting, reproducibility, and collaboration.

    You can use MLflow to log visualizations, making your experimentation fully reproducible.

    The code in the following cell creates 6 functions, each of which generates a different plot to help you visually inspect your dataset.

    def plot_feature_distributions(X: pd.DataFrame, y: pd.Series, n_cols: int = 3) -> plt.Figure:
        """
        Creates a grid of histograms for each feature in the dataset.
    
        Args:
            X (pd.DataFrame): DataFrame containing synthetic features.
            y (pd.Series): Series containing the target variable.
            n_cols (int): Number of columns in the grid layout.
    
        Returns:
            plt.Figure: The matplotlib Figure object containing the distribution plots.
        """
        features = X.columns
        n_features = len(features)
        n_rows = (n_features + n_cols - 1) // n_cols
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4 * n_rows))
        axes = axes.flatten() if n_rows * n_cols > 1 else [axes]
        
        for i, feature in enumerate(features):
            if i < len(axes):
                ax = axes[i]
                sns.histplot(X[feature], ax=ax, kde=True, color='skyblue')
                ax.set_title(f'Distribution of {feature}')
        
        # Hide any unused subplots
        for i in range(n_features, len(axes)):
            axes[i].set_visible(False)
        
        plt.tight_layout()
        fig.suptitle('Feature Distributions', y=1.02, fontsize=16)
        plt.close(fig)
        return fig
    
    def plot_correlation_heatmap(X: pd.DataFrame, y: pd.Series) -> plt.Figure:
        """
        Creates a correlation heatmap of all features and the target variable.
    
        Args:
            X (pd.DataFrame): DataFrame containing features.
            y (pd.Series): Series containing the target variable.
    
        Returns:
            plt.Figure: The matplotlib Figure object containing the heatmap.
        """
        # Combine features and target into one DataFrame
        data = X.copy()
        data['target'] = y
        
        # Calculate correlation matrix
        corr_matrix = data.corr()
        
        # Set up the figure
        fig, ax = plt.subplots(figsize=(12, 10))
        
        # Draw the heatmap with a color bar
        cmap = sns.diverging_palette(220, 10, as_cmap=True)
        sns.heatmap(corr_matrix, annot=True, fmt='.2f', cmap=cmap,
                    center=0, square=True, linewidths=0.5, ax=ax)
        
        ax.set_title('Feature Correlation Heatmap', fontsize=16)
        plt.close(fig)
        return fig
    
    def plot_feature_target_relationships(X: pd.DataFrame, y: pd.Series, n_cols: int = 3) -> plt.Figure:
        """
        Creates a grid of scatter plots showing the relationship between each feature and the target.
    
        Args:
            X (pd.DataFrame): DataFrame containing features.
            y (pd.Series): Series containing the target variable.
            n_cols (int): Number of columns in the grid layout.
    
        Returns:
            plt.Figure: The matplotlib Figure object containing the relationship plots.
        """
        features = X.columns
        n_features = len(features)
        n_rows = (n_features + n_cols - 1) // n_cols
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4 * n_rows))
        axes = axes.flatten() if n_rows * n_cols > 1 else [axes]
        
        for i, feature in enumerate(features):
            if i < len(axes):
                ax = axes[i]
                # Scatter plot with regression line
                sns.regplot(x=X[feature], y=y, ax=ax, 
                           scatter_kws={'alpha': 0.5, 'color': 'blue'}, 
                           line_kws={'color': 'red'})
                ax.set_title(f'{feature} vs Target')
        
        for i in range(n_features, len(axes)):
            axes[i].set_visible(False)
        
        plt.tight_layout()
        fig.suptitle('Feature vs Target Relationships', y=1.02, fontsize=16)
        plt.close(fig)
        return fig
    
    def plot_pairwise_relationships(X: pd.DataFrame, y: pd.Series, features: list[str]) -> plt.Figure:
        """
        Creates a pairplot showing relationships between selected features and the target.
    
        Args:
            X (pd.DataFrame): DataFrame containing features.
            y (pd.Series): Series containing the target variable.
            features (List[str]): List of feature names to include in the plot.
    
        Returns:
            plt.Figure: The matplotlib Figure object containing the pairplot.
        """
        # Ensure features exist in the DataFrame
        valid_features = [f for f in features if f in X.columns]
        
        if not valid_features:
            fig, ax = plt.subplots()
            ax.text(0.5, 0.5, "No valid features provided", ha='center', va='center')
            return fig
        
        # Combine selected features and target
        data = X[valid_features].copy()
        data['target'] = y
        
        # Create pairplot
        pairgrid = sns.pairplot(data, diag_kind="kde", 
                              plot_kws={"alpha": 0.6, "s": 50},
                              corner=True)
        
        pairgrid.fig.suptitle("Pairwise Feature Relationships", y=1.02, fontsize=16)
        plt.close(pairgrid.fig)
        return pairgrid.fig
    
    def plot_boxplots(X: pd.DataFrame, y: pd.Series, n_cols: int = 3) -> plt.Figure:
        """
        Creates a grid of box plots for each feature, with points colored by target value.
    
        Args:
            X (pd.DataFrame): DataFrame containing features.
            y (pd.Series): Series containing the target variable.
            n_cols (int): Number of columns in the grid layout.
    
        Returns:
            plt.Figure: The matplotlib Figure object containing the box plots.
        """
        features = X.columns
        n_features = len(features)
        n_rows = (n_features + n_cols - 1) // n_cols
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4 * n_rows))
        axes = axes.flatten() if n_rows * n_cols > 1 else [axes]
        
        # Create target bins for coloring
        y_binned = pd.qcut(y, 3, labels=['Low', 'Medium', 'High'])
        
        for i, feature in enumerate(features):
            if i < len(axes):
                ax = axes[i]
                # Box plot for each feature
                sns.boxplot(x=y_binned, y=X[feature], ax=ax)
                ax.set_title(f'Distribution of {feature} by Target Range')
                ax.set_xlabel('Target Range')
        
        # Hide any unused subplots
        for i in range(n_features, len(axes)):
            axes[i].set_visible(False)
        
        plt.tight_layout()
        fig.suptitle('Feature Distributions by Target Range', y=1.02, fontsize=16)
        plt.close(fig)
        return fig
    
    def plot_outliers(X: pd.DataFrame, n_cols: int = 3) -> plt.Figure:
        """
        Creates a grid of box plots to detect outliers in each feature.
    
        Args:
            X (pd.DataFrame): DataFrame containing features.
            n_cols (int): Number of columns in the grid layout.
    
        Returns:
            plt.Figure: The matplotlib Figure object containing the outlier plots.
        """
        features = X.columns
        n_features = len(features)
        n_rows = (n_features + n_cols - 1) // n_cols
        
        fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4 * n_rows))
        axes = axes.flatten() if n_rows * n_cols > 1 else [axes]
        
        for i, feature in enumerate(features):
            if i < len(axes):
                ax = axes[i]
                # Box plot to detect outliers
                sns.boxplot(x=X[feature], ax=ax, color='skyblue')
                ax.set_title(f'Outlier Detection for {feature}')
                ax.set_xlabel(feature)
        
        # Hide any unused subplots
        for i in range(n_features, len(axes)):
            axes[i].set_visible(False)
        
        plt.tight_layout()
        fig.suptitle('Outlier Detection for Features', y=1.02, fontsize=16)
        plt.close(fig)
        return fig

    3. Standard modeling workflow

    The code in the next cell does the following:

    1. Uses the function you created, create_regression_data, to create a dataset.
    2. Uses the visualization functions you created to create EDA plots.
    3. Configures and trains an XGBoost model.
    4. Uses the trained model to make predictions on the test dataset.
    # Create the regression dataset
    n_samples = 1000
    n_features = 10
    X, y = create_regression_data(n_samples=n_samples, n_features=n_features, nonlinear=True)
    
    # Create EDA plots
    dist_plot = plot_feature_distributions(X, y)
    corr_plot = plot_correlation_heatmap(X, y)
    scatter_plot = plot_feature_target_relationships(X, y)
    corr_with_target = X.corrwith(y).abs().sort_values(ascending=False)
    top_features = corr_with_target.head(4).index.tolist()
    pairwise_plot = plot_pairwise_relationships(X, y, top_features)
    outlier_plot = plot_outliers(X)
    
    # Configure the XGBoost model
    reg = xgb.XGBRegressor(
        tree_method="hist",
        n_estimators=100,
        learning_rate=0.1,
        max_depth=6,
        subsample=0.8,
        colsample_bytree=0.8,
        eval_metric='rmse',
    )
    
    # Create train/test split to properly evaluate the model
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=7722)
    
    # Train the model with evaluation
    reg.fit(
        X_train, y_train,
        eval_set=[(X_train, y_train), (X_test, y_test)],
        verbose=False
    )
    
    # Generate predictions for residual plot
    y_pred = reg.predict(X_test)
    residual_plot = plot_boxplots(X, y)
    
    
    

    4. Log the model using MLflow

    When you log a model using MLflow on Databricks, important artifacts and metadata are captured. This ensures that your model is not only reproducible but also ready for deployment with all necessary dependencies and clear API contracts. For details on what is logged, see the MLflow documentation.

    The code in the next cell starts an MLflow run using with mlflow.start_run():. This initializes the MLflow context manager for the run and encloses the run in a code block. When the code block ends, all logged metrics, parameters, and artifacts are saved, and the MLflow run is automatically terminated.

    In MLflow 3.0, calling mlflow.xgboost.log_model() creates a logged model object with all associated metrics and parameters that can be accessed across runs.

    # Incorporate MLflow evaluation
    evaluation_data = X_test.copy()
    evaluation_data["label"] = y_test
    
    # Log the model and training metadata results
    with mlflow.start_run() as run:
        # Extract metrics
        final_train_rmse = np.array(reg.evals_result()["validation_0"]["rmse"])[-1]
        final_test_rmse = np.array(reg.evals_result()["validation_1"]["rmse"])[-1]
        
        # Extract parameters for logging
        feature_map = {key: value for key, value in reg.get_xgb_params().items() if value is not None}
    
        # Generate a model signature using the infer_signature utility in MLflow
        # A signature is required to register the model to Unity Catalog 
        # so that the model can be used in SQL queries
        signature = infer_signature(X, reg.predict(X))
    
        # Log parameters
        mlflow.log_params(feature_map)
        
        # Log the model to MLflow and register the model to Unity Catalog
        # All model metrics and parameters will be available in Unity Catalog
        model_info = mlflow.xgboost.log_model(
            xgb_model=reg,
            name="xgboost_regression_model",
            input_example=X.iloc[[0]],
            signature=signature,
            registered_model_name="main.default.xgboost_regression_model",
        )
    
        # Log metrics to the run and model
        mlflow.log_metric("train_rmse", final_train_rmse)
        mlflow.log_metric("test_rmse", final_test_rmse)
        
        # Log feature analysis plots
        # Plots are saved as artifacts in MLflow
        mlflow.log_figure(dist_plot, "feature_distributions.png")
        mlflow.log_figure(corr_plot, "correlation_heatmap.png")
        mlflow.log_figure(scatter_plot, "feature_target_relationships.png")
        mlflow.log_figure(pairwise_plot, "pairwise_relationships.png")
        mlflow.log_figure(outlier_plot, "outlier_detection.png")
        mlflow.log_figure(residual_plot, "feature_boxplots_by_target.png")
            
        # Plot feature importance
        fig, ax = plt.subplots(figsize=(10, 8))
        xgb.plot_importance(reg, ax=ax, importance_type='gain')
        plt.title('Feature Importance')
        plt.tight_layout()
        plt.close(fig)
    
        mlflow.log_figure(fig, "feature_importance.png")
    
        # Run MLflow evaluation to generate additional metrics without having to implement them
        mlflow.models.evaluate(
            model=model_info.model_uri, 
            data=evaluation_data, 
            targets="label", 
            model_type="regressor", 
            evaluator_config={"metric_prefix": "mlflow_evaluation_"},
        )
        
        print(f"Model logged: {model_info.model_uri}")
        print(f"Train RMSE: {final_train_rmse:.4f}")
        print(f"Test RMSE: {final_test_rmse:.4f}")

    5. Hyperparameter tuning

    This section shows how to automate hyperparameter tuning using Optuna and nested runs in MLflow. In this way you can explore a range of parameter configurations and capture all of the experimental details.

    The code in the next cell does the following:

    1. Uses the create_regression_data function defined previously to generate a synthetic regression dataset.

    2. Splits the dataset into separate training and test datasets, and saves a copy of the test dataset for evaluation.

    3. Trains the XGBoost regression model.

    4. Creates an objective function for the hyperparameter tuning process. The objective function defines the search space for hyperparameters of the XGBoost regressor, such as the maximum tree depth, number of estimators, learning rate, and sampling ratios. Optuna dynamically samples these values, ensuring that each trial tests a different combination of parameters.

    5. Initiates a nested MLflow run inside the objective function. This nested run automatically captures and logs all details specific to the current hyperparameter trial. By isolating each trial in its own nested run, you can keep a well-organized record of each configuration and its corresponding performance metrics. The nested run logs the following:

      • The specific hyperparameters used for that trial.
      • The performance metric (in this case, RMSE) computed on the test set.
      • The trained model instance is also stored as part of the trial’s metadata, allowing easy retrieval of the best-performing model later.

      The code does not record each model to MLflow. While doing hyperparameter tuning, each iteration is not guaranteed to be particularly good, so there is no reason to record the model artifact for each one.

    6. Create a parent MLflow run. This run initiates an Optuna study designed to identify the optimal set of hyperparameters (the set that produces the minimum RMSE). Optuna runs a series of trials where each trial uses a unique combination of hyperparameters. During each trial, the nested MLflow run captures all the experiment details, so you can later track and compare the performance of each model configuration.

    7. The study identifies the best trial based on the lowest RMSE. The code logs the metrics, parameters, and model from the best trial, and registers the best model in Unity Catalog. The code uses infer_signature to save a model signature, which specifies the expected input and output schemas and is important for consistent deployment and integration with systems like Unity Catalog. Finally, additional artifacts such as EDA plots and feature importance charts are recorded.

    import optuna
    from sklearn.metrics import mean_squared_error
    import numpy as np
    import mlflow
    import xgboost as xgb
    
    # Generate training and validation data
    n_samples = 2000
    n_features = 10
    
    X, y = create_regression_data(n_samples=n_samples, n_features=n_features, nonlinear=True)
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    # Prepare the evaluation data
    evaluation_data = X_test.copy()
    evaluation_data["label"] = y_test
    
    reg.fit(
        X_train, y_train,
        eval_set=[(X_train, y_train), (X_test, y_test)],
        verbose=False
    )
    reg = xgb.XGBRegressor(
        tree_method="hist",
        n_estimators=100,
        learning_rate=0.1,
        max_depth=6,
        subsample=0.8,
        colsample_bytree=0.8,
        eval_metric='rmse',
    )
    
    # The objective function defines the search space for the key hyperparameters of the XGBRegressor algorithm.
    # Optuna dynamically samples these values, so that each trial tests a different combination of parameters.
    def objective(trial):
        param = {
            "max_depth": trial.suggest_int("max_depth", 3, 10),
            "n_estimators": trial.suggest_int("n_estimators", 50, 500),
            "eta": trial.suggest_float("eta", 0.01, 0.3, log=True),
            "subsample": trial.suggest_float("subsample", 0.5, 1.0),
            "colsample_bytree": trial.suggest_float("colsample_bytree", 0.5, 1.0),
            "tree_method": "hist",
            "objective": "reg:squarederror",
            "eval_metric": "rmse"
        }
    
        with mlflow.start_run(nested=True):
            mlflow.log_params(param)
            regressor = xgb.XGBRegressor(**param)
            regressor.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False)
    
            preds = regressor.predict(X_test)
            rmse = np.sqrt(mean_squared_error(y_test, preds))
            mlflow.log_metric("rmse", rmse)
        
        # Store the model in the trial's `user attributes`
        trial.set_user_attr("model", regressor)
        return rmse
    
    # In the parent run, save the best iteration from the hyperparameter tuning execution
    with mlflow.start_run():
        study = optuna.create_study(direction="minimize")
        study.optimize(objective, n_trials=50)
    
        best_trial = study.best_trial
        best_model = best_trial.user_attrs["model"]
    
        mlflow.log_metric("best_rmse", best_trial.value)
        mlflow.log_params(best_trial.params)
    
        signature = infer_signature(X_train, best_model.predict(X_test))
    
        mlflow.xgboost.log_model(
            xgb_model=best_model,     
            name="xgboostoptuna",
            input_example=X_train.iloc[[0]],
            signature=signature,
            model_format="ubj",
            registered_model_name="main.default.xgboostoptuna",
        )
    
        mlflow.models.evaluate(
            model=model_info.model_uri, 
            data=evaluation_data, 
            targets="label", 
            model_type="regressor", 
            evaluator_config={"metric_prefix": "mlflow_evaluation_"},
        )
    
        dist_plot = plot_feature_distributions(X_train, y_train)
        corr_plot = plot_correlation_heatmap(X_train, y_train)
        scatter_plot = plot_feature_target_relationships(X_train, y_train)
    
        # Select a few interesting features for the pairwise plot
        # Choose features with highest correlation with target
        corr_with_target = X_train.corrwith(y_train).abs().sort_values(ascending=False)
        top_features = corr_with_target.head(4).index.tolist()
        pairwise_plot = plot_pairwise_relationships(X, y, top_features)
    
        # Log the plots associated with the parent run only
        mlflow.log_figure(dist_plot, "feature_distributions.png")
        mlflow.log_figure(corr_plot, "correlation_heatmap.png")
        mlflow.log_figure(scatter_plot, "feature_target_relationships.png")
        mlflow.log_figure(pairwise_plot, "pairwise_relationships.png")
        mlflow.log_figure(outlier_plot, "outlier_detection.png")
        mlflow.log_figure(residual_plot, "feature_boxplots_by_target.png")
            
        # Plot feature importance of the best model only
        fig, ax = plt.subplots(figsize=(10, 8))
        xgb.plot_importance(best_model, ax=ax, importance_type='gain')
        plt.title('Feature Importance')
        plt.tight_layout()
        plt.close(fig)
    
        mlflow.log_figure(fig, "feature_importance.png")
    
    

    6. Assign a human-readable alias

    MLflow provides the capability to assign human-readable aliases to registered models.

    When you set an alias for a model, you create a meaningful label—such as "best" or "production"—to a specific version of the model. This label makes it easy for team members and automated systems to identify which version is intended for deployment or further evaluation.

    Many of the model registry APIs recognize and work with these aliases.

    # Use the `MlflowClient` to access metadata, artifacts, and information about models that are tracked or registered to the model registry.
    from mlflow import MlflowClient
    client = MlflowClient()
    
    # Set the alias on the desired version. This example uses version 1.
    client.set_registered_model_alias("main.default.xgboostoptuna", "best", 1)

    7. Pre-deployment validation

    MLflow provides the mlflow.models.predict utility to simulate a production-like environment and validate that your model is configured correctly.

    mlflow.models.predict fetches your logged model from the specified model URI, validates any dependencies, and builds a virtual execution environment within a subprocess. This simulates deploying your model on a virtual machine, closely mirroring a production model serving scenario.

    The utility supports multiple environment managers to build the execution environment. For more information, see the MLflow documentation. This example uses uv, which is recommended for best performance. This notebook installed the package in the first cell.

    You can supply data for inference in two ways:

    • In-memory data:
      Pass an in-memory object to the input_data argument. This allows for immediate validation within your notebook.
    • External data location:
      Alternatively, use the input_path argument to specify a location (such as a volume in Unity Catalog) from which to read the data.

    To keep a record of the model’s predictions, specify an output_path to save the inference results. If you don’t specify an output path, the prediction results are displayed in the cell's output.

    model_uri = "models:/main.default.xgboostoptuna@best"
    
    mlflow.models.predict(model_uri=model_uri, input_data=X_train, env_manager="uv")

    8. Load the registered model and make predictions

    The code in this section shows how to load the registered model from MLflow and use it to make predictions locally. The model URI is based on the alias defined previously. By using the alias in the model URI, you ensure that the most recent version of the model with that alias is used for inference.

    After the model is loaded, you can generate predictions on your data. This is a confirmation step to confirm that the model is working as expected before you use it for larger-scale applications such as batch prediction.

    Should I use PyFunc or native XGBoost?

    When working with MLflow, you have two primary methods for loading your logged models: the generic pyfunc interface and the native XGBoost object. The best choice depends on your use case.

    • The pyfunc interface provides a standardized, framework-agnostic way to interact with your model. This makes it the best choice for real-time model serving and production environments, where a consistent API is required. By using pyfunc, your model is encapsulated in a generic wrapper that exposes a simple predict() method, ensuring seamless integration and consistent behavior across various deployment scenarios.

    • Alternatively, you can load the model using mlflow.xgboost.load_model(), which returns a native XGBoost object. This method preserves the full functionality of the XGBoost library, allowing you to take advantage of its specialized methods and optimizations. For local validation or batch inference tasks, the native object can offer performance benefits and more granular control during evaluation. However, this approach is less suitable for deployment to production environments that require a standardized interface.

    # Load the model and use the model to make predictions locally
    loaded_registered_model = mlflow.pyfunc.load_model(model_uri=model_uri)
    
    loaded_registered_model.predict(X_train)

    9. Batch prediction using Spark UDF in MLflow

    The code in this section shows how to perform distributed batch predictions using the Spark UDF integration in MLflow. This approach allows you to leverage the scalable data processing capabilities of Spark to apply your model across large datasets efficiently.

    # Convert the training data into a Spark DataFrame.
    X_spark = spark.createDataFrame(X_train)
    
    # Create a Spark UDF to apply the model to the Spark DataFrame.
    # Note that `model_uri` is defined based on a model alias, ensuring that you always load the current, approved version.
    udf = mlflow.pyfunc.spark_udf(
        spark,
        model_uri=model_uri,
    )
    
    # Apply the Spark UDF to the DataFrame. This performs batch predictions across all rows in a distributed manner. 
    X_spark = X_spark.withColumn("prediction", udf(*X_train.columns))
    
    # Display the resulting DataFrame. 
    display(X_spark)
    ;