大数据——Spark-SQL自定义函数UDF、UDAF、UDTF
Spark-SQL自定义函数UDF、UDAF、UDTF
- 自定义函数分类
-
- UDF
- UDAF
- UDTF
自定义函数分类
类似有Hive当中的自定义函数,Spark同样可以使用自定义的函数来实现新的功能
Spark中的自定义函数有三类:
-
UDF(User-Defined-Function)
输入一行,输出一行
-
UDAF(User-Defined Aggregation Function)
输入多行,输出一行
-
UDTF(User-Defined Table-Generating Functions)
输入一行,输出多行
UDF
-
需求:用户行为喜好个数统计
-
hobbies.txt
alice jogging,Coding,cooking
lina travel,dance -
代码展示:
package nj.zb.kb09.sql
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.StringTrimRight
import org.apache.spark.sql.{DataFrame, SparkSession}case class Hobbies(name:String,hobbies:String)
object SparkUDFDemo {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().master(“local[*]”).appName(“SparkUDFDemo”).getOrCreate()import spark.implicits._
val sc: SparkContext = spark.sparkContext
val rdd: RDD[String] = sc.textFile(“in/hobbies.txt”)
val df: DataFrame = rdd.map(x=>x.split(" ")).map(x=>Hobbies(x(0),x(1))).toDF()
df.printSchema()
df.show()df.registerTempTable("hobbies") spark.udf.register("hobby_num",(v:String)=>v.split(",").size) val frame: DataFrame = spark.sql(""+"select name,hobbies,hobby_num(hobbies) as hobbynum from hobbies") frame.show()
}
}
结果展示:
-
需求:将每一行数据转换成大写
-
udf.txt
Hello
abc
study
small -
代码展示:
package nj.zb.kb09.sql
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.StringTrimRight
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}object SparkUDFDemo2 {
def main(args: Array[String]): Unit = {
//创建SparkSession
val spark: SparkSession = SparkSession.builder().master(“local[*]”).appName(“SparkUDFDemo2”).getOrCreate()
val sc: SparkContext = spark.sparkContext//读取文件 val fileDs: Dataset[String] = spark.read.textFile("in/udf2.txt") fileDs.printSchema() fileDs.show() //注册一个函数名称为smallToBig,作用是传入一个String,返回一个大写的String spark.udf.register("smallToBig",(str:String)=>str.toUpperCase()) //定义一个视图 fileDs.createOrReplaceTempView("t_word") //使用自定义的函数 val df: DataFrame = spark.sql("select value,smallToBig(value) from t_word") df.printSchema() df.show()
}
}
结果展示:
UDAF
-
继承UserDefinedAggregateFunction方法重写说明
inputSchema:输入数据的类型 bufferSchema:产生中间结果的数据类型 dataType:最终返回的结果类型 deterministic:确保一致性,一般用true initialize:指定初始值 update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算) merge:全局聚合(将每个分区的结果进行聚合) evaluate:计算最终的结果
-
需求:求不同性别的平均年龄
-
udaf.json
{“id”: 1001, “name”: “foo”, “sex”: “man”, “age”: 20}
{“id”: 1002, “name”: “bar”, “sex”: “man”, “age”: 24}
{“id”: 1003, “name”: “baz”, “sex”: “man”, “age”: 18}
{“id”: 1004, “name”: “foo1”, “sex”: “woman”, “age”: 17}
{“id”: 1005, “name”: “bar2”, “sex”: “woman”, “age”: 19}
{“id”: 1006, “name”: “baz3”, “sex”: “woman”, “age”: 20} -
代码展示:
package nj.zb.kb09.sql
import org.apache.spark.SparkContext
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._//自定义UDAF函数及使用
object SparkUDAFDemo {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().master(“local[*]”).appName(“SparkUDAFDemo”).getOrCreate()import spark.implicits._ val sc: SparkContext = spark.sparkContext val df: DataFrame = spark.read.json("in/udaf.json") df.printSchema() df.show() //创建并注册自定义udaf函数 val myUdaf = new MyAgeAvgFunction spark.udf.register("myAvgAge",myUdaf) //创建临时视图 df.createTempView("userinfo") //使用自定义的函数 val resultDF: DataFrame = spark.sql("select sex,myAvgAge(age) from userinfo group by sex") resultDF.printSchema() resultDF.show() //使用内置的avg函数 println("-----------------") spark.sql("select sex,avg(age) from userinfo group by sex").show()
}
}class MyAgeAvgFunction extends UserDefinedAggregateFunction{
//聚合函数的输入数据结构
override def inputSchema: StructType = {
new StructType().add(“age”,LongType)
//StructType(StructField(“age”,LongType)::Nil)
}//缓存区的数据结构
override def bufferSchema: StructType = {
new StructType().add(“sum”,LongType).add(“count”,LongType)
// StructType(StructField(“sum”,LongType)::StructField(“count”,LongType)::Nil)
}
//聚合函数返回值数据结构
override def dataType: DataType =DoubleType
//聚合函数是否是幂等的,即相同输入是否总是能得到相同输出
override def deterministic: Boolean = true//初始化缓冲区
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0)=0L
buffer(1)=0L
}//给聚合函数传入一条数据进行处理
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0)=buffer.getLong(0)+input.getLong(0)
buffer(1)=buffer.getLong(1)+1
}
//合并聚合函数缓冲区
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//总年龄数
buffer1(0)=buffer1.getLong(0)+buffer2.getLong(0)
//部个数
buffer1(1)=buffer1.getLong(1)+buffer2.getLong(1)
}//计算最终结果
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble/buffer.getLong(1)
}}
结果展示:
UDTF
-
需求:遍历ls学的大数据组件
-
udtf.txt
01//zs//Hadoop scala spark hive hbase
02//ls//Hadoop scala kafka hive hbase Oozie
03//ww//Hadoop scala spark hive sqoop -
代码展示:
package nj.zb.kb09.sql
import java.util
import org.apache.hadoop.hive.ql.exec.UDFArgumentException
import org.apache.hadoop.hive.ql.udf.generic.GenericUDTF
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory, PrimitiveObjectInspector, StructObjectInspector}
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SparkSession}//Hive UDTF函数
class MyUDTF extends GenericUDTF{//这个方法的作用:1、输入参数校验 2、输出列定义,可以多于1列,相当于可以生成多行多列数据
override def initialize(argOIs: Array[ObjectInspector]): StructObjectInspector = {
if(argOIs.length!=1){
throw new UDFArgumentException(“有且只能有一个参数传入”)
}
if (argOIs(0).getCategory!=ObjectInspector.Category.PRIMITIVE){
throw new UDFArgumentException(“参数类型不匹配”)
}
val fieldNames=new util.ArrayList[String]
val fieldOIs=new util.ArrayList[ObjectInspector]fieldNames.add("type") //这里定义的是输出列字段类型 fieldOIs.add(PrimitiveObjectInspectorFactory.javaStringObjectInspector) ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames,fieldOIs)
}
//传入Hadoop scala kafka hive hbase Oozie
/*输出 HEAD type String
Hadoop
scala
kafka
hive
hbase
Oozie
*///这是处理数据的方法,入参数组里只有一行数据,即每次调用process方法只处理一行数据
override def process(objects: Array[AnyRef]): Unit = {
//将字符串切分成单个字符的数组
val strings: Array[String] = objects(0).toString.split(" ")
println(strings)
for (str<-strings){
val tmp: Array[String] = new ArrayString
tmp(0)=str
//调用forward方法,必须传字符串数组,即使只有一个元素
forward(tmp)
}
}override def close(): Unit ={
}
}object SparkUDTFDemo {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().master(“local[*]”).appName(“SparkUDTFDemo”).enableHiveSupport().getOrCreate()
import spark.implicits._
val sc: SparkContext = spark.sparkContext
val lines: RDD[String] = sc.textFile(“in/udtf.txt”)val stuDf: DataFrame = lines.map(_.split("//")).filter(x=>x(1).equals("ls")).map(x=>(x(0),x(1),x(2))).toDF("id","name","class") stuDf.printSchema() stuDf.show() stuDf.createOrReplaceTempView("student") spark.sql("CREATE TEMPORARY FUNCTION MyUDTF AS 'nj.zb.kb09.sql.MyUDTF'") val resultDF: DataFrame = spark.sql("select MyUDTF(class) from student") resultDF.printSchema() resultDF.show()
}
}
结果展示: