Use XGBoost on Databricks
Warning
Versions of XGBoost below 1.3.0 have a bug that can cause the shared Spark context to be killed if XGBoost model training fails. The only way to recover is to restart the cluster. All Databricks Runtime ML versions below 7.6 ML include a version of XGBoost that is affected by this bug. To install a different version of XGBoost, see Install XGBoost on Databricks.
Single node training in Python
The Python package allows you to train only single node workloads.
Databricks Runtime 7.6 ML and above
Preview
This feature is in Public Preview.
Databricks Runtime 7.6 ML and above include PySpark estimators based on the Python xgboost
package, sparkdl.xgboost.XgboostRegressor
and sparkdl.xgboost.XgboostClassifier
. You can create an ML pipeline based on these estimators. For more information, see Xgboost for PySpark Pipeline.
Note
- These estimators train the model on a single Spark worker.
- GPU clusters are not supported.
- The following parameters from the
xgboost
package are not supported:gpu_id
,kwargs
,output_margin
,base_margin
,validate_features
. - The parameters
sample_weight
,eval_set
, andsample_weight_eval_set
are not supported. Instead, use the parametersweightCol
andvalidationIndicatorCol
. See Xgboost for PySpark Pipeline for details. - The parameter
missing
has different semantics from thexgboost
package. In thexgboost
package, the zero values in a SciPy sparse matrix are treated as missing values regardless of the value ofmissing
. For the PySpark estimators in thesparkdl
package, zero values in a Spark sparse vector are not treated as missing values unless you setmissing=0
. If you have a sparse training dataset (most feature values are missing), Databricks recommends settingmissing=0
to reduce memory consumption and achieve better performance.
Distributed training in Scala
To perform distributed training, you must use XGBoost’s Scala/Java packages. The examples in this section show how you can use XGBoost with MLlib. The first example shows how to embed an XGBoost model into an MLlib ML pipeline. The second example shows how to use MLlib cross validation to tune an XGBoost model.