xgboost-pyspark(Python)

Loading...

Regression with XGBoost and MLlib pipelines

This notebook uses a bike-sharing dataset to illustrate MLlib pipelines and the XGBoost machine learning algorithm. The challenge is to predict the number of bicycle rentals per hour based on the features available in the dataset such as day of the week, weather, season, and so on. Demand prediction is a common problem across businesses; good predictions allow a business or service to optimize inventory and to match supply and demand to make customers happy and maximize profitability.

For more information about the PySpark ML XgboostRegressor estimator used in this notebook, see XgboostRegressor.

Requirements

Databricks Runtime 12.2 LTS ML or below. sparkdl.xgboost is removed in Databricks Runtime 13.0 ML and above.

    /databricks/python_shell/dbruntime/PostImportHook.py:184: FutureWarning: `sparkdl.xgboost` is deprecated and will be removed in a future Databricks Runtime release. Use `xgboost.spark` instead. See https://docs.databricks.com/machine-learning/train-model/xgboost-spark.html#xgboost-migration for migration. hook(module)

    Load the dataset

    The dataset is from the UCI Machine Learning Repository and is provided with Databricks Runtime. The dataset includes information about bicycle rentals from the Capital bikeshare system in 2011 and 2012.

    Load the data using the CSV datasource for Spark, which creates a Spark DataFrame.

    Out[2]: DataFrame[instant: int, dteday: date, season: int, yr: int, mnth: int, hr: int, holiday: int, weekday: int, workingday: int, weathersit: int, temp: double, atemp: double, hum: double, windspeed: double, casual: int, registered: int, cnt: int]

    Data description

    The following columns are included in the dataset:

    Index column:

    • instant: record index

    Feature columns:

    • dteday: date
    • season: season (1:spring, 2:summer, 3:fall, 4:winter)
    • yr: year (0:2011, 1:2012)
    • mnth: month (1 to 12)
    • hr: hour (0 to 23)
    • holiday: 1 if holiday, 0 otherwise
    • weekday: day of the week (0 to 6)
    • workingday: 0 if weekend or holiday, 1 otherwise
    • weathersit: (1:clear, 2:mist or clouds, 3:light rain or snow, 4:heavy rain or snow)
    • temp: normalized temperature in Celsius
    • atemp: normalized feeling temperature in Celsius
    • hum: normalized humidity
    • windspeed: normalized wind speed

    Label columns:

    • casual: count of casual users
    • registered: count of registered users
    • cnt: count of total rental bikes including both casual and registered

    Call display() on a DataFrame to see a sample of the data. The first row shows that 16 people rented bikes between midnight and 1am on January 1, 2011.

      Copied!
       
      instant
      dteday
      season
      yr
      mnth
      hr
      holiday
      weekday
      workingday
      weathersit
      temp
      atemp
      1
      2
      3
      4
      5
      6
      7
      8
      9
      10
      11
      12
      13
      14
      15
      16
      1
      2011-01-01
      1
      0
      1
      0
      0
      6
      0
      1
      0.24
      0.2879
      2
      2011-01-01
      1
      0
      1
      1
      0
      6
      0
      1
      0.22
      0.2727
      3
      2011-01-01
      1
      0
      1
      2
      0
      6
      0
      1
      0.22
      0.2727
      4
      2011-01-01
      1
      0
      1
      3
      0
      6
      0
      1
      0.24
      0.2879
      5
      2011-01-01
      1
      0
      1
      4
      0
      6
      0
      1
      0.24
      0.2879
      6
      2011-01-01
      1
      0
      1
      5
      0
      6
      0
      2
      0.24
      0.2576
      7
      2011-01-01
      1
      0
      1
      6
      0
      6
      0
      1
      0.22
      0.2727
      8
      2011-01-01
      1
      0
      1
      7
      0
      6
      0
      1
      0.2
      0.2576
      9
      2011-01-01
      1
      0
      1
      8
      0
      6
      0
      1
      0.24
      0.2879
      10
      2011-01-01
      1
      0
      1
      9
      0
      6
      0
      1
      0.32
      0.3485
      11
      2011-01-01
      1
      0
      1
      10
      0
      6
      0
      1
      0.38
      0.3939
      12
      2011-01-01
      1
      0
      1
      11
      0
      6
      0
      1
      0.36
      0.3333
      13
      2011-01-01
      1
      0
      1
      12
      0
      6
      0
      1
      0.42
      0.4242
      14
      2011-01-01
      1
      0
      1
      13
      0
      6
      0
      2
      0.46
      0.4545
      15
      2011-01-01
      1
      0
      1
      14
      0
      6
      0
      2
      0.46
      0.4545
      16
      2011-01-01
      1
      0
      1
      15
      0
      6
      0
      2
      0.44
      0.4394
      10,000 rows|Truncated data

        The dataset has 17379 rows.

        Preprocess data

        This dataset is well prepared for machine learning algorithms. The numeric input columns (temp, atemp, hum, and windspeed) are normalized, categorial values (season, yr, mnth, hr, holiday, weekday, workingday, weathersit) are converted to indices, and all of the columns except for the date (dteday) are numeric.

        The goal is to predict the count of bike rentals (the cnt column). Reviewing the dataset, you can see that some columns contain duplicate information. For example, the cnt column equals the sum of the casual and registered columns. You should remove the casual and registered columns from the dataset. The index column instant is also not useful as a predictor.

        You can also delete the column dteday, as this information is already included in the other date-related columns yr, mnth, and weekday.

        Copied!
         
        season
        yr
        mnth
        hr
        holiday
        weekday
        workingday
        weathersit
        temp
        atemp
        hum
        windspeed
        1
        2
        3
        4
        5
        6
        7
        8
        9
        10
        11
        12
        13
        14
        15
        16
        1
        0
        1
        0
        0
        6
        0
        1
        0.24
        0.2879
        0.81
        0
        1
        0
        1
        1
        0
        6
        0
        1
        0.22
        0.2727
        0.8
        0
        1
        0
        1
        2
        0
        6
        0
        1
        0.22
        0.2727
        0.8
        0
        1
        0
        1
        3
        0
        6
        0
        1
        0.24
        0.2879
        0.75
        0
        1
        0
        1
        4
        0
        6
        0
        1
        0.24
        0.2879
        0.75
        0
        1
        0
        1
        5
        0
        6
        0
        2
        0.24
        0.2576
        0.75
        0.0896
        1
        0
        1
        6
        0
        6
        0
        1
        0.22
        0.2727
        0.8
        0
        1
        0
        1
        7
        0
        6
        0
        1
        0.2
        0.2576
        0.86
        0
        1
        0
        1
        8
        0
        6
        0
        1
        0.24
        0.2879
        0.75
        0
        1
        0
        1
        9
        0
        6
        0
        1
        0.32
        0.3485
        0.76
        0
        1
        0
        1
        10
        0
        6
        0
        1
        0.38
        0.3939
        0.76
        0.2537
        1
        0
        1
        11
        0
        6
        0
        1
        0.36
        0.3333
        0.81
        0.2836
        1
        0
        1
        12
        0
        6
        0
        1
        0.42
        0.4242
        0.77
        0.2836
        1
        0
        1
        13
        0
        6
        0
        2
        0.46
        0.4545
        0.72
        0.2985
        1
        0
        1
        14
        0
        6
        0
        2
        0.46
        0.4545
        0.72
        0.2836
        1
        0
        1
        15
        0
        6
        0
        2
        0.44
        0.4394
        0.77
        0.2985
        10,000 rows|Truncated data

        Print the dataset schema to see the type of each column.

          root |-- season: integer (nullable = true) |-- yr: integer (nullable = true) |-- mnth: integer (nullable = true) |-- hr: integer (nullable = true) |-- holiday: integer (nullable = true) |-- weekday: integer (nullable = true) |-- workingday: integer (nullable = true) |-- weathersit: integer (nullable = true) |-- temp: double (nullable = true) |-- atemp: double (nullable = true) |-- hum: double (nullable = true) |-- windspeed: double (nullable = true) |-- cnt: integer (nullable = true)

          Split data into training and test sets

          Randomly split data into training and test sets. By doing this, you can train and tune the model using only the training subset, and then evaluate the model's performance on the test set to get a sense of how the model will perform on new data.

          There are 12081 training examples and 5298 test examples.

          Visualize the data

          You can plot the data to explore it visually. The following plot shows the number of bicycle rentals during each hour of the day. As you might expect, rentals are low during the night, and peak at commute hours.

          To create plots, call display() on a DataFrame in Databricks and click the plot icon below the table.

          To create the plot shown, run the command in the following cell. The results appear in a table. From the drop-down menu below the table, select "Line". Click Plot Options.... In the dialog, drag hr to the Keys field, and drag cnt to the Values field. Also in the Keys field, click the "x" next to <id> to remove it. In the Aggregation drop down, select "AVG".

            Aggregated (by avg) in the backend.

            24 rows

            Train the machine learning pipeline

            Now that you have reviewed the data and prepared it as a DataFrame with numeric values, you're ready to train a model to predict future bike sharing rentals.

            Most MLlib algorithms require a single input column containing a vector of features and a single target column. The DataFrame currently has one column for each feature. MLlib provides functions to help you prepare the dataset in the required format.

            MLlib pipelines combine multiple steps into a single workflow, making it easier to iterate as you develop the model.

            In this example, you create a pipeline using the following functions:

            • VectorAssembler: Assembles the feature columns into a feature vector.
            • VectorIndexer: Identifies columns that should be treated as categorical. This is done heuristically, identifying any column with a small number of distinct values as categorical. In this example, the following columns are considered categorical: yr (2 values), season (4 values), holiday (2 values), workingday (2 values), and weathersit (4 values).
            • XgboostRegressor: Uses the XgboostRegressor estimator to learn how to predict rental counts from the feature vectors.
            • CrossValidator: The XGBoost regression algorithm has several hyperparameters. This notebook illustrates how to use hyperparameter tuning in Spark. This capability automatically tests a grid of hyperparameters and chooses the best resulting model.

            For more information:
            VectorAssembler
            VectorIndexer

            The first step is to create the VectorAssembler and VectorIndexer steps.

            Next, define the model. To use distributed training (requires Databricks Runtime for Machine Learning 9.1 LTS ML or above), set num_workers to the number of available workers in the cluster. To use the all Spark task slots, set num_workers=sc.defaultParallelism.

            The third step is to wrap the model you just defined in a CrossValidator stage. CrossValidator calls the XgboostRegressor estimator with different hyperparameter settings. It trains multiple models and selects the best one, based on minimizing a specified metric. In this example, the metric is root mean squared error (RMSE).

            Create the pipeline.

            Train the pipeline.

            Now that you have set up the workflow, you can train the pipeline with a single call.
            When you call fit(), the pipeline runs feature processing, model tuning, and training and returns a fitted pipeline with the best model it found. This step takes several minutes.

              Make predictions and evaluate results

              The final step is to use the fitted model to make predictions on the test dataset and evaluate the model's performance. The model's performance on the test dataset provides an approximation of how it is likely to perform on new data. For example, if you had weather predictions for the next week, you could predict bike rentals expected during the next week.

              Computing evaluation metrics is important for understanding the quality of predictions, as well as for comparing models and tuning parameters.

              The transform() method of the pipeline model applies the full pipeline to the input dataset. The pipeline applies the feature processing steps to the dataset and then uses the fitted Xgboost Regressor model to make predictions. The pipeline returns a DataFrame with a new column predictions.

                  Copied!
                   
                  cnt
                  prediction
                  season
                  yr
                  mnth
                  hr
                  holiday
                  weekday
                  workingday
                  weathersit
                  temp
                  atemp
                  1
                  2
                  3
                  4
                  5
                  6
                  7
                  8
                  9
                  10
                  11
                  12
                  13
                  14
                  15
                  16
                  22
                  43.09138488769531
                  1
                  0
                  1
                  0
                  0
                  0
                  0
                  1
                  0.04
                  0.0758
                  17
                  41.32030487060547
                  1
                  0
                  1
                  0
                  0
                  0
                  0
                  2
                  0.46
                  0.4545
                  7
                  15.841835975646973
                  1
                  0
                  1
                  0
                  0
                  1
                  1
                  2
                  0.24
                  0.2273
                  17
                  13.308489799499512
                  1
                  0
                  1
                  0
                  0
                  5
                  1
                  2
                  0.2
                  0.197
                  9
                  7.332301616668701
                  1
                  0
                  1
                  0
                  0
                  5
                  1
                  2
                  0.2
                  0.2121
                  17
                  33.36613845825195
                  1
                  0
                  1
                  0
                  1
                  1
                  0
                  2
                  0.2
                  0.197
                  13
                  33.566200256347656
                  1
                  0
                  1
                  1
                  0
                  0
                  0
                  1
                  0.04
                  0.0758
                  12
                  26.840267181396484
                  1
                  0
                  1
                  1
                  0
                  0
                  0
                  1
                  0.1
                  0.0606
                  17
                  19.8973388671875
                  1
                  0
                  1
                  1
                  0
                  0
                  0
                  2
                  0.44
                  0.4394
                  2
                  5.518710136413574
                  1
                  0
                  1
                  1
                  0
                  1
                  1
                  1
                  0.2
                  0.1667
                  7
                  9.750831604003906
                  1
                  0
                  1
                  1
                  0
                  1
                  1
                  1
                  0.22
                  0.2121
                  2
                  7.6698527336120605
                  1
                  0
                  1
                  1
                  0
                  2
                  1
                  1
                  0.16
                  0.1818
                  6
                  10.415045738220215
                  1
                  0
                  1
                  1
                  0
                  3
                  1
                  2
                  0.16
                  0.1818
                  7
                  4.704164505004883
                  1
                  0
                  1
                  1
                  0
                  5
                  1
                  2
                  0.2
                  0.197
                  3
                  4.3393144607543945
                  1
                  0
                  1
                  1
                  0
                  5
                  1
                  2
                  0.2
                  0.2121
                  40
                  33.21678924560547
                  1
                  0
                  1
                  1
                  0
                  6
                  0
                  1
                  0.22
                  0.2727
                  5,298 rows

                  A common way to evaluate the performance of a regression model is the calculate the root mean squared error (RMSE). The value is not very informative on its own, but you can use it to compare different models. CrossValidator determines the best model by selecting the one that minimizes RMSE.

                  RMSE on our test set: 41.8323

                  You can also plot the results, as you did the original dataset. In this case, the hourly count of rentals shows a similar shape.

                    5,298 rows

                    Save and reload the model

                    The xgboost training will use single worker and set nthread=1 (equal to `spark.task.cpus` config), If you need to increase threads number used in training, you can set `nthread` param.

                    The xgboost training will use single worker and set nthread=1 (equal to `spark.task.cpus` config), If you need to increase threads number used in training, you can set `nthread` param.
                    Copied!
                     
                    cnt
                    prediction
                    season
                    yr
                    mnth
                    hr
                    holiday
                    weekday
                    workingday
                    weathersit
                    temp
                    atemp
                    1
                    2
                    3
                    22
                    43.09138488769531
                    1
                    0
                    1
                    0
                    0
                    0
                    0
                    1
                    0.04
                    0.0758
                    17
                    41.32030487060547
                    1
                    0
                    1
                    0
                    0
                    0
                    0
                    2
                    0.46
                    0.4545
                    7
                    15.841835975646973
                    1
                    0
                    1
                    0
                    0
                    1
                    1
                    2
                    0.24
                    0.2273
                    3 rows