Gradient-boosted tree regression

Gradient-boosted trees (GBTs) are a popular regression method using ensembles of decision trees. More information about the spark.ml implementation can be found further in the section on GBTs.

Examples

Note: For this example dataset, GBTRegressor actually only needs 1 iteration, but that will not be true in general.

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.VectorIndexer
import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor}

// Load and parse the data file, converting it to a DataFrame.
val data = spark.read.format("libsvm").load("file:///opt/spark/data/mllib/sample_libsvm_data.txt")

// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
val featureIndexer = new VectorIndexer()
  .setInputCol("features")
  .setOutputCol("indexedFeatures")
  .setMaxCategories(4)
  .fit(data)

// Split the data into training and test sets (30% held out for testing).
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

// Train a GBT model.
val gbt = new GBTRegressor()
  .setLabelCol("label")
  .setFeaturesCol("indexedFeatures")
  .setMaxIter(10)

// Chain indexer and GBT in a Pipeline.
val pipeline = new Pipeline()
  .setStages(Array(featureIndexer, gbt))

// Train model. This also runs the indexer.
val model = pipeline.fit(trainingData)

// Make predictions.
val predictions = model.transform(testData)

// Select example rows to display.
predictions.select("prediction", "label", "features").show(5)

// Select (prediction, true label) and compute test error.
val evaluator = new RegressionEvaluator()
  .setLabelCol("label")
  .setPredictionCol("prediction")
  .setMetricName("rmse")
val rmse = evaluator.evaluate(predictions)
println(s"Root Mean Squared Error (RMSE) on test data = $rmse")

val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel]
println(s"Learned regression GBT model:\n ${gbtModel.toDebugString}")

/*
Output:
+----------+-----+--------------------+
|prediction|label|            features|
+----------+-----+--------------------+
|       0.0|  0.0|(692,[100,101,102...|
|       0.0|  0.0|(692,[121,122,123...|
|       0.0|  0.0|(692,[122,123,148...|
|       0.0|  0.0|(692,[123,124,125...|
|       0.0|  0.0|(692,[124,125,126...|
+----------+-----+--------------------+
only showing top 5 rows

Root Mean Squared Error (RMSE) on test data = 0.0
Learned regression GBT model:
 GBTRegressionModel (uid=gbtr_6fc160e3a65f) with 10 trees
  Tree 0 (weight 1.0):
    If (feature 434 <= 88.5)
     If (feature 99 in {0.0,3.0})
      Predict: 0.0
     Else (feature 99 not in {0.0,3.0})
      Predict: 1.0
    Else (feature 434 > 88.5)
     Predict: 1.0
  Tree 1 (weight 0.1):
    Predict: 0.0
  Tree 2 (weight 0.1):
    Predict: 0.0
  Tree 3 (weight 0.1):
    Predict: 0.0
  Tree 4 (weight 0.1):
    Predict: 0.0
  Tree 5 (weight 0.1):
    Predict: 0.0
  Tree 6 (weight 0.1):
    Predict: 0.0
  Tree 7 (weight 0.1):
    Predict: 0.0
  Tree 8 (weight 0.1):
    Predict: 0.0
  Tree 9 (weight 0.1):
    Predict: 0.0


*/

Last updated