Decision Tree Regression
Decision trees are a popular family of classification and regression methods. More information about the spark.ml implementation can be found further in the section on decision trees.
Examples
The following examples load a dataset in LibSVM format, split it into training and test sets, train on the first dataset, and then evaluate on the held-out test set. We use a feature transformer to index categorical features, adding metadata to the DataFrame which the Decision Tree algorithm can recognize.
1
import org.apache.spark.ml.Pipeline
2
import org.apache.spark.ml.evaluation.RegressionEvaluator
3
import org.apache.spark.ml.feature.VectorIndexer
4
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
5
import org.apache.spark.ml.regression.DecisionTreeRegressor
6
​
7
// Load the data stored in LIBSVM format as a DataFrame.
8
val data = spark.read.format("libsvm").load("file:///opt/spark/data/mllib/sample_libsvm_data.txt")
9
​
10
// Automatically identify categorical features, and index them.
11
// Here, we treat features with > 4 distinct values as continuous.
12
val featureIndexer = new VectorIndexer()
13
.setInputCol("features")
14
.setOutputCol("indexedFeatures")
15
.setMaxCategories(4)
16
.fit(data)
17
​
18
// Split the data into training and test sets (30% held out for testing).
19
val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
20
​
21
// Train a DecisionTree model.
22
val dt = new DecisionTreeRegressor()
23
.setLabelCol("label")
24
.setFeaturesCol("indexedFeatures")
25
​
26
// Chain indexer and tree in a Pipeline.
27
val pipeline = new Pipeline()
28
.setStages(Array(featureIndexer, dt))
29
​
30
// Train model. This also runs the indexer.
31
val model = pipeline.fit(trainingData)
32
​
33
// Make predictions.
34
val predictions = model.transform(testData)
35
​
36
// Select example rows to display.
37
predictions.select("prediction", "label", "features").show(5)
38
​
39
// Select (prediction, true label) and compute test error.
40
val evaluator = new RegressionEvaluator()
41
.setLabelCol("label")
42
.setPredictionCol("prediction")
43
.setMetricName("rmse")
44
val rmse = evaluator.evaluate(predictions)
45
println(s"Root Mean Squared Error (RMSE) on test data = $rmse")
46
​
47
val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel]
48
println(s"Learned regression tree model:\n ${treeModel.toDebugString}")
49
​
50
/*
51
Output:
52
+----------+-----+--------------------+
53
|prediction|label| features|
54
+----------+-----+--------------------+
55
| 0.0| 0.0|(692,[123,124,125...|
56
| 0.0| 0.0|(692,[124,125,126...|
57
| 0.0| 0.0|(692,[124,125,126...|
58
| 0.0| 0.0|(692,[126,127,128...|
59
| 0.0| 0.0|(692,[126,127,128...|
60
+----------+-----+--------------------+
61
only showing top 5 rows
62
​
63
Root Mean Squared Error (RMSE) on test data = 0.19611613513818404
64
Learned regression tree model:
65
DecisionTreeRegressionModel (uid=dtr_f30a452bc6d9) of depth 2 with 5 nodes
66
If (feature 406 <= 126.5)
67
If (feature 99 in {0.0,3.0})
68
Predict: 0.0
69
Else (feature 99 not in {0.0,3.0})
70
Predict: 1.0
71
Else (feature 406 > 126.5)
72
Predict: 1.0
73
​
74
​
75
*/
Copied!
Last modified 1yr ago
Copy link