Survival regression

Spark.ml includes the Accelerated failure time (AFT) model which is a parametric survival regression model for censored data. It describes a model for the log of survival time, so it’s often called a log-linear model for survival analysis. Different from a Proportional hazards model designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently.

import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.regression.AFTSurvivalRegression

val training = spark.createDataFrame(Seq(
  (1.218, 1.0, Vectors.dense(1.560, -0.605)),
  (2.949, 0.0, Vectors.dense(0.346, 2.158)),
  (3.627, 0.0, Vectors.dense(1.380, 0.231)),
  (0.273, 1.0, Vectors.dense(0.520, 1.151)),
  (4.199, 0.0, Vectors.dense(0.795, -0.226))
)).toDF("label", "censor", "features")
val quantileProbabilities = Array(0.3, 0.6)
val aft = new AFTSurvivalRegression()
  .setQuantileProbabilities(quantileProbabilities)
  .setQuantilesCol("quantiles")

val model = aft.fit(training)

// Print the coefficients, intercept and scale parameter for AFT survival regression
println(s"Coefficients: ${model.coefficients}")
println(s"Intercept: ${model.intercept}")
println(s"Scale: ${model.scale}")
model.transform(training).show(false)

/*
Output:
Coefficients: [-0.4963111466650683,0.19844437699933606]
Intercept: 2.638094615104006
Scale: 1.547234557436469
+-----+------+--------------+------------------+---------------------------------------+
|label|censor|features      |prediction        |quantiles                              |
+-----+------+--------------+------------------+---------------------------------------+
|1.218|1.0   |[1.56,-0.605] |5.718979487634987 |[1.1603238947151624,4.9954560102747525]|
|2.949|0.0   |[0.346,2.158] |18.076521181495465|[3.6675458454717664,15.789611866277742]|
|3.627|0.0   |[1.38,0.231]  |7.381861804239101 |[1.4977061305190837,6.447962612338965] |
|0.273|1.0   |[0.52,1.151]  |13.577612501425325|[2.754762148150694,11.859872224069736] |
|4.199|0.0   |[0.795,-0.226]|9.013097744073871 |[1.8286676321297772,7.872826505878406] |
+-----+------+--------------+------------------+---------------------------------------+

*/

Last updated