Use custom metrics with Databricks Lakehouse Monitoring

Preview

This feature is in Public Preview.

This page describes how to create a custom metric in Databricks Lakehouse Monitoring. In addition to the analysis and drift statistics that are automatically calculated, you can create custom metrics. For example, you might want to track a weighted mean that captures some aspect of business logic or use a custom model quality score. You can also create custom drift metrics that track changes to the values in the primary table (compared to the baseline or the previous time window).

For more details on how to use the databricks.lakehouse_monitoring.Metric API, see the API reference.

Types of custom metrics

Databricks Lakehouse Monitoring includes the following types of custom metrics:

  • Aggregate metrics, which are calculated based on columns in the primary table. Aggregate metrics are stored in the profile metrics table.

  • Derived metrics, which are calculated based on previously computed aggregate metrics and do not directly use data from the primary table. Derived metrics are stored in the profile metrics table.

  • Drift metrics, which compare previously computed aggregate or derived metrics from two different time windows, or between the primary table and the baseline table. Drift metrics are stored in the drift metrics table.

Using derived and drift metrics where possible minimizes recomputation over the full primary table. Only aggregate metrics access data from the primary table. Derived and drift metrics can then be computed directly from the aggregate metric values.

Custom metrics parameters

To define a custom metric, you create a Jinja template for a SQL column expression. The tables in this section describe the parameters that define the metric, and the parameters that are used in the Jinja template.

Parameter

Description

type

One of aggregate, derived, or drift.

name

Column name for the custom metric in metric tables.

input_columns

List of column names in the input table the metric should be computed for. To indicate that more than one column is used in the calculation, use :table. See the examples in this article.

definition

Jinja template for a SQL expression that specifies how to compute the metric. See Create metric_definition.

output_data_type

Spark datatype of the metric output.

Create definition

The definition parameter must be a single string expression in the form of a Jinja template. It cannot contain joins or subqueries. To construct complex definitions, you can use Python helper functions.

The following table lists the parameters you can use to create a SQL Jinja Template to specify how to calculate the metric.

Parameter

Description

{{input_column}}

Column used to compute the custom metric.

{{prediction_col}}

Column holding ML model predictions. Used with InferenceLog analysis.

{{label_col}}

Column holding ML model ground truth labels. Used with InferenceLog analysis.

{{current_df}}

For drift compared to the previous time window. Data from the previous time window.

{{base_df}}

For drift compared to the baseline table. Baseline data.

Aggregate metric example

The following example computes the average of the square of the values in a column, and is applied to columns f1 and f2. The output is saved as a new column in the profile metrics table and is shown in the analysis rows corresponding to the columns f1 and f2. The applicable column names are substituted for the Jinja parameter {{input_column}}.

from databricks import lakehouse_monitoring as lm
from pyspark.sql import types as T

lm.Metric(
  type="aggregate",
  name="squared_avg",
  input_columns=["f1", "f2"],
  definition="avg(`{{input_column}}`*`{{input_column}}`)",
  output_data_type=T.DoubleType()
  )

The following code defines a custom metric that computes the average of the difference between columns f1 and f2. This example shows the use of [":table"] in the input_columns parameter to indicate that more than one column from the table is used in the calculation.

from databricks import lakehouse_monitoring as lm
from pyspark.sql import types as T

lm.Metric(
  type="aggregate",
  name="avg_diff_f1_f2",
  input_columns=[":table"],
  definition="avg(f1 - f2)",
  output_data_type=T.DoubleType())

This example computes a weighted model quality score. For observations where the critical column is True, a heavier penalty is assigned when the predicted value for that row does not match the ground truth. Because it’s defined on the raw columns (prediction and label), it’s defined as an aggregate metric. The :table column indicates that this metric is calculated from multiple columns. The Jinja parameters {{prediction_col}} and {{label_col}} are replaced with the name of the prediction and ground truth label columns for the monitor.

from databricks import lakehouse_monitoring as lm
from pyspark.sql import types as T

lm.Metric(
  type="aggregate",
  name="weighted_error",
  input_columns=[":table"],
  definition="""avg(CASE
    WHEN {{prediction_col}} = {{label_col}} THEN 0
    WHEN {{prediction_col}} != {{label_col}} AND critical=TRUE THEN 2
    ELSE 1 END)""",
  output_data_type=T.DoubleType()
)

Derived metric example

The following code defines a custom metric that computes the square root of the squared_avg metric defined earlier in this section. Because this is a derived metric, it does not reference the primary table data and instead is defined in terms of the squared_avg aggregate metric. The output is saved as a new column in the profile metrics table.

from databricks import lakehouse_monitoring as lm
from pyspark.sql import types as T

lm.Metric(
  type="derived",
  name="root_mean_square",
  input_columns=["f1", "f2"],
  definition="sqrt(squared_avg)",
  output_data_type=T.DoubleType())

Drift metrics example

The following code defines a drift metric that tracks the change in the weighted_error metric defined earlier in this section. The {{current_df}} and {{base_df}} parameters allow the metric to reference the weighted_error values from the current window and the comparison window. The comparison window can be either the baseline data or the data from the previous time window. Drift metrics are saved in the drift metrics table.

from databricks import lakehouse_monitoring as lm
from pyspark.sql import types as T

lm.Metric(
  type="drift",
  name="error_rate_delta",
  input_columns=[":table"],
  definition="{{current_df}}.weighted_error - {{base_df}}.weighted_error",
  output_data_type=T.DoubleType()
)