xgboost-regression(Scala)

Loading...

XGBoost regression with Spark DataFrames

Prepare data

import java.util.UUID.randomUUID temp_path: String = /tmp/power_plant/23c86fb4-2744-4582-bdfe-f4bf7668199a defined class PowerPlantTable powerPlantData: Unit = ()

import org.apache.spark.ml.feature.VectorAssembler assembler: org.apache.spark.ml.feature.VectorAssembler = VectorAssembler: uid=vecAssembler_e3255d0f3101, handleInvalid=error, numInputCols=4 xgbInput: org.apache.spark.sql.DataFrame = [AT: double, V: double ... 4 more fields]

    split20: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [AT: double, V: double ... 4 more fields] split80: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [AT: double, V: double ... 4 more fields] testSet: split20.type = [AT: double, V: double ... 4 more fields] trainingSet: split80.type = [AT: double, V: double ... 4 more fields]

    Train the XGBoost model

    import ml.dmlc.xgboost4j.scala.spark.{XGBoostRegressionModel, XGBoostRegressor} xgbParam: scala.collection.immutable.Map[String,Any] = Map(num_workers -> 2, max_depth -> 6, objective -> reg:squarederror, num_round -> 10, eta -> 0.3) xgbRegressor: ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor = xgbr_8fbae9c09bc0

      Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=51469, DMLC_NUM_WORKER=2} xgbRegressionModel: ml.dmlc.xgboost4j.scala.spark.XGBoostRegressionModel = xgbr_8fbae9c09bc0

      Evaluate the model

      Use RegressionEvaluator from MLlib to evaluate the XGBoost model.

        predictions: org.apache.spark.sql.DataFrame = [AT: double, V: double ... 5 more fields]

          import org.apache.spark.ml.evaluation.RegressionEvaluator evaluator: org.apache.spark.ml.evaluation.RegressionEvaluator = RegressionEvaluator: uid=regEval_83b65b551a4e, metricName=rmse, throughOrigin=false rmse: Double = 13.450027850150057

            Root mean squared error: 13.450027850150057

            Tune the model

            Use CrossValidator from MLlib to tune the XGBoost model.

            Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=55173, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=37075, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=58581, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=47463, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=43165, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=50885, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=47005, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=46371, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=37891, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=33205, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=52569, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=52333, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=47265, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=49411, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=46845, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=52399, DMLC_NUM_WORKER=2} Tracker started, with env={DMLC_NUM_SERVER=0, DMLC_TRACKER_URI=10.168.5.140, DMLC_TRACKER_PORT=40821, DMLC_NUM_WORKER=2} import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} paramGrid: Array[org.apache.spark.ml.param.ParamMap] = Array({ xgbr_8fbae9c09bc0-eta: 0.1, xgbr_8fbae9c09bc0-maxDepth: 4 }, { xgbr_8fbae9c09bc0-eta: 0.1, xgbr_8fbae9c09bc0-maxDepth: 7 }, { xgbr_8fbae9c09bc0-eta: 0.6, xgbr_8fbae9c09bc0-maxDepth: 4 }, { xgbr_8fbae9c09bc0-eta: 0.6, xgbr_8fbae9c09bc0-maxDepth: 7 }) cv: org.apache.spark.ml.tuning.CrossValidator = cv_04628ab83845 cvModel: org.apache.spark.ml.tuning.CrossValidatorModel = CrossValidatorModel: uid=cv_04628ab83845, bestModel=xgbr_8fbae9c09bc0, numFolds=4

              res8: org.apache.spark.ml.param.ParamMap = { xgbr_8fbae9c09bc0-allowNonZeroForMissing: false, xgbr_8fbae9c09bc0-batchSize: 32768, xgbr_8fbae9c09bc0-dmlcWorkerConnectRetry: 5, xgbr_8fbae9c09bc0-eta: 0.6, xgbr_8fbae9c09bc0-evalMetric: rmse, xgbr_8fbae9c09bc0-featuresCol: features, xgbr_8fbae9c09bc0-handleInvalid: error, xgbr_8fbae9c09bc0-labelCol: label, xgbr_8fbae9c09bc0-maxDepth: 7, xgbr_8fbae9c09bc0-missing: NaN, xgbr_8fbae9c09bc0-nthread: 1, xgbr_8fbae9c09bc0-numRound: 10, xgbr_8fbae9c09bc0-numWorkers: 2, xgbr_8fbae9c09bc0-objective: reg:squarederror, xgbr_8fbae9c09bc0-predictionCol: prediction, xgbr_8fbae9c09bc0-rabitRingReduceThreshold: 32768, xgbr_8fbae9c09bc0-rabitTimeout: -1, xgbr_8fbae9c09bc0-seed: 0, xgbr_8fbae9c09bc0-trackerConf: TrackerConf(0,python,,), xgbr_8fbae9c09bc0-trainTestRatio: 1.0, xgbr_8fbae9c09bc0-treeLimit: 0, xgbr_8fbae9c09bc0-useExternalMemory: false }

              The model tuning improved RMSE from 13.46 to 3.25.

              predictions2: org.apache.spark.sql.DataFrame = [AT: double, V: double ... 5 more fields] rmse2: Double = 3.1257189881608602

                Root mean squared error: 3.1257189881608602

                    res12: Boolean = true