# Gradient-boosted trees (GBTs)

Gradient-boosted trees (GBTs) are a popular classification and regression method using ensembles of decision trees. More information about the spark.ml implementation can be found further in the section on GBTs.

Examples

The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the DataFrame which the tree-based algorithms can recognize.

```
import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}
import org.apache.spark.ml.Pipeline

// Load and parse the data file, converting it to a DataFrame.
val data = spark.read.format("libsvm").load("file:///opt/spark/data/mllib/sample_libsvm_data.txt")

// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
val labelIndexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("indexedLabel")
  .fit(data)
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
val featureIndexer = new VectorIndexer()
  .setInputCol("features")
  .setOutputCol("indexedFeatures")
  .setMaxCategories(4)
  .fit(data)

// Split the data into training and test sets (30% held out for testing).
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

// Train a GBT model.
val gbt = new GBTClassifier()
  .setLabelCol("indexedLabel")
  .setFeaturesCol("indexedFeatures")
  .setMaxIter(10)
  .setFeatureSubsetStrategy("auto")

// Convert indexed labels back to original labels.
val labelConverter = new IndexToString()
  .setInputCol("prediction")
  .setOutputCol("predictedLabel")
  .setLabels(labelIndexer.labels)

// Chain indexers and GBT in a Pipeline.
val pipeline = new Pipeline()
  .setStages(Array(labelIndexer, featureIndexer, gbt, labelConverter))

// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)

// Make predictions.
val predictions = model.transform(testData)

// Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(5)

// Select (prediction, true label) and compute test error.
val evaluator = new MulticlassClassificationEvaluator()
  .setLabelCol("indexedLabel")
  .setPredictionCol("prediction")
  .setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println(s"Test Error = ${1.0 - accuracy}")

val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel]
println(s"Learned classification GBT model:\n ${gbtModel.toDebugString}")

/*
Output:
+--------------+-----+--------------------+
|predictedLabel|label|            features|
+--------------+-----+--------------------+
|           0.0|  0.0|(692,[123,124,125...|
|           0.0|  0.0|(692,[124,125,126...|
|           0.0|  0.0|(692,[124,125,126...|
|           0.0|  0.0|(692,[125,126,127...|
|           0.0|  0.0|(692,[126,127,128...|
+--------------+-----+--------------------+
only showing top 5 rows

Test Error = 0.0
Learned classification GBT model:
 GBTClassificationModel (uid=gbtc_ef6c4e8f1ddc) with 10 trees
  Tree 0 (weight 1.0):
    If (feature 434 <= 88.5)
     If (feature 99 in {2.0})
      Predict: -1.0
     Else (feature 99 not in {2.0})
      Predict: 1.0
    Else (feature 434 > 88.5)
     Predict: -1.0
  Tree 1 (weight 0.1):
    If (feature 434 <= 88.5)
     If (feature 549 <= 253.5)
      If (feature 400 <= 159.5)
       Predict: 0.4768116880884702
      Else (feature 400 > 159.5)
       Predict: 0.4768116880884703
     Else (feature 549 > 253.5)
      Predict: -0.4768116880884694
    Else (feature 434 > 88.5)
     If (feature 267 <= 254.5)
      Predict: -0.47681168808847024
     Else (feature 267 > 254.5)
      Predict: -0.4768116880884712
  Tree 2 (weight 0.1):
    If (feature 434 <= 88.5)
     If (feature 243 <= 25.0)
      Predict: -0.4381935810427206
     Else (feature 243 > 25.0)
      If (feature 182 <= 32.0)
       Predict: 0.4381935810427206
      Else (feature 182 > 32.0)
       If (feature 154 <= 9.5)
        Predict: 0.4381935810427206
       Else (feature 154 > 9.5)
        Predict: 0.43819358104272066
    Else (feature 434 > 88.5)
     If (feature 461 <= 66.5)
      Predict: -0.4381935810427206
     Else (feature 461 > 66.5)
      Predict: -0.43819358104272066
  Tree 3 (weight 0.1):
    If (feature 462 <= 62.5)
     If (feature 549 <= 253.5)
      Predict: 0.4051496802845983
     Else (feature 549 > 253.5)
      Predict: -0.4051496802845982
    Else (feature 462 > 62.5)
     If (feature 433 <= 244.0)
      Predict: -0.4051496802845983
     Else (feature 433 > 244.0)
      Predict: -0.40514968028459836
  Tree 4 (weight 0.1):
    If (feature 462 <= 62.5)
     If (feature 100 <= 193.5)
      If (feature 235 <= 80.5)
       If (feature 183 <= 88.5)
        Predict: 0.3765841318352991
       Else (feature 183 > 88.5)
        If (feature 239 <= 9.0)
         Predict: 0.3765841318352991
        Else (feature 239 > 9.0)
         Predict: 0.37658413183529915
      Else (feature 235 > 80.5)
       Predict: 0.3765841318352994
     Else (feature 100 > 193.5)
      Predict: -0.3765841318352994
    Else (feature 462 > 62.5)
     If (feature 129 <= 58.0)
      If (feature 515 <= 88.0)
       Predict: -0.37658413183529915
      Else (feature 515 > 88.0)
       Predict: -0.3765841318352994
     Else (feature 129 > 58.0)
      Predict: -0.3765841318352994
  Tree 5 (weight 0.1):
    If (feature 462 <= 62.5)
     If (feature 293 <= 253.5)
      Predict: 0.35166478958101
     Else (feature 293 > 253.5)
      Predict: -0.3516647895810099
    Else (feature 462 > 62.5)
     If (feature 433 <= 244.0)
      Predict: -0.35166478958101005
     Else (feature 433 > 244.0)
      Predict: -0.3516647895810101
  Tree 6 (weight 0.1):
    If (feature 434 <= 88.5)
     If (feature 548 <= 253.5)
      If (feature 154 <= 24.0)
       Predict: 0.32974984655529926
      Else (feature 154 > 24.0)
       Predict: 0.3297498465552994
     Else (feature 548 > 253.5)
      Predict: -0.32974984655530015
    Else (feature 434 > 88.5)
     If (feature 349 <= 2.0)
      Predict: -0.32974984655529926
     Else (feature 349 > 2.0)
      Predict: -0.3297498465552994
  Tree 7 (weight 0.1):
    If (feature 434 <= 88.5)
     If (feature 568 <= 253.5)
      If (feature 658 <= 252.5)
       If (feature 631 <= 27.0)
        Predict: 0.3103372455197956
       Else (feature 631 > 27.0)
        If (feature 209 <= 62.5)
         Predict: 0.3103372455197956
        Else (feature 209 > 62.5)
         Predict: 0.3103372455197957
      Else (feature 658 > 252.5)
       Predict: 0.3103372455197958
     Else (feature 568 > 253.5)
      Predict: -0.31033724551979525
    Else (feature 434 > 88.5)
     If (feature 294 <= 31.5)
      If (feature 184 <= 110.0)
       Predict: -0.3103372455197956
      Else (feature 184 > 110.0)
       Predict: -0.3103372455197957
     Else (feature 294 > 31.5)
      If (feature 350 <= 172.5)
       Predict: -0.3103372455197956
      Else (feature 350 > 172.5)
       Predict: -0.31033724551979563
  Tree 8 (weight 0.1):
    If (feature 434 <= 88.5)
     If (feature 627 <= 2.5)
      Predict: -0.2930291649125432
     Else (feature 627 > 2.5)
      Predict: 0.2930291649125433
    Else (feature 434 > 88.5)
     If (feature 379 <= 11.5)
      Predict: -0.2930291649125433
     Else (feature 379 > 11.5)
      Predict: -0.2930291649125434
  Tree 9 (weight 0.1):
    If (feature 434 <= 88.5)
     If (feature 243 <= 25.0)
      Predict: -0.27750666438358235
     Else (feature 243 > 25.0)
      If (feature 244 <= 10.5)
       Predict: 0.27750666438358246
      Else (feature 244 > 10.5)
       If (feature 263 <= 237.5)
        If (feature 159 <= 10.0)
         Predict: 0.27750666438358246
        Else (feature 159 > 10.0)
         Predict: 0.2775066643835826
       Else (feature 263 > 237.5)
        Predict: 0.27750666438358257
    Else (feature 434 > 88.5)
     Predict: -0.2775066643835825


*/
```


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://george-jen.gitbook.io/data-science-and-apache-spark/gradient-boosted-trees-gbts.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
