Spark.ml includes the Accelerated failure time (AFT) model which is a parametric survival regression model for censored data. It describes a model for the log of survival time, so itβs often called a log-linear model for survival analysis. Different from a Proportional hazards model designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently.
Copy import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.AFTSurvivalRegression
val training = spark.createDataFrame(Seq(
(1.218, 1.0, Vectors.dense(1.560, -0.605)),
(2.949, 0.0, Vectors.dense(0.346, 2.158)),
(3.627, 0.0, Vectors.dense(1.380, 0.231)),
(0.273, 1.0, Vectors.dense(0.520, 1.151)),
(4.199, 0.0, Vectors.dense(0.795, -0.226))
)).toDF("label", "censor", "features")
val quantileProbabilities = Array(0.3, 0.6)
val aft = new AFTSurvivalRegression()
.setQuantileProbabilities(quantileProbabilities)
.setQuantilesCol("quantiles")
val model = aft.fit(training)
// Print the coefficients, intercept and scale parameter for AFT survival regression
println(s"Coefficients: ${model.coefficients}")
println(s"Intercept: ${model.intercept}")
println(s"Scale: ${model.scale}")
model.transform(training).show(false)
/*
Output:
Coefficients: [-0.4963111466650683,0.19844437699933606]
Intercept: 2.638094615104006
Scale: 1.547234557436469
+-----+------+--------------+------------------+---------------------------------------+
|label|censor|features |prediction |quantiles |
+-----+------+--------------+------------------+---------------------------------------+
|1.218|1.0 |[1.56,-0.605] |5.718979487634987 |[1.1603238947151624,4.9954560102747525]|
|2.949|0.0 |[0.346,2.158] |18.076521181495465|[3.6675458454717664,15.789611866277742]|
|3.627|0.0 |[1.38,0.231] |7.381861804239101 |[1.4977061305190837,6.447962612338965] |
|0.273|1.0 |[0.52,1.151] |13.577612501425325|[2.754762148150694,11.859872224069736] |
|4.199|0.0 |[0.795,-0.226]|9.013097744073871 |[1.8286676321297772,7.872826505878406] |
+-----+------+--------------+------------------+---------------------------------------+
*/