Untyped User-Defined Aggregate Functions
Users have to extend the UserDefinedAggregateFunction abstract class to implement a custom untyped aggregate function.
For example, a user-defined average can look like:
1
import org.apache.spark.sql.{Row, SparkSession}
2
import org.apache.spark.sql
3
.expressions.MutableAggregationBuffer
4
import org.apache.spark.sql
5
.expressions.UserDefinedAggregateFunction
6
import org.apache.spark.sql.types._
7
​
8
object MyAverage extends
9
UserDefinedAggregateFunction {
10
// Data types of input arguments of
11
//this aggregate function
12
def inputSchema: StructType =
13
StructType(StructField("inputColumn", LongType)
14
:: Nil)
15
// Data types of values in the aggregation
16
//buffer
17
def bufferSchema: StructType = {
18
StructType(StructField("sum", LongType)
19
:: StructField("count", LongType) :: Nil)
20
}
21
// The data type of the returned value
22
def dataType: DataType = DoubleType
23
// Whether this function always returns the
24
//same output on the identical input
25
def deterministic: Boolean = true
26
// Initializes the given aggregation buffer.
27
//The buffer itself is a `Row` that in addition
28
//to
29
// standard methods like retrieving a value at
30
//an index (e.g., get(), getBoolean()), provides
31
// the opportunity to update its values. Note
32
//that arrays and maps inside the buffer are
33
//still immutable.
34
def initialize(buffer: MutableAggregationBuffer):
35
Unit = {
36
buffer(0) = 0L
37
buffer(1) = 0L
38
}
39
// Updates the given aggregation buffer
40
//`buffer` with new input data from `input`
41
def update(buffer: MutableAggregationBuffer
42
, input: Row): Unit = {
43
if (!input.isNullAt(0)) {
44
buffer(0) = buffer.getLong(0) + input.getLong(0)
45
buffer(1) = buffer.getLong(1) + 1
46
}
47
}
48
// Merges two aggregation buffers and stores
49
//the updated buffer values back to `buffer1`
50
def merge(buffer1: MutableAggregationBuffer
51
, buffer2: Row): Unit = {
52
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
53
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
54
}
55
// Calculates the final result
56
def evaluate(buffer: Row):
57
Double = buffer.getLong(0).toDouble /
58
buffer.getLong(1)
59
}
60
​
61
// Register the function to access it
62
spark.udf.register("myAverage", MyAverage)
63
​
64
val df = spark.read.json("file:///home/dv6/spark/spark/examples/src/main/resources/employees.json")
65
df.createOrReplaceTempView("employees")
66
df.show()
67
​
68
/*
69
+-------+------+
70
| name|salary|
71
+-------+------+
72
|Michael| 3000|
73
| Andy| 4500|
74
| Justin| 3500|
75
| Berta| 4000|
76
+-------+------+
77
*/
78
​
79
val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
80
result.show()
81
​
82
/*
83
+--------------+
84
|average_salary|
85
+--------------+
86
| 3750.0|
87
+--------------+
88
*/
Copied!
Last modified 1yr ago
Copy link