Random Forest Regression
Random forests are a popular family of classification and regression methods. More information about the spark.ml implementation can be found further in the section on random forests.
Examples
The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use a feature transformer to index categorical features, adding metadata to the DataFrame which the tree-based algorithms can recognize.
1
import org.apache.spark.ml.Pipeline
2
import org.apache.spark.ml.evaluation.RegressionEvaluator
3
import org.apache.spark.ml.feature.VectorIndexer
4
import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor}
5
​
6
// Load and parse the data file, converting it to a DataFrame.
7
val data = spark.read.format("libsvm").load("file:///opt/spark/data/mllib/sample_libsvm_data.txt")
8
​
9
// Automatically identify categorical features, and index them.
10
// Set maxCategories so features with > 4 distinct values are treated as continuous.
11
val featureIndexer = new VectorIndexer()
12
.setInputCol("features")
13
.setOutputCol("indexedFeatures")
14
.setMaxCategories(4)
15
.fit(data)
16
​
17
// Split the data into training and test sets (30% held out for testing).
18
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
19
​
20
// Train a RandomForest model.
21
val rf = new RandomForestRegressor()
22
.setLabelCol("label")
23
.setFeaturesCol("indexedFeatures")
24
​
25
// Chain indexer and forest in a Pipeline.
26
val pipeline = new Pipeline()
27
.setStages(Array(featureIndexer, rf))
28
​
29
// Train model. This also runs the indexer.
30
val model = pipeline.fit(trainingData)
31
​
32
// Make predictions.
33
val predictions = model.transform(testData)
34
​
35
// Select example rows to display.
36
predictions.select("prediction", "label", "features").show(5)
37
​
38
// Select (prediction, true label) and compute test error.
39
val evaluator = new RegressionEvaluator()
40
.setLabelCol("label")
41
.setPredictionCol("prediction")
42
.setMetricName("rmse")
43
val rmse = evaluator.evaluate(predictions)
44
println(s"Root Mean Squared Error (RMSE) on test data = $rmse")
45
​
46
val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel]
47
println(s"Learned regression forest model:\n ${rfModel.toDebugString}")
48
​
49
/*
50
Output:
51
+----------+-----+--------------------+
52
|prediction|label| features|
53
+----------+-----+--------------------+
54
| 0.0| 0.0|(692,[98,99,100,1...|
55
| 0.1| 0.0|(692,[122,123,148...|
56
| 0.0| 0.0|(692,[123,124,125...|
57
| 0.0| 0.0|(692,[124,125,126...|
58
| 0.0| 0.0|(692,[124,125,126...|
59
+----------+-----+--------------------+
60
only showing top 5 rows
61
​
62
Root Mean Squared Error (RMSE) on test data = 0.14711511527663346
63
Learned regression forest model:
64
RandomForestRegressionModel (uid=rfr_f01d85b28d5a) with 20 trees
65
Tree 0 (weight 1.0):
66
If (feature 434 <= 88.5)
67
Predict: 0.0
68
Else (feature 434 > 88.5)
69
Predict: 1.0
70
Tree 1 (weight 1.0):
71
If (feature 490 <= 15.5)
72
Predict: 0.0
73
Else (feature 490 > 15.5)
74
Predict: 1.0
75
Tree 2 (weight 1.0):
76
If (feature 462 <= 62.5)
77
Predict: 0.0
78
Else (feature 462 > 62.5)
79
Predict: 1.0
80
Tree 3 (weight 1.0):
81
If (feature 461 <= 71.0)
82
If (feature 343 <= 253.5)
83
Predict: 0.0
84
Else (feature 343 > 253.5)
85
Predict: 1.0
86
Else (feature 461 > 71.0)
87
Predict: 1.0
88
Tree 4 (weight 1.0):
89
If (feature 483 <= 15.5)
90
If (feature 318 <= 223.0)
91
Predict: 1.0
92
Else (feature 318 > 223.0)
93
Predict: 0.0
94
Else (feature 483 > 15.5)
95
Predict: 0.0
96
Tree 5 (weight 1.0):
97
If (feature 405 <= 106.0)
98
If (feature 490 <= 15.5)
99
Predict: 0.0
100
Else (feature 490 > 15.5)
101
Predict: 1.0
102
Else (feature 405 > 106.0)
103
Predict: 1.0
104
Tree 6 (weight 1.0):
105
If (feature 490 <= 44.5)
106
Predict: 0.0
107
Else (feature 490 > 44.5)
108
Predict: 1.0
109
Tree 7 (weight 1.0):
110
If (feature 400 <= 4.5)
111
If (feature 375 <= 103.0)
112
Predict: 1.0
113
Else (feature 375 > 103.0)
114
Predict: 0.0
115
Else (feature 400 > 4.5)
116
Predict: 0.0
117
Tree 8 (weight 1.0):
118
If (feature 406 <= 126.5)
119
Predict: 0.0
120
Else (feature 406 > 126.5)
121
Predict: 1.0
122
Tree 9 (weight 1.0):
123
If (feature 490 <= 44.5)
124
Predict: 0.0
125
Else (feature 490 > 44.5)
126
Predict: 1.0
127
Tree 10 (weight 1.0):
128
If (feature 345 <= 6.5)
129
Predict: 1.0
130
Else (feature 345 > 6.5)
131
Predict: 0.0
132
Tree 11 (weight 1.0):
133
If (feature 406 <= 126.5)
134
If (feature 436 <= 1.5)
135
Predict: 0.0
136
Else (feature 436 > 1.5)
137
Predict: 1.0
138
Else (feature 406 > 126.5)
139
Predict: 1.0
140
Tree 12 (weight 1.0):
141
If (feature 489 <= 1.5)
142
Predict: 0.0
143
Else (feature 489 > 1.5)
144
Predict: 1.0
145
Tree 13 (weight 1.0):
146
If (feature 462 <= 62.5)
147
Predict: 0.0
148
Else (feature 462 > 62.5)
149
Predict: 1.0
150
Tree 14 (weight 1.0):
151
If (feature 435 <= 32.5)
152
If (feature 488 <= 141.0)
153
Predict: 0.0
154
Else (feature 488 > 141.0)
155
Predict: 1.0
156
Else (feature 435 > 32.5)
157
Predict: 1.0
158
Tree 15 (weight 1.0):
159
If (feature 489 <= 1.5)
160
If (feature 519 <= 146.0)
161
Predict: 0.0
162
Else (feature 519 > 146.0)
163
Predict: 1.0
164
Else (feature 489 > 1.5)
165
Predict: 1.0
166
Tree 16 (weight 1.0):
167
If (feature 434 <= 88.5)
168
Predict: 0.0
169
Else (feature 434 > 88.5)
170
Predict: 1.0
171
Tree 17 (weight 1.0):
172
If (feature 378 <= 18.0)
173
Predict: 0.0
174
Else (feature 378 > 18.0)
175
Predict: 1.0
176
Tree 18 (weight 1.0):
177
If (feature 434 <= 88.5)
178
Predict: 0.0
179
Else (feature 434 > 88.5)
180
Predict: 1.0
181
Tree 19 (weight 1.0):
182
If (feature 490 <= 44.5)
183
Predict: 0.0
184
Else (feature 490 > 44.5)
185
Predict: 1.0
186
​
187
​
188
​
189
*/
Copied!
Last modified 1yr ago
Copy link