Logistic regression is a popular method to predict a categorical response. It is a special case of Generalized Linear models that predicts the probability of the outcomes. In spark.ml logistic regression can be used to predict a binary outcome by using binomial logistic regression, or it can be used to predict a multiclass outcome by using multinomial logistic regression. Use the family parameter to select between these two algorithms, or leave it unset and Spark will infer the correct variant.
Multinomial logistic regression can be used for binary classification by setting the family param to “multinomial”. It will produce two sets of coefficients and two intercepts.
When fitting LogisticRegressionModel without intercept on dataset with constant nonzero column, Spark MLlib outputs zero coefficients for constant nonzero columns. This behavior is the same as R glmnet but different from LIBSVM.
Binomial logistic regression
For more background and more details about the implementation of binomial logistic regression, refer to the documentation of logistic regression in spark.mllib.
Examples
The following example shows how to train binomial and multinomial logistic regression models for binary classification with elastic net regularization. elasticNetParam corresponds to α and regParam corresponds to λ.
import org.apache.spark.ml.classification.LogisticRegression
// Load training data
val training = spark.read.format("libsvm").load("file:///opt/spark/data/mllib/sample_libsvm_data.txt")
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8)
// Fit the model
val lrModel = lr.fit(training)
// Print the coefficients and intercept for logistic regression
println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
// We can also use the multinomial family for binary classification
val mlr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8)
.setFamily("multinomial")
val mlrModel = mlr.fit(training)
// Print the coefficients and intercepts for logistic regression with multinomial family
println(s"Multinomial coefficients: ${mlrModel.coefficientMatrix}")
println(s"Multinomial intercepts: ${mlrModel.interceptVector}")
/*
Output:
Coefficients: (692,[244,263,272,300,301,328,350,351,378,379,405,406,407,428,433,434,455,456,461,462,483,484,489,490,496,511,512,517,539,540,568],[-7.353983524188197E-5,-9.102738505589466E-5,-1.9467430546904298E-4,-2.0300642473486668E-4,-3.1476183314863995E-5,-6.842977602660743E-5,1.5883626898239883E-5,1.4023497091372047E-5,3.5432047524968605E-4,1.1443272898171087E-4,1.0016712383666666E-4,6.014109303795481E-4,2.840248179122762E-4,-1.1541084736508837E-4,3.85996886312906E-4,6.35019557424107E-4,-1.1506412384575676E-4,-1.5271865864986808E-4,2.804933808994214E-4,6.070117471191634E-4,-2.008459663247437E-4,-1.421075579290126E-4,2.739010341160883E-4,2.7730456244968115E-4,-9.838027027269332E-5,-3.808522443517704E-4,-2.5315198008555033E-4,2.7747714770754307E-4,-2.443619763919199E-4,-0.0015394744687597765,-2.3073328411331293E-4]) Intercept: 0.22456315961250325
Multinomial coefficients: 2 x 692 CSCMatrix
(0,244) 4.290365458958277E-5
(1,244) -4.290365458958294E-5
(0,263) 6.488313287833108E-5
(1,263) -6.488313287833092E-5
(0,272) 1.2140666790834663E-4
(1,272) -1.2140666790834657E-4
(0,300) 1.3231861518665612E-4
(1,300) -1.3231861518665607E-4
(0,350) -6.775444746760509E-7
(1,350) 6.775444746761932E-7
(0,351) -4.899237909429297E-7
(1,351) 4.899237909430322E-7
(0,378) -3.5812102770679596E-5
(1,378) 3.581210277067968E-5
(0,379) -2.3539704331222065E-5
(1,379) 2.353970433122204E-5
(0,405) -1.90295199030314E-5
(1,405) 1.90295199030314E-5
(0,406) -5.626696935778909E-4
(1,406) 5.626696935778912E-4
(0,407) -5.121519619099504E-5
(1,407) 5.1215196190995074E-5
(0,428) 8.080614545413342E-5
(1,428) -8.080614545413331E-5
(0,433) -4.256734915330487E-5
(1,433) 4.256734915330495E-5
(0,434) -7.080191510151425E-4
(1,434) 7.080191510151435E-4
(0,455) 8.094482475733589E-5
(1,455) -8.094482475733582E-5
(0,456) 1.0433687128309833E-4
(1,456) -1.0433687128309814E-4
(0,461) -5.4466605046259246E-5
(1,461) 5.4466605046259286E-5
(0,462) -5.667133061990392E-4
(1,462) 5.667133061990392E-4
(0,483) 1.2495896045528374E-4
(1,483) -1.249589604552838E-4
(0,484) 9.810519424784944E-5
(1,484) -9.810519424784941E-5
(0,489) -4.88440907254626E-5
(1,489) 4.8844090725462606E-5
(0,490) -4.324392733454803E-5
(1,490) 4.324392733454811E-5
(0,496) 6.903351855620161E-5
(1,496) -6.90335185562012E-5
(0,511) 3.946505594172827E-4
(1,511) -3.946505594172831E-4
(0,512) 2.621745995919226E-4
(1,512) -2.621745995919226E-4
(0,517) -4.459475951170906E-5
(1,517) 4.459475951170901E-5
(0,539) 2.5417562428184555E-4
(1,539) -2.5417562428184555E-4
(0,540) 5.271781246228031E-4
(1,540) -5.271781246228032E-4
(0,568) 1.860255150352447E-4
(1,568) -1.8602551503524485E-4
Multinomial intercepts: [-0.12065879445860686,0.12065879445860686]
*/
The spark.ml implementation of logistic regression also supports extracting a summary of the model over the training set. Note that the predictions and metrics which are stored as DataFrame in LogisticRegressionSummary are annotated @transient and hence only available on the driver.
//Continue from code above
// Obtain the objective per iteration.
val objectiveHistory = trainingSummary.objectiveHistory
println("objectiveHistory:")
objectiveHistory.foreach(loss => println(loss))
// Obtain the receiver-operating characteristic as a dataframe and areaUnderROC.
val roc = trainingSummary.roc
roc.show()
println(s"areaUnderROC: ${trainingSummary.areaUnderROC}")
// Set the model threshold to maximize F-Measure
val fMeasure = trainingSummary.fMeasureByThreshold
val maxFMeasure = fMeasure.select(max("F-Measure")).head().getDouble(0)
val bestThreshold = fMeasure.where($"F-Measure" === maxFMeasure)
.select("threshold").head().getDouble(0)
lrModel.setThreshold(bestThreshold)
/*
Output:
+---+--------------------+
|FPR| TPR|
+---+--------------------+
|0.0| 0.0|
|0.0|0.017543859649122806|
|0.0| 0.03508771929824561|
|0.0| 0.05263157894736842|
|0.0| 0.07017543859649122|
|0.0| 0.08771929824561403|
|0.0| 0.10526315789473684|
|0.0| 0.12280701754385964|
|0.0| 0.14035087719298245|
|0.0| 0.15789473684210525|
|0.0| 0.17543859649122806|
|0.0| 0.19298245614035087|
|0.0| 0.21052631578947367|
|0.0| 0.22807017543859648|
|0.0| 0.24561403508771928|
|0.0| 0.2631578947368421|
|0.0| 0.2807017543859649|
|0.0| 0.2982456140350877|
|0.0| 0.3157894736842105|
|0.0| 0.3333333333333333|
+---+--------------------+
only showing top 20 rows
areaUnderROC: 1.0
*/
Multinomial logistic regression
Multiclass classification is supported via multinomial logistic (softmax) regression. In multinomial logistic regression, the algorithm produces K sets of coefficients, or a matrix of dimension K×J where K is the number of outcome classes and J is the number of features. If the algorithm is fit with an intercept term then a length K vector of intercepts is available.
Multinomial coefficients are available as coefficientMatrix and intercepts are available as interceptVector.
coefficients and intercept methods on a logistic regression model trained with multinomial family are not supported. Use coefficientMatrix and interceptVector instead.
The conditional probabilities of the outcome classes k∈1,2,…,K are modeled using the softmax function.
We minimize the weighted negative log-likelihood, using a multinomial response model, with elastic-net penalty to control for overfitting.
For a detailed derivation please see here.
Examples
The following example shows how to train a multiclass logistic regression model with elastic net regularization, as well as extract the multiclass training summary for evaluating the model.
import org.apache.spark.ml.classification.LogisticRegression
// Load training data
val training = spark
.read
.format("libsvm")
.load("file:///opt/spark/data/mllib/sample_multiclass_classification_data.txt")
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(0.3)
.setElasticNetParam(0.8)
// Fit the model
val lrModel = lr.fit(training)
// Print the coefficients and intercept for multinomial logistic regression
println(s"Coefficients: \n${lrModel.coefficientMatrix}")
println(s"Intercepts: \n${lrModel.interceptVector}")
val trainingSummary = lrModel.summary
// Obtain the objective per iteration
val objectiveHistory = trainingSummary.objectiveHistory
println("objectiveHistory:")
objectiveHistory.foreach(println)
// for multiclass, we can inspect metrics on a per-label basis
println("False positive rate by label:")
trainingSummary.falsePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) =>
println(s"label $label: $rate")
}
println("True positive rate by label:")
trainingSummary.truePositiveRateByLabel.zipWithIndex.foreach { case (rate, label) =>
println(s"label $label: $rate")
}
println("Precision by label:")
trainingSummary.precisionByLabel.zipWithIndex.foreach { case (prec, label) =>
println(s"label $label: $prec")
}
println("Recall by label:")
trainingSummary.recallByLabel.zipWithIndex.foreach { case (rec, label) =>
println(s"label $label: $rec")
}
println("F-measure by label:")
trainingSummary.fMeasureByLabel.zipWithIndex.foreach { case (f, label) =>
println(s"label $label: $f")
}
val accuracy = trainingSummary.accuracy
val falsePositiveRate = trainingSummary.weightedFalsePositiveRate
val truePositiveRate = trainingSummary.weightedTruePositiveRate
val fMeasure = trainingSummary.weightedFMeasure
val precision = trainingSummary.weightedPrecision
val recall = trainingSummary.weightedRecall
println(s"Accuracy: $accuracy\nFPR: $falsePositiveRate\nTPR: $truePositiveRate\n" +
s"F-measure: $fMeasure\nPrecision: $precision\nRecall: $recall")
/*
Output:
Coefficients:
3 x 4 CSCMatrix
(1,2) -0.7803943459681859
(0,3) 0.3176483191238039
(1,3) -0.3769611423403096
Intercepts:
[0.05165231659832854,-0.12391224990853622,0.07225993331020768]
objectiveHistory:
1.098612288668108
1.087602085441699
1.0341156572156232
1.0289859520256006
1.0300389657358995
1.0239965158223991
1.0236097451839508
1.0231082121970012
1.023022220302788
1.0230018151780262
1.0229963739557606
False positive rate by label:
label 0: 0.22
label 1: 0.05
label 2: 0.0
True positive rate by label:
label 0: 1.0
label 1: 1.0
label 2: 0.46
Precision by label:
label 0: 0.6944444444444444
label 1: 0.9090909090909091
label 2: 1.0
Recall by label:
label 0: 1.0
label 1: 1.0
label 2: 0.46
F-measure by label:
label 0: 0.819672131147541
label 1: 0.9523809523809523
label 2: 0.6301369863013699
Accuracy: 0.82
FPR: 0.09
TPR: 0.82
F-measure: 0.8007300232766211
Precision: 0.8678451178451179
Recall: 0.82
*/