Train a Decision Tree
Begin by training a decision tree using the default settings. Before training, you use the StringIndexer
class to tell the algorithm that the labels are categories 0-9, instead of continuous values. Then, you use a Pipeline
to connect the feature preprocessing with the decision tree algorithm. ML Pipelines are Spark tools for linking Machine Learning algorithms into workflows. To learn more about Pipelines, see the other ML example notebooks in Databricks and the ML Pipelines user guide.
Exploring "maxDepth": training trees of different sizes
This section tunes a single hyperparameter maxDepth
, which determines how deep (and large) the tree can be. It trains trees at varying depths and see how it affects the accuracy on the held-out test set.
Note: The next cell can take about 1 minute to run since it is training several trees which get deeper and deeper.
Even though deeper trees are more powerful, they are not always better. As tree depth increases, training takes longer. You also risk overfitting (fitting the training data so well that performance decreases on test data); it is important to tune parameters based on held-out data to prevent overfitting.
Exploring maxBins
: discretization for efficient distributed computing
This section explores a more expert-level setting maxBins
. For efficient distributed training of Decision Trees, Spark and most other libraries discretize (or "bin") continuous features (such as pixel values) into a finite number of values. This is an important step for the distributed implementation, but it introduces a tradeoff: Larger maxBins
mean your data will be more accurately represented, but it will also mean more communication (and slower training).
The default value of maxBins
generally works, but it is interesting to explore on the handwritten digit dataset. The images are in grayscale. Setting maxBins = 2
, effectively makes it a black-and-white image. Does that affect the accuracy of the model?
What's next?
- Explore: Try out tuning other parameters of trees---or even ensembles like Random Forests or Gradient-Boosted Trees.
- Automated tuning: This type of tuning does not have to be done manually. (It is done manually here to show the effects of tuning in detail.) MLlib provides automated tuning functionality using
CrossValidator
. See the other Databricks ML Pipeline guides or the Spark ML user guide for details.
Decision Trees for handwritten digit recognition
This notebook demonstrates learning a Decision Tree using Spark's distributed implementation. It gives the reader a better understanding of some critical hyperparameters for the tree learning algorithm, using examples to demonstrate how tuning the hyperparameters can improve accuracy.
This notebook uses the classic MNIST handwritten digit recognition dataset.