Random forests

Random forests are a popular family of classification and regression methods. More information about the spark.ml implementation can be found further in the section on random forests.

Examples

The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use two feature transformers to prepare the data; these help index categories for the label and categorical features, adding metadata to the DataFrame which the tree-based algorithms can recognize.

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.{RandomForestClassificationModel, RandomForestClassifier}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer}

// Load and parse the data file, converting it to a DataFrame.
val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")

// Index labels, adding metadata to the label column.
// Fit on whole dataset to include all labels in index.
val labelIndexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("indexedLabel")
  .fit(data)
// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
val featureIndexer = new VectorIndexer()
  .setInputCol("features")
  .setOutputCol("indexedFeatures")
  .setMaxCategories(4)
  .fit(data)

// Split the data into training and test sets (30% held out for testing).
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

// Train a RandomForest model.
val rf = new RandomForestClassifier()
  .setLabelCol("indexedLabel")
  .setFeaturesCol("indexedFeatures")
  .setNumTrees(10)

// Convert indexed labels back to original labels.
val labelConverter = new IndexToString()
  .setInputCol("prediction")
  .setOutputCol("predictedLabel")
  .setLabels(labelIndexer.labels)

// Chain indexers and forest in a Pipeline.
val pipeline = new Pipeline()
  .setStages(Array(labelIndexer, featureIndexer, rf, labelConverter))

// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)

// Make predictions.
val predictions = model.transform(testData)

// Select example rows to display.
predictions.select("predictedLabel", "label", "features").show(5)

// Select (prediction, true label) and compute test error.
val evaluator = new MulticlassClassificationEvaluator()
  .setLabelCol("indexedLabel")
  .setPredictionCol("prediction")
  .setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println(s"Test Error = ${(1.0 - accuracy)}")

val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel]
println(s"Learned classification forest model:\n ${rfModel.toDebugString}")

/*
Output:
+--------------+-----+--------------------+
|predictedLabel|label|            features|
+--------------+-----+--------------------+
|           0.0|  0.0|(692,[95,96,97,12...|
|           0.0|  0.0|(692,[98,99,100,1...|
|           0.0|  0.0|(692,[123,124,125...|
|           0.0|  0.0|(692,[124,125,126...|
|           0.0|  0.0|(692,[124,125,126...|
+--------------+-----+--------------------+
only showing top 5 rows

Test Error = 0.0
Learned classification forest model:
 RandomForestClassificationModel (uid=rfc_b10be44ab730) with 10 trees
  Tree 0 (weight 1.0):
    If (feature 510 <= 2.5)
     If (feature 495 <= 111.0)
      Predict: 0.0
     Else (feature 495 > 111.0)
      Predict: 1.0
    Else (feature 510 > 2.5)
     Predict: 1.0
  Tree 1 (weight 1.0):
    If (feature 567 <= 8.0)
     If (feature 456 <= 31.5)
      If (feature 373 <= 11.5)
       Predict: 0.0
      Else (feature 373 > 11.5)
       Predict: 1.0
     Else (feature 456 > 31.5)
      Predict: 1.0
    Else (feature 567 > 8.0)
     If (feature 317 <= 9.5)
      Predict: 0.0
     Else (feature 317 > 9.5)
      If (feature 491 <= 49.5)
       Predict: 1.0
      Else (feature 491 > 49.5)
       Predict: 0.0
  Tree 2 (weight 1.0):
    If (feature 540 <= 87.0)
     If (feature 576 <= 221.5)
      Predict: 0.0
     Else (feature 576 > 221.5)
      If (feature 490 <= 15.5)
       Predict: 1.0
      Else (feature 490 > 15.5)
       Predict: 0.0
    Else (feature 540 > 87.0)
     Predict: 1.0
  Tree 3 (weight 1.0):
    If (feature 518 <= 18.0)
     If (feature 350 <= 97.5)
      Predict: 1.0
     Else (feature 350 > 97.5)
      If (feature 356 <= 16.0)
       Predict: 0.0
      Else (feature 356 > 16.0)
       Predict: 1.0
    Else (feature 518 > 18.0)
     Predict: 0.0
  Tree 4 (weight 1.0):
    If (feature 429 <= 11.5)
     If (feature 358 <= 12.0)
      Predict: 0.0
     Else (feature 358 > 12.0)
      Predict: 1.0
    Else (feature 429 > 11.5)
     Predict: 1.0
  Tree 5 (weight 1.0):
    If (feature 462 <= 62.5)
     If (feature 240 <= 253.5)
      Predict: 1.0
     Else (feature 240 > 253.5)
      Predict: 0.0
    Else (feature 462 > 62.5)
     Predict: 0.0
  Tree 6 (weight 1.0):
    If (feature 385 <= 4.0)
     If (feature 545 <= 3.0)
      If (feature 346 <= 2.0)
       Predict: 0.0
      Else (feature 346 > 2.0)
       Predict: 1.0
     Else (feature 545 > 3.0)
      Predict: 0.0
    Else (feature 385 > 4.0)
     Predict: 1.0
  Tree 7 (weight 1.0):
    If (feature 512 <= 8.0)
     If (feature 350 <= 7.0)
      If (feature 298 <= 152.5)
       Predict: 1.0
      Else (feature 298 > 152.5)
       Predict: 0.0
     Else (feature 350 > 7.0)
      Predict: 0.0
    Else (feature 512 > 8.0)
     Predict: 1.0
  Tree 8 (weight 1.0):
    If (feature 462 <= 62.5)
     If (feature 324 <= 253.5)
      Predict: 1.0
     Else (feature 324 > 253.5)
      Predict: 0.0
    Else (feature 462 > 62.5)
     Predict: 0.0
  Tree 9 (weight 1.0):
    If (feature 301 <= 30.0)
     If (feature 517 <= 20.5)
      If (feature 630 <= 5.0)
       Predict: 0.0
      Else (feature 630 > 5.0)
       Predict: 1.0
     Else (feature 517 > 20.5)
      Predict: 0.0
    Else (feature 301 > 30.0)
     Predict: 1.0


*/

Last updated