How to Extract Feature Information for Tree-based Apache SparkML Pipeline Models

When you are fitting a tree-based model, such as a decision tree, random forest, or gradient boosted tree, it is helpful to be able to review the feature importance levels along with the feature names. Typically models in SparkML are fit as the last stage of the pipeline. To extract the relevant feature information from the pipeline with the tree model, you must extract the correct pipeline stage. You can extract the feature names from the VectorAssembler object:

from import StringIndexer, VectorAssembler
from import DecisionTreeClassifier
from import Pipeline

pipeline = Pipeline(stages=[indexer, assembler, decision_tree)
DTmodel =
va = dtModel.stages[-2]
tree = DTmodel.stages[-1]

display(tree) #visualize the decision tree model
print(tree.toDebugString) #print the nodes of the decision tree model

list(zip(va.getInputCols(), tree.featureImportances))

You can also tune a tree-based model using a cross validator in the last stage of the pipeline. To visualize the decision tree and print the feature importance levels, you extract the bestModel from the CrossValidator object:

from import ParamGridBuilder, CrossValidator

cv = CrossValidator(estimator=decision_tree, estimatorParamMaps=paramGrid, evaluator=evaluator, numFolds=3)
pipelineCV = Pipeline(stages=[indexer, assembler, cv)
DTmodelCV =
va = DTmodelCV.stages[-2]
treeCV = DTmodelCV.stages[-1].bestModel

display(treeCV) #visualize the best decision tree model
print(treeCV.toDebugString) #print the nodes of the decision tree model

list(zip(va.getInputCols(), treeCV.featureImportances))

The display function visualizes decision tree models only. See Machine learning visualizations.