Gradient-boosted tree regression
Gradient-boosted trees (GBTs) are a popular regression method using ensembles of decision trees. More information about the spark.ml implementation can be found further in the section on GBTs.
Examples
Note: For this example dataset, GBTRegressor actually only needs 1 iteration, but that will not be true in general.
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.VectorIndexer
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}
// Load and parse the data file, converting it to a DataFrame.
val data = spark.read.format("libsvm").load("file:///opt/spark/data/mllib/sample_libsvm_data.txt")
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
val featureIndexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexedFeatures")
.setMaxCategories(4)
.fit(data)
// Split the data into training and test sets (30% held out for testing).
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
// Train a GBT model.
val gbt = new GBTRegressor()
.setLabelCol("label")
.setFeaturesCol("indexedFeatures")
.setMaxIter(10)
// Chain indexer and GBT in a Pipeline.
val pipeline = new Pipeline()
.setStages(Array(featureIndexer, gbt))
// Train model. This also runs the indexer.
val model = pipeline.fit(trainingData)
// Make predictions.
val predictions = model.transform(testData)
// Select example rows to display.
predictions.select("prediction", "label", "features").show(5)
// Select (prediction, true label) and compute test error.
val evaluator = new RegressionEvaluator()
.setLabelCol("label")
.setPredictionCol("prediction")
.setMetricName("rmse")
val rmse = evaluator.evaluate(predictions)
println(s"Root Mean Squared Error (RMSE) on test data = $rmse")
val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel]
println(s"Learned regression GBT model:\n ${gbtModel.toDebugString}")
/*
Output:
+----------+-----+--------------------+
|prediction|label| features|
+----------+-----+--------------------+
| 0.0| 0.0|(692,[100,101,102...|
| 0.0| 0.0|(692,[121,122,123...|
| 0.0| 0.0|(692,[122,123,148...|
| 0.0| 0.0|(692,[123,124,125...|
| 0.0| 0.0|(692,[124,125,126...|
+----------+-----+--------------------+
only showing top 5 rows
Root Mean Squared Error (RMSE) on test data = 0.0
Learned regression GBT model:
GBTRegressionModel (uid=gbtr_6fc160e3a65f) with 10 trees
Tree 0 (weight 1.0):
If (feature 434 <= 88.5)
If (feature 99 in {0.0,3.0})
Predict: 0.0
Else (feature 99 not in {0.0,3.0})
Predict: 1.0
Else (feature 434 > 88.5)
Predict: 1.0
Tree 1 (weight 0.1):
Predict: 0.0
Tree 2 (weight 0.1):
Predict: 0.0
Tree 3 (weight 0.1):
Predict: 0.0
Tree 4 (weight 0.1):
Predict: 0.0
Tree 5 (weight 0.1):
Predict: 0.0
Tree 6 (weight 0.1):
Predict: 0.0
Tree 7 (weight 0.1):
Predict: 0.0
Tree 8 (weight 0.1):
Predict: 0.0
Tree 9 (weight 0.1):
Predict: 0.0
*/
Last updated