OneVsRest
import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest}
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
// load data file.
val inputData = spark.read.format("libsvm")
.load("file:///opt/spark/data/mllib/sample_libsvm_data.txt")
// generate the train/test split.
val Array(train, test) = inputData.randomSplit(Array(0.8, 0.2))
// instantiate the base classifier
val classifier = new LogisticRegression()
.setMaxIter(10)
.setTol(1E-6)
.setFitIntercept(true)
// instantiate the One Vs Rest Classifier.
val ovr = new OneVsRest().setClassifier(classifier)
// train the multiclass model.
val ovrModel = ovr.fit(train)
// score the model on test data.
val predictions = ovrModel.transform(test)
// obtain evaluator.
val evaluator = new MulticlassClassificationEvaluator()
.setMetricName("accuracy")
// compute the classification error on test data.
val accuracy = evaluator.evaluate(predictions)
println(s"Test Error = ${1 - accuracy}")
/*
Output:
Test Error = 0.0
*/Last updated