Random Forest Regression

Random forests are a popular family of classification and regression methods. More information about the spark.ml implementation can be found further in the section on random forests.

Examples

The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use a feature transformer to index categorical features, adding metadata to the DataFrame which the tree-based algorithms can recognize.

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.VectorIndexer
import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor}

// Load and parse the data file, converting it to a DataFrame.
val data = spark.read.format("libsvm").load("file:///opt/spark/data/mllib/sample_libsvm_data.txt")

// Automatically identify categorical features, and index them.
// Set maxCategories so features with > 4 distinct values are treated as continuous.
val featureIndexer = new VectorIndexer()
  .setInputCol("features")
  .setOutputCol("indexedFeatures")
  .setMaxCategories(4)
  .fit(data)

// Split the data into training and test sets (30% held out for testing).
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))

// Train a RandomForest model.
val rf = new RandomForestRegressor()
  .setLabelCol("label")
  .setFeaturesCol("indexedFeatures")

// Chain indexer and forest in a Pipeline.
val pipeline = new Pipeline()
  .setStages(Array(featureIndexer, rf))

// Train model. This also runs the indexer.
val model = pipeline.fit(trainingData)

// Make predictions.
val predictions = model.transform(testData)

// Select example rows to display.
predictions.select("prediction", "label", "features").show(5)

// Select (prediction, true label) and compute test error.
val evaluator = new RegressionEvaluator()
  .setLabelCol("label")
  .setPredictionCol("prediction")
  .setMetricName("rmse")
val rmse = evaluator.evaluate(predictions)
println(s"Root Mean Squared Error (RMSE) on test data = $rmse")

val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel]
println(s"Learned regression forest model:\n ${rfModel.toDebugString}")

/*
Output:
+----------+-----+--------------------+
|prediction|label|            features|
+----------+-----+--------------------+
|       0.0|  0.0|(692,[98,99,100,1...|
|       0.1|  0.0|(692,[122,123,148...|
|       0.0|  0.0|(692,[123,124,125...|
|       0.0|  0.0|(692,[124,125,126...|
|       0.0|  0.0|(692,[124,125,126...|
+----------+-----+--------------------+
only showing top 5 rows

Root Mean Squared Error (RMSE) on test data = 0.14711511527663346
Learned regression forest model:
 RandomForestRegressionModel (uid=rfr_f01d85b28d5a) with 20 trees
  Tree 0 (weight 1.0):
    If (feature 434 <= 88.5)
     Predict: 0.0
    Else (feature 434 > 88.5)
     Predict: 1.0
  Tree 1 (weight 1.0):
    If (feature 490 <= 15.5)
     Predict: 0.0
    Else (feature 490 > 15.5)
     Predict: 1.0
  Tree 2 (weight 1.0):
    If (feature 462 <= 62.5)
     Predict: 0.0
    Else (feature 462 > 62.5)
     Predict: 1.0
  Tree 3 (weight 1.0):
    If (feature 461 <= 71.0)
     If (feature 343 <= 253.5)
      Predict: 0.0
     Else (feature 343 > 253.5)
      Predict: 1.0
    Else (feature 461 > 71.0)
     Predict: 1.0
  Tree 4 (weight 1.0):
    If (feature 483 <= 15.5)
     If (feature 318 <= 223.0)
      Predict: 1.0
     Else (feature 318 > 223.0)
      Predict: 0.0
    Else (feature 483 > 15.5)
     Predict: 0.0
  Tree 5 (weight 1.0):
    If (feature 405 <= 106.0)
     If (feature 490 <= 15.5)
      Predict: 0.0
     Else (feature 490 > 15.5)
      Predict: 1.0
    Else (feature 405 > 106.0)
     Predict: 1.0
  Tree 6 (weight 1.0):
    If (feature 490 <= 44.5)
     Predict: 0.0
    Else (feature 490 > 44.5)
     Predict: 1.0
  Tree 7 (weight 1.0):
    If (feature 400 <= 4.5)
     If (feature 375 <= 103.0)
      Predict: 1.0
     Else (feature 375 > 103.0)
      Predict: 0.0
    Else (feature 400 > 4.5)
     Predict: 0.0
  Tree 8 (weight 1.0):
    If (feature 406 <= 126.5)
     Predict: 0.0
    Else (feature 406 > 126.5)
     Predict: 1.0
  Tree 9 (weight 1.0):
    If (feature 490 <= 44.5)
     Predict: 0.0
    Else (feature 490 > 44.5)
     Predict: 1.0
  Tree 10 (weight 1.0):
    If (feature 345 <= 6.5)
     Predict: 1.0
    Else (feature 345 > 6.5)
     Predict: 0.0
  Tree 11 (weight 1.0):
    If (feature 406 <= 126.5)
     If (feature 436 <= 1.5)
      Predict: 0.0
     Else (feature 436 > 1.5)
      Predict: 1.0
    Else (feature 406 > 126.5)
     Predict: 1.0
  Tree 12 (weight 1.0):
    If (feature 489 <= 1.5)
     Predict: 0.0
    Else (feature 489 > 1.5)
     Predict: 1.0
  Tree 13 (weight 1.0):
    If (feature 462 <= 62.5)
     Predict: 0.0
    Else (feature 462 > 62.5)
     Predict: 1.0
  Tree 14 (weight 1.0):
    If (feature 435 <= 32.5)
     If (feature 488 <= 141.0)
      Predict: 0.0
     Else (feature 488 > 141.0)
      Predict: 1.0
    Else (feature 435 > 32.5)
     Predict: 1.0
  Tree 15 (weight 1.0):
    If (feature 489 <= 1.5)
     If (feature 519 <= 146.0)
      Predict: 0.0
     Else (feature 519 > 146.0)
      Predict: 1.0
    Else (feature 489 > 1.5)
     Predict: 1.0
  Tree 16 (weight 1.0):
    If (feature 434 <= 88.5)
     Predict: 0.0
    Else (feature 434 > 88.5)
     Predict: 1.0
  Tree 17 (weight 1.0):
    If (feature 378 <= 18.0)
     Predict: 0.0
    Else (feature 378 > 18.0)
     Predict: 1.0
  Tree 18 (weight 1.0):
    If (feature 434 <= 88.5)
     Predict: 0.0
    Else (feature 434 > 88.5)
     Predict: 1.0
  Tree 19 (weight 1.0):
    If (feature 490 <= 44.5)
     Predict: 0.0
    Else (feature 490 > 44.5)
     Predict: 1.0



*/

Last updated