Train-Validation Split
In addition to CrossValidator Spark also offers TrainValidationSplit for hyper-parameter tuning. TrainValidationSplit only evaluates each combination of parameters once, as opposed to k times in the case of CrossValidator. It is, therefore, less expensive, but will not produce as reliable results when the training dataset is not sufficiently large.
Unlike CrossValidator, TrainValidationSplit creates a single (training, test) dataset pair. It splits the dataset into these two parts using the trainRatio parameter. For example with trainRatio=0.75, TrainValidationSplit will generate a training and test dataset pair where 75% of the data is used for training and 25% for validation.
Like CrossValidator, TrainValidationSplit finally fits the Estimator using the best ParamMap and the entire dataset.
1
import org.apache.spark.ml.evaluation.RegressionEvaluator
2
import org.apache.spark.ml.regression.LinearRegression
3
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
4
​
5
// Prepare training and test data.
6
val data = spark.read.format("libsvm").load("file:///opt/spark/data/mllib/sample_linear_regression_data.txt")
7
val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345)
8
​
9
val lr = new LinearRegression()
10
.setMaxIter(10)
11
​
12
// We use a ParamGridBuilder to construct a grid of parameters to search over.
13
// TrainValidationSplit will try all combinations of values and determine best model using
14
// the evaluator.
15
val paramGrid = new ParamGridBuilder()
16
.addGrid(lr.regParam, Array(0.1, 0.01))
17
.addGrid(lr.fitIntercept)
18
.addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0))
19
.build()
20
​
21
// In this case the estimator is simply the linear regression.
22
// A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
23
val trainValidationSplit = new TrainValidationSplit()
24
.setEstimator(lr)
25
.setEvaluator(new RegressionEvaluator)
26
.setEstimatorParamMaps(paramGrid)
27
// 80% of the data will be used for training and the remaining 20% for validation.
28
.setTrainRatio(0.8)
29
// Evaluate up to 2 parameter settings in parallel
30
.setParallelism(2)
31
​
32
// Run train validation split, and choose the best set of parameters.
33
val model = trainValidationSplit.fit(training)
34
​
35
// Make predictions on test data. model is the model with combination of parameters
36
// that performed best.
37
model.transform(test)
38
.select("features", "label", "prediction")
39
.show()
40
41
/*
42
Output:
43
+--------------------+--------------------+--------------------+
44
| features| label| prediction|
45
+--------------------+--------------------+--------------------+
46
|(10,[0,1,2,3,4,5,...| -23.51088409032297| -1.6659388625179559|
47
|(10,[0,1,2,3,4,5,...| -21.432387764165806| 0.3400877302576284|
48
|(10,[0,1,2,3,4,5,...| -12.977848725392104|-0.02335359093652395|
49
|(10,[0,1,2,3,4,5,...| -11.827072996392571| 2.5642684021108417|
50
|(10,[0,1,2,3,4,5,...| -10.945919657782932| -0.1631314487734783|
51
|(10,[0,1,2,3,4,5,...| -10.58331129986813| 2.517790654691453|
52
|(10,[0,1,2,3,4,5,...| -10.288657252388708| -0.9443474180536754|
53
|(10,[0,1,2,3,4,5,...| -8.822357870425154| 0.6872889429113783|
54
|(10,[0,1,2,3,4,5,...| -8.772667465932606| -1.485408580416465|
55
|(10,[0,1,2,3,4,5,...| -8.605713514762092| 1.110272909026478|
56
|(10,[0,1,2,3,4,5,...| -6.544633229269576| 3.0454559778611285|
57
|(10,[0,1,2,3,4,5,...| -5.055293333055445| 0.6441174575094268|
58
|(10,[0,1,2,3,4,5,...| -5.039628433467326| 0.9572366607107066|
59
|(10,[0,1,2,3,4,5,...| -4.937258492902948| 0.2292114538379546|
60
|(10,[0,1,2,3,4,5,...| -3.741044592262687| 3.343205816009816|
61
|(10,[0,1,2,3,4,5,...| -3.731112242951253| -2.6826413698701064|
62
|(10,[0,1,2,3,4,5,...| -2.109441044710089| -2.1930034039595445|
63
|(10,[0,1,2,3,4,5,...| -1.8722161156986976| 0.49547270330052423|
64
|(10,[0,1,2,3,4,5,...| -1.1009750789589774| -0.9441633113006601|
65
|(10,[0,1,2,3,4,5,...|-0.48115211266405217| -0.6756196573079968|
66
+--------------------+--------------------+--------------------+
67
only showing top 20 rows
68
​
69
*/
Copied!
Last modified 1yr ago
Copy link