Random forests
Random forests are a popular family of classification and regression methods. More information about the spark.ml implementation can be found further in the section on random forests.
Examples
The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the DataFrame which the tree-based algorithms can recognize.
1
import org.apache.spark.ml.Pipeline
2
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
3
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
4
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
5
​
6
// Load and parse the data file, converting it to a DataFrame.
7
val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
8
​
9
// Index labels, adding metadata to the label column.
10
// Fit on whole dataset to include all labels in index.
11
val labelIndexer = new StringIndexer()
12
.setInputCol("label")
13
.setOutputCol("indexedLabel")
14
.fit(data)
15
// Automatically identify categorical features, and index them.
16
// Set maxCategories so features with > 4 distinct values are treated as continuous.
17
val featureIndexer = new VectorIndexer()
18
.setInputCol("features")
19
.setOutputCol("indexedFeatures")
20
.setMaxCategories(4)
21
.fit(data)
22
​
23
// Split the data into training and test sets (30% held out for testing).
24
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
25
​
26
// Train a RandomForest model.
27
val rf = new RandomForestClassifier()
28
.setLabelCol("indexedLabel")
29
.setFeaturesCol("indexedFeatures")
30
.setNumTrees(10)
31
​
32
// Convert indexed labels back to original labels.
33
val labelConverter = new IndexToString()
34
.setInputCol("prediction")
35
.setOutputCol("predictedLabel")
36
.setLabels(labelIndexer.labels)
37
​
38
// Chain indexers and forest in a Pipeline.
39
val pipeline = new Pipeline()
40
.setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))
41
​
42
// Train model. This also runs the indexers.
43
val model = pipeline.fit(trainingData)
44
​
45
// Make predictions.
46
val predictions = model.transform(testData)
47
​
48
// Select example rows to display.
49
predictions.select("predictedLabel", "label", "features").show(5)
50
​
51
// Select (prediction, true label) and compute test error.
52
val evaluator = new MulticlassClassificationEvaluator()
53
.setLabelCol("indexedLabel")
54
.setPredictionCol("prediction")
55
.setMetricName("accuracy")
56
val accuracy = evaluator.evaluate(predictions)
57
println(s"Test Error = ${(1.0 - accuracy)}")
58
​
59
val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
60
println(s"Learned classification forest model:\n ${rfModel.toDebugString}")
61
​
62
/*
63
Output:
64
+--------------+-----+--------------------+
65
|predictedLabel|label| features|
66
+--------------+-----+--------------------+
67
| 0.0| 0.0|(692,[95,96,97,12...|
68
| 0.0| 0.0|(692,[98,99,100,1...|
69
| 0.0| 0.0|(692,[123,124,125...|
70
| 0.0| 0.0|(692,[124,125,126...|
71
| 0.0| 0.0|(692,[124,125,126...|
72
+--------------+-----+--------------------+
73
only showing top 5 rows
74
​
75
Test Error = 0.0
76
Learned classification forest model:
77
RandomForestClassificationModel (uid=rfc_b10be44ab730) with 10 trees
78
Tree 0 (weight 1.0):
79
If (feature 510 <= 2.5)
80
If (feature 495 <= 111.0)
81
Predict: 0.0
82
Else (feature 495 > 111.0)
83
Predict: 1.0
84
Else (feature 510 > 2.5)
85
Predict: 1.0
86
Tree 1 (weight 1.0):
87
If (feature 567 <= 8.0)
88
If (feature 456 <= 31.5)
89
If (feature 373 <= 11.5)
90
Predict: 0.0
91
Else (feature 373 > 11.5)
92
Predict: 1.0
93
Else (feature 456 > 31.5)
94
Predict: 1.0
95
Else (feature 567 > 8.0)
96
If (feature 317 <= 9.5)
97
Predict: 0.0
98
Else (feature 317 > 9.5)
99
If (feature 491 <= 49.5)
100
Predict: 1.0
101
Else (feature 491 > 49.5)
102
Predict: 0.0
103
Tree 2 (weight 1.0):
104
If (feature 540 <= 87.0)
105
If (feature 576 <= 221.5)
106
Predict: 0.0
107
Else (feature 576 > 221.5)
108
If (feature 490 <= 15.5)
109
Predict: 1.0
110
Else (feature 490 > 15.5)
111
Predict: 0.0
112
Else (feature 540 > 87.0)
113
Predict: 1.0
114
Tree 3 (weight 1.0):
115
If (feature 518 <= 18.0)
116
If (feature 350 <= 97.5)
117
Predict: 1.0
118
Else (feature 350 > 97.5)
119
If (feature 356 <= 16.0)
120
Predict: 0.0
121
Else (feature 356 > 16.0)
122
Predict: 1.0
123
Else (feature 518 > 18.0)
124
Predict: 0.0
125
Tree 4 (weight 1.0):
126
If (feature 429 <= 11.5)
127
If (feature 358 <= 12.0)
128
Predict: 0.0
129
Else (feature 358 > 12.0)
130
Predict: 1.0
131
Else (feature 429 > 11.5)
132
Predict: 1.0
133
Tree 5 (weight 1.0):
134
If (feature 462 <= 62.5)
135
If (feature 240 <= 253.5)
136
Predict: 1.0
137
Else (feature 240 > 253.5)
138
Predict: 0.0
139
Else (feature 462 > 62.5)
140
Predict: 0.0
141
Tree 6 (weight 1.0):
142
If (feature 385 <= 4.0)
143
If (feature 545 <= 3.0)
144
If (feature 346 <= 2.0)
145
Predict: 0.0
146
Else (feature 346 > 2.0)
147
Predict: 1.0
148
Else (feature 545 > 3.0)
149
Predict: 0.0
150
Else (feature 385 > 4.0)
151
Predict: 1.0
152
Tree 7 (weight 1.0):
153
If (feature 512 <= 8.0)
154
If (feature 350 <= 7.0)
155
If (feature 298 <= 152.5)
156
Predict: 1.0
157
Else (feature 298 > 152.5)
158
Predict: 0.0
159
Else (feature 350 > 7.0)
160
Predict: 0.0
161
Else (feature 512 > 8.0)
162
Predict: 1.0
163
Tree 8 (weight 1.0):
164
If (feature 462 <= 62.5)
165
If (feature 324 <= 253.5)
166
Predict: 1.0
167
Else (feature 324 > 253.5)
168
Predict: 0.0
169
Else (feature 462 > 62.5)
170
Predict: 0.0
171
Tree 9 (weight 1.0):
172
If (feature 301 <= 30.0)
173
If (feature 517 <= 20.5)
174
If (feature 630 <= 5.0)
175
Predict: 0.0
176
Else (feature 630 > 5.0)
177
Predict: 1.0
178
Else (feature 517 > 20.5)
179
Predict: 0.0
180
Else (feature 301 > 30.0)
181
Predict: 1.0
182
​
183
​
184
*/
Copied!
Last modified 1yr ago
Copy link