Bagaimana cara memilih baris pertama dari setiap grup?

143

Saya memiliki DataFrame yang dihasilkan sebagai berikut:

df.groupBy($"Hour", $"Category")
  .agg(sum($"value") as "TotalValue")
  .sort($"Hour".asc, $"TotalValue".desc))

Hasilnya terlihat seperti:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   0|   cat13|      22.1|
|   0|   cat95|      19.6|
|   0|  cat105|       1.3|
|   1|   cat67|      28.5|
|   1|    cat4|      26.8|
|   1|   cat13|      12.6|
|   1|   cat23|       5.3|
|   2|   cat56|      39.6|
|   2|   cat40|      29.7|
|   2|  cat187|      27.9|
|   2|   cat68|       9.8|
|   3|    cat8|      35.6|
| ...|    ....|      ....|
+----+--------+----------+

Seperti yang Anda lihat, DataFrame dipesan dengan Hourdalam urutan yang meningkat, kemudian dengan TotalValueurutan menurun.

Saya ingin memilih baris teratas dari setiap grup, yaitu

  • dari grup Hour == 0 select (0, cat26,30.9)
  • dari grup Hour == 1 pilih (1, cat67,28.5)
  • dari grup Hour == 2 pilih (2, cat56,39.6)
  • dan seterusnya

Jadi output yang diinginkan adalah:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
|   3|    cat8|      35.6|
| ...|     ...|       ...|
+----+--------+----------+

Mungkin berguna untuk dapat memilih baris N teratas dari setiap grup juga.

Setiap bantuan sangat dihargai.

Rami
sumber

Jawaban:

231

Fungsi jendela :

Sesuatu seperti ini harus melakukan trik:

import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)

val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Metode ini akan tidak efisien jika kemiringan data yang signifikan.

Agregasi SQL biasa diikuti olehjoin :

Atau Anda dapat bergabung dengan bingkai data gabungan:

val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))

val dfTopByJoin = df.join(broadcast(dfMax),
    ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
  .drop("max_hour")
  .drop("max_value")

dfTopByJoin.show

// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Ini akan menyimpan nilai duplikat (jika ada lebih dari satu kategori per jam dengan nilai total yang sama). Anda dapat menghapus ini sebagai berikut:

dfTopByJoin
  .groupBy($"hour")
  .agg(
    first("category").alias("category"),
    first("TotalValue").alias("TotalValue"))

Menggunakan pemesanan lebihstructs :

Rapi, meskipun tidak diuji dengan sangat baik, trik yang tidak memerlukan fungsi gabungan atau jendela:

val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
  .groupBy($"hour")
  .agg(max("vs").alias("vs"))
  .select($"Hour", $"vs.Category", $"vs.TotalValue")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Dengan API DataSet (Spark 1.6+, 2.0+):

Spark 1.6 :

case class Record(Hour: Integer, Category: String, TotalValue: Double)

df.as[Record]
  .groupBy($"hour")
  .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
  .show

// +---+--------------+
// | _1|            _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+

Spark 2.0 atau lebih baru :

df.as[Record]
  .groupByKey(_.Hour)
  .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)

Dua metode terakhir dapat memanfaatkan gabungan sisi peta dan tidak memerlukan pengocokan penuh sehingga sebagian besar waktu harus menunjukkan kinerja yang lebih baik dibandingkan dengan fungsi jendela dan gabungan. Tongkat ini juga dapat digunakan dengan Streaming Terstruktur dalam completedmode keluaran.

Jangan gunakan :

df.orderBy(...).groupBy(...).agg(first(...), ...)

Ini mungkin terlihat bekerja (terutama dalam localmode) tetapi tidak dapat diandalkan (lihat SPARK-16207 , kredit untuk Tzach Zohar untuk menghubungkan masalah JIRA yang relevan , dan SPARK-30335 ).

Catatan yang sama berlaku untuk

df.orderBy(...).dropDuplicates(...)

yang secara internal menggunakan rencana eksekusi yang setara.

nol323
sumber
3
Sepertinya sejak percikan 1.6 itu adalah row_number () bukannya rowNumber
Adam Szałucha
Tentang Jangan gunakan df.orderBy (...). GropBy (...). Dalam keadaan apa kita dapat bergantung pada orderBy (...)? atau jika kita tidak dapat memastikan apakah orderBy () akan memberikan hasil yang benar, alternatif apa yang kita miliki?
Ignacio Alorre
Saya mungkin mengabaikan sesuatu, tetapi secara umum disarankan untuk menghindari groupByKey , sebagai gantinya mengurangiByKey harus digunakan. Anda juga akan menyimpan satu baris.
Thomas
3
@ Thomas menghindari groupBy / groupByKey hanya ketika berhadapan dengan RDD, Anda akan melihat bahwa api Dataset bahkan tidak memiliki fungsi mengurangiByKey.
soote
16

Untuk Spark 2.0.2 dengan pengelompokan berdasarkan beberapa kolom:

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)

val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
Antonín Hoskovec
sumber
8

Ini adalah persis sama zero323 's jawaban tetapi dalam SQL cara query.

Dengan asumsi bahwa dataframe dibuat dan terdaftar sebagai

df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0   |cat26   |30.9      |
//|0   |cat13   |22.1      |
//|0   |cat95   |19.6      |
//|0   |cat105  |1.3       |
//|1   |cat67   |28.5      |
//|1   |cat4    |26.8      |
//|1   |cat13   |12.6      |
//|1   |cat23   |5.3       |
//|2   |cat56   |39.6      |
//|2   |cat40   |29.7      |
//|2   |cat187  |27.9      |
//|2   |cat68   |9.8       |
//|3   |cat8    |35.6      |
//+----+--------+----------+

Fungsi jendela:

sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn  FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Agregasi SQL biasa diikuti dengan bergabung:

sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
  "(select Hour, Category, TotalValue from table tmp1 " +
  "join " +
  "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
  "on " +
  "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
  "group by tmp3.Hour")
  .show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Menggunakan pemesanan lebih dari struct:

sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Cara dataSets dan jangan lakukan sama dengan di jawaban asli

Ramesh Maharjan
sumber
2

Polanya adalah kelompok dengan tombol => melakukan sesuatu untuk setiap kelompok misalnya mengurangi => kembali ke bingkai data

Saya pikir abstraksi Dataframe agak rumit dalam hal ini jadi saya menggunakan fungsionalitas RDD

 val rdd: RDD[Row] = originalDf
  .rdd
  .groupBy(row => row.getAs[String]("grouping_row"))
  .map(iterableTuple => {
    iterableTuple._2.reduce(reduceFunction)
  })

val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)
Bebek karet
sumber
1

Solusi di bawah ini hanya melakukan satu groupBy dan mengekstrak baris dataframe Anda yang berisi maxValue dalam satu kesempatan. Tidak perlu Bergabung lagi, atau Windows.

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.DataFrame

//df is the dataframe with Day, Category, TotalValue

implicit val dfEnc = RowEncoder(df.schema)

val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}
elghoto
sumber
Tapi itu mengocok segalanya terlebih dahulu. Ini bukan perbaikan (mungkin tidak lebih buruk dari fungsi jendela, tergantung pada data).
Alper t. Turker
Anda memiliki tempat pertama di grup, yang akan memicu shuffle. Ini tidak lebih buruk daripada fungsi jendela karena dalam fungsi jendela itu akan mengevaluasi jendela untuk setiap baris dalam kerangka data.
elghoto
1

Cara yang baik untuk melakukan ini dengan api dataframe adalah menggunakan logika argmax seperti itu

  val df = Seq(
    (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
    (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
    (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
    (3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")

  df.groupBy($"Hour")
    .agg(max(struct($"TotalValue", $"Category")).as("argmax"))
    .select($"Hour", $"argmax.*").show

 +----+----------+--------+
 |Hour|TotalValue|Category|
 +----+----------+--------+
 |   1|      28.5|   cat67|
 |   3|      35.6|    cat8|
 |   2|      39.6|   cat56|
 |   0|      30.9|   cat26|
 +----+----------+--------+
randal25
sumber
0

Di sini Anda dapat melakukannya seperti ini -

   val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")

data.withColumnRenamed("_1","Hour").show
Shubham Agrawal
sumber
-2

Kita dapat menggunakan fungsi jendela rank () (di mana Anda akan memilih peringkat rank = 1) hanya menambahkan angka untuk setiap baris grup (dalam hal ini akan menjadi jam)

ini sebuah contoh. (dari https://github.com/jaceklaskowski/mastering-apache-spark-book/blob/master/spark-sql-functions.adoc#rank )

val dataset = spark.range(9).withColumn("bucket", 'id % 3)

import org.apache.spark.sql.expressions.Window
val byBucket = Window.partitionBy('bucket).orderBy('id)

scala> dataset.withColumn("rank", rank over byBucket).show
+---+------+----+
| id|bucket|rank|
+---+------+----+
|  0|     0|   1|
|  3|     0|   2|
|  6|     0|   3|
|  1|     1|   1|
|  4|     1|   2|
|  7|     1|   3|
|  2|     2|   1|
|  5|     2|   2|
|  8|     2|   3|
+---+------+----+
Vasile Surdu
sumber