Decision trees
A popular family of classification and regression methods. More information about the spark.ml implementation can be found further in the section on decision trees.

Examples

The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the DataFrame which the Decision Tree algorithm can recognize..
1
import org.apache.spark.ml.Pipeline
2
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
3
import org.apache.spark.ml.classification.DecisionTreeClassifier
4
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
5
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
6
​
7
// Load the data stored in LIBSVM format as a DataFrame.
8
val data = spark.read.format("libsvm").load("file:///opt/spark/data/mllib/sample_libsvm_data.txt")
9
​
10
// Index labels, adding metadata to the label column.
11
// Fit on whole dataset to include all labels in index.
12
val labelIndexer = new StringIndexer()
13
.setInputCol("label")
14
.setOutputCol("indexedLabel")
15
.fit(data)
16
// Automatically identify categorical features, and index them.
17
val featureIndexer = new VectorIndexer()
18
.setInputCol("features")
19
.setOutputCol("indexedFeatures")
20
.setMaxCategories(4) // features with > 4 distinct values are treated as continuous.
21
.fit(data)
22
​
23
// Split the data into training and test sets (30% held out for testing).
24
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
25
​
26
// Train a DecisionTree model.
27
val dt = new DecisionTreeClassifier()
28
.setLabelCol("indexedLabel")
29
.setFeaturesCol("indexedFeatures")
30
​
31
// Convert indexed labels back to original labels.
32
val labelConverter = new IndexToString()
33
.setInputCol("prediction")
34
.setOutputCol("predictedLabel")
35
.setLabels(labelIndexer.labels)
36
​
37
// Chain indexers and tree in a Pipeline.
38
val pipeline = new Pipeline()
39
.setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))
40
​
41
// Train model. This also runs the indexers.
42
val model = pipeline.fit(trainingData)
43
​
44
// Make predictions.
45
val predictions = model.transform(testData)
46
​
47
// Select example rows to display.
48
predictions.select("predictedLabel", "label", "features").show(5)
49
​
50
// Select (prediction, true label) and compute test error.
51
val evaluator = new MulticlassClassificationEvaluator()
52
.setLabelCol("indexedLabel")
53
.setPredictionCol("prediction")
54
.setMetricName("accuracy")
55
val accuracy = evaluator.evaluate(predictions)
56
println(s"Test Error = ${(1.0 - accuracy)}")
57
​
58
val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
59
println(s"Learned classification tree model:\n ${treeModel.toDebugString}")
60
​
61
/*
62
Output:
63
+--------------+-----+--------------------+
64
|predictedLabel|label| features|
65
+--------------+-----+--------------------+
66
| 0.0| 0.0|(692,[121,122,123...|
67
| 0.0| 0.0|(692,[123,124,125...|
68
| 0.0| 0.0|(692,[124,125,126...|
69
| 0.0| 0.0|(692,[124,125,126...|
70
| 0.0| 0.0|(692,[125,126,127...|
71
+--------------+-----+--------------------+
72
only showing top 5 rows
73
​
74
Test Error = 0.030303030303030276
75
Learned classification tree model:
76
DecisionTreeClassificationModel (uid=dtc_a286075ebc4c) of depth 2 with 5 nodes
77
If (feature 406 <= 22.0)
78
If (feature 99 in {2.0})
79
Predict: 0.0
80
Else (feature 99 not in {2.0})
81
Predict: 1.0
82
Else (feature 406 > 22.0)
83
Predict: 0.0
84
​
85
​
86
*/
Copied!
Last modified 1yr ago
Copy link
Contents
Examples