Gradient-boosted tree regression
Gradient-boosted trees (GBTs) are a popular regression method using ensembles of decision trees. More information about the spark.ml implementation can be found further in the section on GBTs.
Examples
Note: For this example dataset, GBTRegressor actually only needs 1 iteration, but that will not be true in general.
1
import org.apache.spark.ml.Pipeline
2
import org.apache.spark.ml.evaluation.RegressionEvaluator
3
import org.apache.spark.ml.feature.VectorIndexer
4
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
5
​
6
// Load and parse the data file, converting it to a DataFrame.
7
val data = spark.read.format("libsvm").load("file:///opt/spark/data/mllib/sample_libsvm_data.txt")
8
​
9
// Automatically identify categorical features, and index them.
10
// Set maxCategories so features with > 4 distinct values are treated as continuous.
11
val featureIndexer = new VectorIndexer()
12
.setInputCol("features")
13
.setOutputCol("indexedFeatures")
14
.setMaxCategories(4)
15
.fit(data)
16
​
17
// Split the data into training and test sets (30% held out for testing).
18
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
19
​
20
// Train a GBT model.
21
val gbt = new GBTRegressor()
22
.setLabelCol("label")
23
.setFeaturesCol("indexedFeatures")
24
.setMaxIter(10)
25
​
26
// Chain indexer and GBT in a Pipeline.
27
val pipeline = new Pipeline()
28
.setStages(Array(featureIndexer, gbt))
29
​
30
// Train model. This also runs the indexer.
31
val model = pipeline.fit(trainingData)
32
​
33
// Make predictions.
34
val predictions = model.transform(testData)
35
​
36
// Select example rows to display.
37
predictions.select("prediction", "label", "features").show(5)
38
​
39
// Select (prediction, true label) and compute test error.
40
val evaluator = new RegressionEvaluator()
41
.setLabelCol("label")
42
.setPredictionCol("prediction")
43
.setMetricName("rmse")
44
val rmse = evaluator.evaluate(predictions)
45
println(s"Root Mean Squared Error (RMSE) on test data = $rmse")
46
​
47
val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel]
48
println(s"Learned regression GBT model:\n ${gbtModel.toDebugString}")
49
​
50
/*
51
Output:
52
+----------+-----+--------------------+
53
|prediction|label| features|
54
+----------+-----+--------------------+
55
| 0.0| 0.0|(692,[100,101,102...|
56
| 0.0| 0.0|(692,[121,122,123...|
57
| 0.0| 0.0|(692,[122,123,148...|
58
| 0.0| 0.0|(692,[123,124,125...|
59
| 0.0| 0.0|(692,[124,125,126...|
60
+----------+-----+--------------------+
61
only showing top 5 rows
62
​
63
Root Mean Squared Error (RMSE) on test data = 0.0
64
Learned regression GBT model:
65
GBTRegressionModel (uid=gbtr_6fc160e3a65f) with 10 trees
66
Tree 0 (weight 1.0):
67
If (feature 434 <= 88.5)
68
If (feature 99 in {0.0,3.0})
69
Predict: 0.0
70
Else (feature 99 not in {0.0,3.0})
71
Predict: 1.0
72
Else (feature 434 > 88.5)
73
Predict: 1.0
74
Tree 1 (weight 0.1):
75
Predict: 0.0
76
Tree 2 (weight 0.1):
77
Predict: 0.0
78
Tree 3 (weight 0.1):
79
Predict: 0.0
80
Tree 4 (weight 0.1):
81
Predict: 0.0
82
Tree 5 (weight 0.1):
83
Predict: 0.0
84
Tree 6 (weight 0.1):
85
Predict: 0.0
86
Tree 7 (weight 0.1):
87
Predict: 0.0
88
Tree 8 (weight 0.1):
89
Predict: 0.0
90
Tree 9 (weight 0.1):
91
Predict: 0.0
92
​
93
​
94
*/
Copied!
Last modified 1yr ago
Copy link