Survival regression
Spark.ml includes the Accelerated failure time (AFT) model which is a parametric survival regression model for censored data. It describes a model for the log of survival time, so it’s often called a log-linear model for survival analysis. Different from a Proportional hazards model designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently.
1
import org.apache.spark.ml.linalg.Vectors
2
import org.apache.spark.ml.regression.AFTSurvivalRegression
3
​
4
val training = spark.createDataFrame(Seq(
5
(1.218, 1.0, Vectors.dense(1.560, -0.605)),
6
(2.949, 0.0, Vectors.dense(0.346, 2.158)),
7
(3.627, 0.0, Vectors.dense(1.380, 0.231)),
8
(0.273, 1.0, Vectors.dense(0.520, 1.151)),
9
(4.199, 0.0, Vectors.dense(0.795, -0.226))
10
)).toDF("label", "censor", "features")
11
val quantileProbabilities = Array(0.3, 0.6)
12
val aft = new AFTSurvivalRegression()
13
.setQuantileProbabilities(quantileProbabilities)
14
.setQuantilesCol("quantiles")
15
​
16
val model = aft.fit(training)
17
​
18
// Print the coefficients, intercept and scale parameter for AFT survival regression
19
println(s"Coefficients: ${model.coefficients}")
20
println(s"Intercept: ${model.intercept}")
21
println(s"Scale: ${model.scale}")
22
model.transform(training).show(false)
23
​
24
/*
25
Output:
26
Coefficients: [-0.4963111466650683,0.19844437699933606]
27
Intercept: 2.638094615104006
28
Scale: 1.547234557436469
29
+-----+------+--------------+------------------+---------------------------------------+
30
|label|censor|features |prediction |quantiles |
31
+-----+------+--------------+------------------+---------------------------------------+
32
|1.218|1.0 |[1.56,-0.605] |5.718979487634987 |[1.1603238947151624,4.9954560102747525]|
33
|2.949|0.0 |[0.346,2.158] |18.076521181495465|[3.6675458454717664,15.789611866277742]|
34
|3.627|0.0 |[1.38,0.231] |7.381861804239101 |[1.4977061305190837,6.447962612338965] |
35
|0.273|1.0 |[0.52,1.151] |13.577612501425325|[2.754762148150694,11.859872224069736] |
36
|4.199|0.0 |[0.795,-0.226]|9.013097744073871 |[1.8286676321297772,7.872826505878406] |
37
+-----+------+--------------+------------------+---------------------------------------+
38
​
39
*/
Copied!
Last modified 1yr ago
Copy link