Train-Validation Split

In addition to CrossValidator Spark also offers TrainValidationSplit for hyper-parameter tuning. TrainValidationSplit only evaluates each combination of parameters once, as opposed to k times in the case of CrossValidator. It is, therefore, less expensive, but will not produce as reliable results when the training dataset is not sufficiently large.

Unlike CrossValidator, TrainValidationSplit creates a single (training, test) dataset pair. It splits the dataset into these two parts using the trainRatio parameter. For example with trainRatio=0.75, TrainValidationSplit will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation.

Like CrossValidator, TrainValidationSplit finally fits the Estimator using the best ParamMap and the entire dataset.

import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}

// Prepare training and test data.
val data = spark.read.format("libsvm").load("file:///opt/spark/data/mllib/sample_linear_regression_data.txt")
val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345)

val lr = new LinearRegression()
    .setMaxIter(10)

// We use a ParamGridBuilder to construct a grid of parameters to search over.
// TrainValidationSplit will try all combinations of values and determine best model using
// the evaluator.
val paramGrid = new ParamGridBuilder()
  .addGrid(lr.regParam, Array(0.1, 0.01))
  .addGrid(lr.fitIntercept)
  .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
  .build()

// In this case the estimator is simply the linear regression.
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
val trainValidationSplit = new TrainValidationSplit()
  .setEstimator(lr)
  .setEvaluator(new RegressionEvaluator)
  .setEstimatorParamMaps(paramGrid)
  // 80% of the data will be used for training and the remaining 20% for validation.
  .setTrainRatio(0.8)
  // Evaluate up to 2 parameter settings in parallel
  .setParallelism(2)

// Run train validation split, and choose the best set of parameters.
val model = trainValidationSplit.fit(training)

// Make predictions on test data. model is the model with combination of parameters
// that performed best.
model.transform(test)
  .select("features", "label", "prediction")
  .show()
  
/*
Output:
+--------------------+--------------------+--------------------+
|            features|               label|          prediction|
+--------------------+--------------------+--------------------+
|(10,[0,1,2,3,4,5,...|  -23.51088409032297| -1.6659388625179559|
|(10,[0,1,2,3,4,5,...| -21.432387764165806|  0.3400877302576284|
|(10,[0,1,2,3,4,5,...| -12.977848725392104|-0.02335359093652395|
|(10,[0,1,2,3,4,5,...| -11.827072996392571|  2.5642684021108417|
|(10,[0,1,2,3,4,5,...| -10.945919657782932| -0.1631314487734783|
|(10,[0,1,2,3,4,5,...|  -10.58331129986813|   2.517790654691453|
|(10,[0,1,2,3,4,5,...| -10.288657252388708| -0.9443474180536754|
|(10,[0,1,2,3,4,5,...|  -8.822357870425154|  0.6872889429113783|
|(10,[0,1,2,3,4,5,...|  -8.772667465932606|  -1.485408580416465|
|(10,[0,1,2,3,4,5,...|  -8.605713514762092|   1.110272909026478|
|(10,[0,1,2,3,4,5,...|  -6.544633229269576|  3.0454559778611285|
|(10,[0,1,2,3,4,5,...|  -5.055293333055445|  0.6441174575094268|
|(10,[0,1,2,3,4,5,...|  -5.039628433467326|  0.9572366607107066|
|(10,[0,1,2,3,4,5,...|  -4.937258492902948|  0.2292114538379546|
|(10,[0,1,2,3,4,5,...|  -3.741044592262687|   3.343205816009816|
|(10,[0,1,2,3,4,5,...|  -3.731112242951253| -2.6826413698701064|
|(10,[0,1,2,3,4,5,...|  -2.109441044710089| -2.1930034039595445|
|(10,[0,1,2,3,4,5,...| -1.8722161156986976| 0.49547270330052423|
|(10,[0,1,2,3,4,5,...| -1.1009750789589774| -0.9441633113006601|
|(10,[0,1,2,3,4,5,...|-0.48115211266405217| -0.6756196573079968|
+--------------------+--------------------+--------------------+
only showing top 20 rows

*/

Last updated