decision-trees(Scala)

Loading...

Decision Trees for handwritten digit recognition

This notebook demonstrates learning a Decision Tree using Spark's distributed implementation. It gives the reader a better understanding of some critical hyperparameters for the tree learning algorithm, using examples to demonstrate how tuning the hyperparameters can improve accuracy.

This notebook uses the classic MNIST handwritten digit recognition dataset.

Load MNIST training and test datasets

The datasets consist of vectors of pixels representing images of handwritten digits. These datasets are stored in the popular LibSVM dataset format. This notebook uses MLlib's LibSVM dataset reader utility to load the datasets.

Display the data. Each image has the true label (the label column) and a vector of features which represent pixel intensities.

    Train a Decision Tree

    Begin by training a decision tree using the default settings. Before training, you use the StringIndexer class to tell the algorithm that the labels are categories 0-9, instead of continuous values. Then, you use a Pipeline to connect the feature preprocessing with the decision tree algorithm. ML Pipelines are Spark tools for linking Machine Learning algorithms into workflows. To learn more about Pipelines, see the other ML example notebooks in Databricks and the ML Pipelines user guide.

    Now fit a model to the data.

      Use Databricks ML visualization to inspect the learned tree. Visualization is only available for some models.

      You can see above how the tree makes predictions. When classifying a new example, the tree starts at the "root" node (at the top). Each tree node tests a pixel value and goes either left or right. At the bottom "leaf" nodes, the tree predicts a digit as the image's label.

      Exploring "maxDepth": training trees of different sizes

      This section tunes a single hyperparameter maxDepth, which determines how deep (and large) the tree can be. It trains trees at varying depths and see how it affects the accuracy on the held-out test set.

      Note: The next cell can take about 1 minute to run since it is training several trees which get deeper and deeper.

      You can display accuracy results and see immediately that deeper, larger trees are more powerful classifiers, achieving higher accuracies.

      Note: When you run display(), you will get a table. Click on the plot icon below the table to create a plot, and use "Plot Options" to adjust what is displayed.

        Even though deeper trees are more powerful, they are not always better. As tree depth increases, training takes longer. You also risk overfitting (fitting the training data so well that performance decreases on test data); it is important to tune parameters based on held-out data to prevent overfitting.

        Exploring maxBins: discretization for efficient distributed computing

        This section explores a more expert-level setting maxBins. For efficient distributed training of Decision Trees, Spark and most other libraries discretize (or "bin") continuous features (such as pixel values) into a finite number of values. This is an important step for the distributed implementation, but it introduces a tradeoff: Larger maxBins mean your data will be more accurately represented, but it will also mean more communication (and slower training).

        The default value of maxBins generally works, but it is interesting to explore on the handwritten digit dataset. The images are in grayscale. Setting maxBins = 2, effectively makes it a black-and-white image. Does that affect the accuracy of the model?

        The extreme discretization (black and white) decreases accuracy, but only a bit. Using more bins increases the accuracy, but also makes learning more costly.

        What's next?

        • Explore: Try out tuning other parameters of trees---or even ensembles like Random Forests or Gradient-Boosted Trees.
        • Automated tuning: This type of tuning does not have to be done manually. (It is done manually here to show the effects of tuning in detail.) MLlib provides automated tuning functionality using CrossValidator. See the other Databricks ML Pipeline guides or the Spark ML user guide for details.