Cross-Validation
CrossValidator begins by splitting the dataset into a set of folds which are used as separate training and test datasets. E.g., with k=3 folds, CrossValidator will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. To evaluate a particular ParamMap, CrossValidator computes the average evaluation metric for the 3 Models produced by fitting the Estimator on the 3 different (training, test) dataset pairs.
After identifying the best ParamMap, CrossValidator finally re-fits the Estimator using the best ParamMap and the entire dataset.
Examples: model selection via cross-validation
The following example demonstrates using CrossValidator to select from a grid of parameters.
Note that cross-validation over a grid of parameters is expensive. E.g., in the example below, the parameter grid has 3 values for hashingTF.numFeatures and 2 values for lr.regParam, and CrossValidator uses 2 folds. This multiplies out to (3Γ—2)Γ—2=12 different models being trained. In realistic settings, it can be common to try many more parameters and use more folds (k=3 and k=10 are common). In other words, using CrossValidator can be very expensive. However, it is also a well-established method for choosing parameters which is more statistically sound than heuristic hand-tuning.
1
import org.apache.spark.ml.Pipeline
2
import org.apache.spark.ml.classification.LogisticRegression
3
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
4
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
5
import org.apache.spark.ml.linalg.Vector
6
import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder}
7
import org.apache.spark.sql.Row
8
​
9
// Prepare training data from a list of (id, text, label) tuples.
10
val training = spark.createDataFrame(Seq(
11
(0L, "a b c d e spark", 1.0),
12
(1L, "b d", 0.0),
13
(2L, "spark f g h", 1.0),
14
(3L, "hadoop mapreduce", 0.0),
15
(4L, "b spark who", 1.0),
16
(5L, "g d a y", 0.0),
17
(6L, "spark fly", 1.0),
18
(7L, "was mapreduce", 0.0),
19
(8L, "e spark program", 1.0),
20
(9L, "a e c l", 0.0),
21
(10L, "spark compile", 1.0),
22
(11L, "hadoop software", 0.0)
23
)).toDF("id", "text", "label")
24
​
25
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
26
val tokenizer = new Tokenizer()
27
.setInputCol("text")
28
.setOutputCol("words")
29
val hashingTF = new HashingTF()
30
.setInputCol(tokenizer.getOutputCol)
31
.setOutputCol("features")
32
val lr = new LogisticRegression()
33
.setMaxIter(10)
34
val pipeline = new Pipeline()
35
.setStages(Array(tokenizer, hashingTF, lr))
36
​
37
// We use a ParamGridBuilder to construct a grid of parameters to search over.
38
// With 3 values for hashingTF.numFeatures and 2 values for lr.regParam,
39
// this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from.
40
val paramGrid = new ParamGridBuilder()
41
.addGrid(hashingTF.numFeatures, Array(10, 100, 1000))
42
.addGrid(lr.regParam, Array(0.1, 0.01))
43
.build()
44
​
45
// We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance.
46
// This will allow us to jointly choose parameters for all Pipeline stages.
47
// A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator.
48
// Note that the evaluator here is a BinaryClassificationEvaluator and its default metric
49
// is areaUnderROC.
50
val cv = new CrossValidator()
51
.setEstimator(pipeline)
52
.setEvaluator(new BinaryClassificationEvaluator)
53
.setEstimatorParamMaps(paramGrid)
54
.setNumFolds(2) // Use 3+ in practice
55
.setParallelism(2) // Evaluate up to 2 parameter settings in parallel
56
​
57
// Run cross-validation, and choose the best set of parameters.
58
val cvModel = cv.fit(training)
59
​
60
// Prepare test documents, which are unlabeled (id, text) tuples.
61
val test = spark.createDataFrame(Seq(
62
(4L, "spark i j k"),
63
(5L, "l m n"),
64
(6L, "mapreduce spark"),
65
(7L, "apache hadoop")
66
)).toDF("id", "text")
67
​
68
// Make predictions on test documents. cvModel uses the best model found (lrModel).
69
cvModel.transform(test)
70
.select("id", "text", "probability", "prediction")
71
.collect()
72
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
73
println(s"($id, $text) --> prob=$prob, prediction=$prediction")
74
}
75
/*
76
Output:
77
(4, spark i j k) --> prob=[0.1256626071135742,0.8743373928864258], prediction=1.0
78
(5, l m n) --> prob=[0.995215441016286,0.004784558983714038], prediction=0.0
79
(6, mapreduce spark) --> prob=[0.30696895232626586,0.6930310476737341], prediction=1.0
80
(7, apache hadoop) --> prob=[0.8040279442401462,0.1959720557598538], prediction=0.0
81
82
*/
Copied!
Last modified 1yr ago
Copy link