Linear Regression
In statistics, linear regression is a linear approach to modeling the relationship between a scalar response (or dependent variable) and one or more explanatory variables (or independent variables). The case of one explanatory variable is called simple linear regression.
1
import org.apache.spark.ml.regression.LinearRegression
2
// Load training data
3
val training = spark.read.format("libsvm")
4
.load("file:///opt/spark/data/mllib/sample_linear_regression_data.txt")
5
​
6
val lr = new LinearRegression()
7
.setMaxIter(10)
8
.setRegParam(0.3)
9
.setElasticNetParam(0.8)
10
​
11
// Fit the model
12
val lrModel = lr.fit(training)
13
​
14
// Print the coefficients and intercept for linear regression
15
println(s"Coefficients: ${lrModel.coefficients} Intercept: ${lrModel.intercept}")
16
​
17
// Summarize the model over the training set and print out some metrics
18
val trainingSummary = lrModel.summary
19
println(s"numIterations: ${trainingSummary.totalIterations}")
20
println(s"objectiveHistory: [${trainingSummary.objectiveHistory.mkString(",")}]")
21
trainingSummary.residuals.show()
22
println(s"RMSE: ${trainingSummary.rootMeanSquaredError}")
23
println(s"r2: ${trainingSummary.r2}")import org.apache.spark.ml.regression.LinearRegression
24
​
25
/*
26
Output:
27
Coefficients: [0.0,0.32292516677405936,-0.3438548034562218,1.9156017023458414,0.05288058680386263,0.765962720459771,0.0,-0.15105392669186682,-0.21587930360904642,0.22025369188813426] Intercept: 0.1598936844239736
28
numIterations: 7
29
objectiveHistory: [0.49999999999999994,0.4967620357443381,0.4936361664340463,0.4936351537897608,0.4936351214177871,0.49363512062528014,0.4936351206216114]
30
+--------------------+
31
| residuals|
32
+--------------------+
33
| -9.889232683103197|
34
| 0.5533794340053554|
35
| -5.204019455758823|
36
| -20.566686715507508|
37
| -9.4497405180564|
38
| -6.909112502719486|
39
| -10.00431602969873|
40
| 2.062397807050484|
41
| 3.1117508432954772|
42
| -15.893608229419382|
43
| -5.036284254673026|
44
| 6.483215876994333|
45
| 12.429497299109002|
46
| -20.32003219007654|
47
| -2.0049838218725005|
48
| -17.867901734183793|
49
| 7.646455887420495|
50
| -2.2653482182417406|
51
|-0.10308920436195645|
52
| -1.380034070385301|
53
+--------------------+
54
only showing top 20 rows
55
​
56
RMSE: 10.189077167598475
57
r2: 0.022861466913958184
58
​
59
​
60
​
61
*/
Copied!
Last modified 1yr ago
Copy link