Wie wähle ich die erste Zeile jeder Gruppe aus?


143

Ich habe einen DataFrame wie folgt generiert:

df.groupBy($"Hour", $"Category")
  .agg(sum($"value") as "TotalValue")
  .sort($"Hour".asc, $"TotalValue".desc))

Die Ergebnisse sehen aus wie:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   0|   cat13|      22.1|
|   0|   cat95|      19.6|
|   0|  cat105|       1.3|
|   1|   cat67|      28.5|
|   1|    cat4|      26.8|
|   1|   cat13|      12.6|
|   1|   cat23|       5.3|
|   2|   cat56|      39.6|
|   2|   cat40|      29.7|
|   2|  cat187|      27.9|
|   2|   cat68|       9.8|
|   3|    cat8|      35.6|
| ...|    ....|      ....|
+----+--------+----------+

Wie Sie sehen können, wird der DataFrame Hourin aufsteigender Reihenfolge und dann TotalValuein absteigender Reihenfolge sortiert.

Ich möchte die oberste Reihe jeder Gruppe auswählen, dh

  • Wählen Sie aus der Gruppe der Stunden == 0 (0, Kat.26,30,9).
  • Wählen Sie aus der Gruppe der Stunden == 1 (1, Kat. 67,28.5).
  • Wählen Sie aus der Gruppe der Stunden == 2 (2, Kat. 56, 39,6).
  • und so weiter

Die gewünschte Ausgabe wäre also:

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
|   3|    cat8|      35.6|
| ...|     ...|       ...|
+----+--------+----------+

Es kann nützlich sein, auch die oberen N Zeilen jeder Gruppe auswählen zu können.

Jede Hilfe wird sehr geschätzt.

Antworten:


231

Fensterfunktionen :

So etwas sollte den Trick machen:

import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)

val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Diese Methode ist im Falle eines signifikanten Datenversatzes ineffizient.

Einfache SQL-Aggregation, gefolgt vonjoin :

Alternativ können Sie sich mit einem aggregierten Datenrahmen verbinden:

val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))

val dfTopByJoin = df.join(broadcast(dfMax),
    ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
  .drop("max_hour")
  .drop("max_value")

dfTopByJoin.show

// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Es werden doppelte Werte beibehalten (wenn es mehr als eine Kategorie pro Stunde mit demselben Gesamtwert gibt). Sie können diese wie folgt entfernen:

dfTopByJoin
  .groupBy($"hour")
  .agg(
    first("category").alias("category"),
    first("TotalValue").alias("TotalValue"))

Verwenden der Bestellung überstructs :

Ordentlicher, wenn auch nicht sehr gut getesteter Trick, der keine Verknüpfungen oder Fensterfunktionen erfordert:

val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
  .groupBy($"hour")
  .agg(max("vs").alias("vs"))
  .select($"Hour", $"vs.Category", $"vs.TotalValue")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

Mit DataSet API (Spark 1.6+, 2.0+):

Spark 1.6 :

case class Record(Hour: Integer, Category: String, TotalValue: Double)

df.as[Record]
  .groupBy($"hour")
  .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
  .show

// +---+--------------+
// | _1|            _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+

Spark 2.0 oder höher :

df.as[Record]
  .groupByKey(_.Hour)
  .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)

Die letzten beiden Methoden können die kartenseitige Kombination nutzen und erfordern kein vollständiges Mischen, sodass die meiste Zeit eine bessere Leistung im Vergleich zu Fensterfunktionen und Verknüpfungen erzielt werden sollte. Diese können auch mit Structured Streaming in verwendet werdencompleted Ausgabemodus verwendet werden.

Verwenden Sie nicht :

df.orderBy(...).groupBy(...).agg(first(...), ...)

Es scheint zu funktionieren (insbesondere im localModus), ist aber unzuverlässig (siehe SPARK-16207 , Tzach Zohar für die Verknüpfung des relevanten JIRA-Problems und SPARK-30335) ).

Der gleiche Hinweis gilt für

df.orderBy(...).dropDuplicates(...)

die intern äquivalenten Ausführungsplan verwendet.


3
Es sieht so aus, als wäre es seit Spark 1.6 row_number () anstelle von rowNumber
Adam Szałucha

Über das Verwenden Sie nicht df.orderBy (...). GropBy (...). Unter welchen Umständen können wir uns auf orderBy (...) verlassen? oder wenn wir nicht sicher sein können, ob orderBy () das richtige Ergebnis liefert, welche Alternativen haben wir?
Ignacio Alorre

Ich könnte etwas übersehen, aber im Allgemeinen wird empfohlen, groupByKey zu vermeiden , stattdessen sollte reductByKey verwendet werden. Außerdem speichern Sie eine Zeile.
Thomas

3
@Thomas das Vermeiden von groupBy / groupByKey ist nur beim Umgang mit RDDs festzustellen, dass die Dataset-API nicht einmal über eine reduByKey-Funktion verfügt.
Ruß


16

Für Spark 2.0.2 mit Gruppierung nach mehreren Spalten:

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)

val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

8

Dies ist genau das Gleiche wie die Antwort von zero323 , jedoch in SQL-Abfrage.

Angenommen, der Datenrahmen wird erstellt und registriert als

df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0   |cat26   |30.9      |
//|0   |cat13   |22.1      |
//|0   |cat95   |19.6      |
//|0   |cat105  |1.3       |
//|1   |cat67   |28.5      |
//|1   |cat4    |26.8      |
//|1   |cat13   |12.6      |
//|1   |cat23   |5.3       |
//|2   |cat56   |39.6      |
//|2   |cat40   |29.7      |
//|2   |cat187  |27.9      |
//|2   |cat68   |9.8       |
//|3   |cat8    |35.6      |
//+----+--------+----------+

Fensterfunktion:

sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn  FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Einfache SQL-Aggregation, gefolgt von Join:

sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
  "(select Hour, Category, TotalValue from table tmp1 " +
  "join " +
  "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
  "on " +
  "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
  "group by tmp3.Hour")
  .show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Verwenden der Bestellung über Strukturen:

sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

DataSets Weg und nicht tun sind die gleichen wie in der ursprünglichen Antwort


2

Das Muster ist nach Schlüsseln gruppiert => mache etwas mit jeder Gruppe, zB reduziere => kehre zum Datenrahmen zurück

Ich fand die Dataframe-Abstraktion in diesem Fall etwas umständlich, daher habe ich die RDD-Funktionalität verwendet

 val rdd: RDD[Row] = originalDf
  .rdd
  .groupBy(row => row.getAs[String]("grouping_row"))
  .map(iterableTuple => {
    iterableTuple._2.reduce(reduceFunction)
  })

val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)

1

Die folgende Lösung führt nur eine Gruppe durch und extrahiert die Zeilen Ihres Datenrahmens, die den maxValue enthalten, auf einmal. Keine weiteren Joins oder Windows erforderlich.

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.RowEncoder
import org.apache.spark.sql.DataFrame

//df is the dataframe with Day, Category, TotalValue

implicit val dfEnc = RowEncoder(df.schema)

val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}}

Aber es mischt zuerst alles. Es ist kaum eine Verbesserung (möglicherweise nicht schlechter als Fensterfunktionen, abhängig von den Daten).
Alper t. Turker

Sie haben eine Gruppe an erster Stelle, die ein Shuffle auslöst. Es ist nicht schlechter als die Fensterfunktion, da in einer Fensterfunktion das Fenster für jede einzelne Zeile im Datenrahmen ausgewertet wird.
Elghoto

1

Eine gute Möglichkeit, dies mit der Dataframe-API zu tun, ist die Verwendung der Argmax-Logik

  val df = Seq(
    (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
    (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
    (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
    (3,"cat8",35.6)).toDF("Hour", "Category", "TotalValue")

  df.groupBy($"Hour")
    .agg(max(struct($"TotalValue", $"Category")).as("argmax"))
    .select($"Hour", $"argmax.*").show

 +----+----------+--------+
 |Hour|TotalValue|Category|
 +----+----------+--------+
 |   1|      28.5|   cat67|
 |   3|      35.6|    cat8|
 |   2|      39.6|   cat56|
 |   0|      30.9|   cat26|
 +----+----------+--------+

0

Hier können Sie so machen -

   val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour")

data.withColumnRenamed("_1","Hour").show

-2

Wir können die Funktion des Fensters rank () verwenden (wobei Sie den Rang = 1 wählen würden). Der Rang fügt nur eine Zahl für jede Zeile einer Gruppe hinzu (in diesem Fall wäre es die Stunde).

Hier ist ein Beispiel. (von https://github.com/jaceklaskowski/mastering-apache-spark-book/blob/master/spark-sql-functions.adoc#rank )

val dataset = spark.range(9).withColumn("bucket", 'id % 3)

import org.apache.spark.sql.expressions.Window
val byBucket = Window.partitionBy('bucket).orderBy('id)

scala> dataset.withColumn("rank", rank over byBucket).show
+---+------+----+
| id|bucket|rank|
+---+------+----+
|  0|     0|   1|
|  3|     0|   2|
|  6|     0|   3|
|  1|     1|   1|
|  4|     1|   2|
|  7|     1|   3|
|  2|     2|   1|
|  5|     2|   2|
|  8|     2|   3|
+---+------+----+
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.