Gradient-boosted trees (GBTs)
Gradient-boosted trees (GBTs) are a popular classification and regression method using ensembles of decision trees. More information about the spark.ml implementation can be found further in the section on GBTs.
Examples
The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the DataFrame which the tree-based algorithms can recognize.
1
import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
2
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
3
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
4
import org.apache.spark.ml.Pipeline
5
​
6
// Load and parse the data file, converting it to a DataFrame.
7
val data = spark.read.format("libsvm").load("file:///opt/spark/data/mllib/sample_libsvm_data.txt")
8
​
9
// Index labels, adding metadata to the label column.
10
// Fit on whole dataset to include all labels in index.
11
val labelIndexer = new StringIndexer()
12
.setInputCol("label")
13
.setOutputCol("indexedLabel")
14
.fit(data)
15
// Automatically identify categorical features, and index them.
16
// Set maxCategories so features with > 4 distinct values are treated as continuous.
17
val featureIndexer = new VectorIndexer()
18
.setInputCol("features")
19
.setOutputCol("indexedFeatures")
20
.setMaxCategories(4)
21
.fit(data)
22
​
23
// Split the data into training and test sets (30% held out for testing).
24
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
25
​
26
// Train a GBT model.
27
val gbt = new GBTClassifier()
28
.setLabelCol("indexedLabel")
29
.setFeaturesCol("indexedFeatures")
30
.setMaxIter(10)
31
.setFeatureSubsetStrategy("auto")
32
​
33
// Convert indexed labels back to original labels.
34
val labelConverter = new IndexToString()
35
.setInputCol("prediction")
36
.setOutputCol("predictedLabel")
37
.setLabels(labelIndexer.labels)
38
​
39
// Chain indexers and GBT in a Pipeline.
40
val pipeline = new Pipeline()
41
.setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter))
42
​
43
// Train model. This also runs the indexers.
44
val model = pipeline.fit(trainingData)
45
​
46
// Make predictions.
47
val predictions = model.transform(testData)
48
​
49
// Select example rows to display.
50
predictions.select("predictedLabel", "label", "features").show(5)
51
​
52
// Select (prediction, true label) and compute test error.
53
val evaluator = new MulticlassClassificationEvaluator()
54
.setLabelCol("indexedLabel")
55
.setPredictionCol("prediction")
56
.setMetricName("accuracy")
57
val accuracy = evaluator.evaluate(predictions)
58
println(s"Test Error = ${1.0 - accuracy}")
59
​
60
val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel]
61
println(s"Learned classification GBT model:\n ${gbtModel.toDebugString}")
62
​
63
/*
64
Output:
65
+--------------+-----+--------------------+
66
|predictedLabel|label| features|
67
+--------------+-----+--------------------+
68
| 0.0| 0.0|(692,[123,124,125...|
69
| 0.0| 0.0|(692,[124,125,126...|
70
| 0.0| 0.0|(692,[124,125,126...|
71
| 0.0| 0.0|(692,[125,126,127...|
72
| 0.0| 0.0|(692,[126,127,128...|
73
+--------------+-----+--------------------+
74
only showing top 5 rows
75
​
76
Test Error = 0.0
77
Learned classification GBT model:
78
GBTClassificationModel (uid=gbtc_ef6c4e8f1ddc) with 10 trees
79
Tree 0 (weight 1.0):
80
If (feature 434 <= 88.5)
81
If (feature 99 in {2.0})
82
Predict: -1.0
83
Else (feature 99 not in {2.0})
84
Predict: 1.0
85
Else (feature 434 > 88.5)
86
Predict: -1.0
87
Tree 1 (weight 0.1):
88
If (feature 434 <= 88.5)
89
If (feature 549 <= 253.5)
90
If (feature 400 <= 159.5)
91
Predict: 0.4768116880884702
92
Else (feature 400 > 159.5)
93
Predict: 0.4768116880884703
94
Else (feature 549 > 253.5)
95
Predict: -0.4768116880884694
96
Else (feature 434 > 88.5)
97
If (feature 267 <= 254.5)
98
Predict: -0.47681168808847024
99
Else (feature 267 > 254.5)
100
Predict: -0.4768116880884712
101
Tree 2 (weight 0.1):
102
If (feature 434 <= 88.5)
103
If (feature 243 <= 25.0)
104
Predict: -0.4381935810427206
105
Else (feature 243 > 25.0)
106
If (feature 182 <= 32.0)
107
Predict: 0.4381935810427206
108
Else (feature 182 > 32.0)
109
If (feature 154 <= 9.5)
110
Predict: 0.4381935810427206
111
Else (feature 154 > 9.5)
112
Predict: 0.43819358104272066
113
Else (feature 434 > 88.5)
114
If (feature 461 <= 66.5)
115
Predict: -0.4381935810427206
116
Else (feature 461 > 66.5)
117
Predict: -0.43819358104272066
118
Tree 3 (weight 0.1):
119
If (feature 462 <= 62.5)
120
If (feature 549 <= 253.5)
121
Predict: 0.4051496802845983
122
Else (feature 549 > 253.5)
123
Predict: -0.4051496802845982
124
Else (feature 462 > 62.5)
125
If (feature 433 <= 244.0)
126
Predict: -0.4051496802845983
127
Else (feature 433 > 244.0)
128
Predict: -0.40514968028459836
129
Tree 4 (weight 0.1):
130
If (feature 462 <= 62.5)
131
If (feature 100 <= 193.5)
132
If (feature 235 <= 80.5)
133
If (feature 183 <= 88.5)
134
Predict: 0.3765841318352991
135
Else (feature 183 > 88.5)
136
If (feature 239 <= 9.0)
137
Predict: 0.3765841318352991
138
Else (feature 239 > 9.0)
139
Predict: 0.37658413183529915
140
Else (feature 235 > 80.5)
141
Predict: 0.3765841318352994
142
Else (feature 100 > 193.5)
143
Predict: -0.3765841318352994
144
Else (feature 462 > 62.5)
145
If (feature 129 <= 58.0)
146
If (feature 515 <= 88.0)
147
Predict: -0.37658413183529915
148
Else (feature 515 > 88.0)
149
Predict: -0.3765841318352994
150
Else (feature 129 > 58.0)
151
Predict: -0.3765841318352994
152
Tree 5 (weight 0.1):
153
If (feature 462 <= 62.5)
154
If (feature 293 <= 253.5)
155
Predict: 0.35166478958101
156
Else (feature 293 > 253.5)
157
Predict: -0.3516647895810099
158
Else (feature 462 > 62.5)
159
If (feature 433 <= 244.0)
160
Predict: -0.35166478958101005
161
Else (feature 433 > 244.0)
162
Predict: -0.3516647895810101
163
Tree 6 (weight 0.1):
164
If (feature 434 <= 88.5)
165
If (feature 548 <= 253.5)
166
If (feature 154 <= 24.0)
167
Predict: 0.32974984655529926
168
Else (feature 154 > 24.0)
169
Predict: 0.3297498465552994
170
Else (feature 548 > 253.5)
171
Predict: -0.32974984655530015
172
Else (feature 434 > 88.5)
173
If (feature 349 <= 2.0)
174
Predict: -0.32974984655529926
175
Else (feature 349 > 2.0)
176
Predict: -0.3297498465552994
177
Tree 7 (weight 0.1):
178
If (feature 434 <= 88.5)
179
If (feature 568 <= 253.5)
180
If (feature 658 <= 252.5)
181
If (feature 631 <= 27.0)
182
Predict: 0.3103372455197956
183
Else (feature 631 > 27.0)
184
If (feature 209 <= 62.5)
185
Predict: 0.3103372455197956
186
Else (feature 209 > 62.5)
187
Predict: 0.3103372455197957
188
Else (feature 658 > 252.5)
189
Predict: 0.3103372455197958
190
Else (feature 568 > 253.5)
191
Predict: -0.31033724551979525
192
Else (feature 434 > 88.5)
193
If (feature 294 <= 31.5)
194
If (feature 184 <= 110.0)
195
Predict: -0.3103372455197956
196
Else (feature 184 > 110.0)
197
Predict: -0.3103372455197957
198
Else (feature 294 > 31.5)
199
If (feature 350 <= 172.5)
200
Predict: -0.3103372455197956
201
Else (feature 350 > 172.5)
202
Predict: -0.31033724551979563
203
Tree 8 (weight 0.1):
204
If (feature 434 <= 88.5)
205
If (feature 627 <= 2.5)
206
Predict: -0.2930291649125432
207
Else (feature 627 > 2.5)
208
Predict: 0.2930291649125433
209
Else (feature 434 > 88.5)
210
If (feature 379 <= 11.5)
211
Predict: -0.2930291649125433
212
Else (feature 379 > 11.5)
213
Predict: -0.2930291649125434
214
Tree 9 (weight 0.1):
215
If (feature 434 <= 88.5)
216
If (feature 243 <= 25.0)
217
Predict: -0.27750666438358235
218
Else (feature 243 > 25.0)
219
If (feature 244 <= 10.5)
220
Predict: 0.27750666438358246
221
Else (feature 244 > 10.5)
222
If (feature 263 <= 237.5)
223
If (feature 159 <= 10.0)
224
Predict: 0.27750666438358246
225
Else (feature 159 > 10.0)
226
Predict: 0.2775066643835826
227
Else (feature 263 > 237.5)
228
Predict: 0.27750666438358257
229
Else (feature 434 > 88.5)
230
Predict: -0.2775066643835825
231
​
232
​
233
*/
Copied!
Last modified 1yr ago
Copy link