Gradient-boosted trees (GBTs)

Gradient-boosted trees (GBTs) are a popular classification and regression method using ensembles of decision trees. More information about the spark.ml implementation can be found further in the section on GBTs.

Examples

The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the DataFrame which the tree-based algorithms can recognize.

import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.ml.Pipeline

// Load and parse the data file, converting it to a DataFrame.
val data = spark.read.format("libsvm").load("file:///opt/spark/data/mllib/sample_libsvm_data.txt")

// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
val labelIndexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("indexedLabel")
  .fit(data)
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
val featureIndexer = new VectorIndexer()
  .setInputCol("features")
  .setOutputCol("indexedFeatures")
  .setMaxCategories(4)
  .fit(data)

// Split the data into training and test sets (30% held out for testing).
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

// Train a GBT model.
val gbt = new GBTClassifier()
  .setLabelCol("indexedLabel")
  .setFeaturesCol("indexedFeatures")
  .setMaxIter(10)
  .setFeatureSubsetStrategy("auto")

// Convert indexed labels back to original labels.
val labelConverter = new IndexToString()
  .setInputCol("prediction")
  .setOutputCol("predictedLabel")
  .setLabels(labelIndexer.labels)

// Chain indexers and GBT in a Pipeline.
val pipeline = new Pipeline()
  .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter))

// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)

// Make predictions.
val predictions = model.transform(testData)

// Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(5)

// Select (prediction, true label) and compute test error.
val evaluator = new MulticlassClassificationEvaluator()
  .setLabelCol("indexedLabel")
  .setPredictionCol("prediction")
  .setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println(s"Test Error = ${1.0 - accuracy}")

val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel]
println(s"Learned classification GBT model:\n ${gbtModel.toDebugString}")

/*
Output:
+--------------+-----+--------------------+
|predictedLabel|label|            features|
+--------------+-----+--------------------+
|           0.0|  0.0|(692,[123,124,125...|
|           0.0|  0.0|(692,[124,125,126...|
|           0.0|  0.0|(692,[124,125,126...|
|           0.0|  0.0|(692,[125,126,127...|
|           0.0|  0.0|(692,[126,127,128...|
+--------------+-----+--------------------+
only showing top 5 rows

Test Error = 0.0
Learned classification GBT model:
 GBTClassificationModel (uid=gbtc_ef6c4e8f1ddc) with 10 trees
  Tree 0 (weight 1.0):
    If (feature 434 <= 88.5)
     If (feature 99 in {2.0})
      Predict: -1.0
     Else (feature 99 not in {2.0})
      Predict: 1.0
    Else (feature 434 > 88.5)
     Predict: -1.0
  Tree 1 (weight 0.1):
    If (feature 434 <= 88.5)
     If (feature 549 <= 253.5)
      If (feature 400 <= 159.5)
       Predict: 0.4768116880884702
      Else (feature 400 > 159.5)
       Predict: 0.4768116880884703
     Else (feature 549 > 253.5)
      Predict: -0.4768116880884694
    Else (feature 434 > 88.5)
     If (feature 267 <= 254.5)
      Predict: -0.47681168808847024
     Else (feature 267 > 254.5)
      Predict: -0.4768116880884712
  Tree 2 (weight 0.1):
    If (feature 434 <= 88.5)
     If (feature 243 <= 25.0)
      Predict: -0.4381935810427206
     Else (feature 243 > 25.0)
      If (feature 182 <= 32.0)
       Predict: 0.4381935810427206
      Else (feature 182 > 32.0)
       If (feature 154 <= 9.5)
        Predict: 0.4381935810427206
       Else (feature 154 > 9.5)
        Predict: 0.43819358104272066
    Else (feature 434 > 88.5)
     If (feature 461 <= 66.5)
      Predict: -0.4381935810427206
     Else (feature 461 > 66.5)
      Predict: -0.43819358104272066
  Tree 3 (weight 0.1):
    If (feature 462 <= 62.5)
     If (feature 549 <= 253.5)
      Predict: 0.4051496802845983
     Else (feature 549 > 253.5)
      Predict: -0.4051496802845982
    Else (feature 462 > 62.5)
     If (feature 433 <= 244.0)
      Predict: -0.4051496802845983
     Else (feature 433 > 244.0)
      Predict: -0.40514968028459836
  Tree 4 (weight 0.1):
    If (feature 462 <= 62.5)
     If (feature 100 <= 193.5)
      If (feature 235 <= 80.5)
       If (feature 183 <= 88.5)
        Predict: 0.3765841318352991
       Else (feature 183 > 88.5)
        If (feature 239 <= 9.0)
         Predict: 0.3765841318352991
        Else (feature 239 > 9.0)
         Predict: 0.37658413183529915
      Else (feature 235 > 80.5)
       Predict: 0.3765841318352994
     Else (feature 100 > 193.5)
      Predict: -0.3765841318352994
    Else (feature 462 > 62.5)
     If (feature 129 <= 58.0)
      If (feature 515 <= 88.0)
       Predict: -0.37658413183529915
      Else (feature 515 > 88.0)
       Predict: -0.3765841318352994
     Else (feature 129 > 58.0)
      Predict: -0.3765841318352994
  Tree 5 (weight 0.1):
    If (feature 462 <= 62.5)
     If (feature 293 <= 253.5)
      Predict: 0.35166478958101
     Else (feature 293 > 253.5)
      Predict: -0.3516647895810099
    Else (feature 462 > 62.5)
     If (feature 433 <= 244.0)
      Predict: -0.35166478958101005
     Else (feature 433 > 244.0)
      Predict: -0.3516647895810101
  Tree 6 (weight 0.1):
    If (feature 434 <= 88.5)
     If (feature 548 <= 253.5)
      If (feature 154 <= 24.0)
       Predict: 0.32974984655529926
      Else (feature 154 > 24.0)
       Predict: 0.3297498465552994
     Else (feature 548 > 253.5)
      Predict: -0.32974984655530015
    Else (feature 434 > 88.5)
     If (feature 349 <= 2.0)
      Predict: -0.32974984655529926
     Else (feature 349 > 2.0)
      Predict: -0.3297498465552994
  Tree 7 (weight 0.1):
    If (feature 434 <= 88.5)
     If (feature 568 <= 253.5)
      If (feature 658 <= 252.5)
       If (feature 631 <= 27.0)
        Predict: 0.3103372455197956
       Else (feature 631 > 27.0)
        If (feature 209 <= 62.5)
         Predict: 0.3103372455197956
        Else (feature 209 > 62.5)
         Predict: 0.3103372455197957
      Else (feature 658 > 252.5)
       Predict: 0.3103372455197958
     Else (feature 568 > 253.5)
      Predict: -0.31033724551979525
    Else (feature 434 > 88.5)
     If (feature 294 <= 31.5)
      If (feature 184 <= 110.0)
       Predict: -0.3103372455197956
      Else (feature 184 > 110.0)
       Predict: -0.3103372455197957
     Else (feature 294 > 31.5)
      If (feature 350 <= 172.5)
       Predict: -0.3103372455197956
      Else (feature 350 > 172.5)
       Predict: -0.31033724551979563
  Tree 8 (weight 0.1):
    If (feature 434 <= 88.5)
     If (feature 627 <= 2.5)
      Predict: -0.2930291649125432
     Else (feature 627 > 2.5)
      Predict: 0.2930291649125433
    Else (feature 434 > 88.5)
     If (feature 379 <= 11.5)
      Predict: -0.2930291649125433
     Else (feature 379 > 11.5)
      Predict: -0.2930291649125434
  Tree 9 (weight 0.1):
    If (feature 434 <= 88.5)
     If (feature 243 <= 25.0)
      Predict: -0.27750666438358235
     Else (feature 243 > 25.0)
      If (feature 244 <= 10.5)
       Predict: 0.27750666438358246
      Else (feature 244 > 10.5)
       If (feature 263 <= 237.5)
        If (feature 159 <= 10.0)
         Predict: 0.27750666438358246
        Else (feature 159 > 10.0)
         Predict: 0.2775066643835826
       Else (feature 263 > 237.5)
        Predict: 0.27750666438358257
    Else (feature 434 > 88.5)
     Predict: -0.2775066643835825


*/

Last updated