Untyped User-Defined Aggregate Functions

Users have to extend the UserDefinedAggregateFunction abstract class to implement a custom untyped aggregate function.

https://spark.apache.org/docs/1.6.2/api/java/org/apache/spark/sql/expressions/UserDefinedAggregateFunction.html

For example, a user-defined average can look like:

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql
   .expressions.MutableAggregationBuffer
import org.apache.spark.sql
   .expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._

object MyAverage extends 
   UserDefinedAggregateFunction {
  // Data types of input arguments of 
  //this aggregate function
  def inputSchema: StructType = 
     StructType(StructField("inputColumn", LongType)
        :: Nil)
  // Data types of values in the aggregation 
  //buffer
  def bufferSchema: StructType = {
    StructType(StructField("sum", LongType)
       :: StructField("count", LongType) :: Nil)
  }
  // The data type of the returned value
  def dataType: DataType = DoubleType
  // Whether this function always returns the 
  //same output on the identical input
  def deterministic: Boolean = true
  // Initializes the given aggregation buffer. 
  //The buffer itself is a `Row` that in addition 
  //to
  // standard methods like retrieving a value at 
  //an index (e.g., get(), getBoolean()), provides
  // the opportunity to update its values. Note 
  //that arrays and maps inside the buffer are 
  //still immutable.
  def initialize(buffer: MutableAggregationBuffer): 
     Unit = {
    buffer(0) = 0L
    buffer(1) = 0L
  }
  // Updates the given aggregation buffer 
  //`buffer` with new input data from `input`
  def update(buffer: MutableAggregationBuffer
     , input: Row): Unit = {
    if (!input.isNullAt(0)) {
   buffer(0) = buffer.getLong(0) + input.getLong(0)
      buffer(1) = buffer.getLong(1) + 1
    }
  }
  // Merges two aggregation buffers and stores 
  //the updated buffer values back to `buffer1`
  def merge(buffer1: MutableAggregationBuffer
     , buffer2: Row): Unit = {
 buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
 buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
  }
  // Calculates the final result
  def evaluate(buffer: Row): 
     Double = buffer.getLong(0).toDouble / 
     buffer.getLong(1)
}

// Register the function to access it
spark.udf.register("myAverage", MyAverage)

val df = spark.read.json("file:///home/dv6/spark/spark/examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
df.show()

/*
+-------+------+
|   name|salary|
+-------+------+
|Michael|  3000|
|   Andy|  4500|
| Justin|  3500|
|  Berta|  4000|
+-------+------+
*/

val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
result.show()

/*
+--------------+
|average_salary|
+--------------+
|        3750.0|
+--------------+
*/

Last updated