Konvertieren Sie die Spark-DataFrame-Spalte in eine Python-Liste


102

Ich arbeite an einem Datenrahmen mit zwei Spalten, mvv und count.

+---+-----+
|mvv|count|
+---+-----+
| 1 |  5  |
| 2 |  9  |
| 3 |  3  |
| 4 |  1  |

Ich möchte zwei Listen mit MVV-Werten und Zählwert erhalten. Etwas wie

mvv = [1,2,3,4]
count = [5,9,3,1]

Also habe ich den folgenden Code ausprobiert: Die erste Zeile sollte eine Python-Zeilenliste zurückgeben. Ich wollte den ersten Wert sehen:

mvv_list = mvv_count_df.select('mvv').collect()
firstvalue = mvv_list[0].getInt(0)

Aber ich bekomme eine Fehlermeldung mit der zweiten Zeile:

AttributeError: getInt


Ab Spark 2.3 ist dieser Code der schnellste und am wenigsten wahrscheinliche, der OutOfMemory-Ausnahmen verursacht : list(df.select('mvv').toPandas()['mvv']). Arrow wurde in PySpark integriert, was sich toPandaserheblich beschleunigte . Verwenden Sie nicht die anderen Ansätze, wenn Sie Spark 2.3+ verwenden. Weitere Benchmarking-Details finden Sie in meiner Antwort.
Powers

Antworten:


140

Sehen Sie, warum diese Art und Weise, wie Sie es tun, nicht funktioniert. Zunächst versuchen Sie, eine Ganzzahl aus einem Zeilentyp abzurufen. Die Ausgabe Ihrer Sammlung sieht folgendermaßen aus:

>>> mvv_list = mvv_count_df.select('mvv').collect()
>>> mvv_list[0]
Out: Row(mvv=1)

Wenn Sie so etwas nehmen:

>>> firstvalue = mvv_list[0].mvv
Out: 1

Sie erhalten den mvvWert. Wenn Sie alle Informationen des Arrays möchten, können Sie Folgendes tun:

>>> mvv_array = [int(row.mvv) for row in mvv_list.collect()]
>>> mvv_array
Out: [1,2,3,4]

Wenn Sie jedoch dasselbe für die andere Spalte versuchen, erhalten Sie:

>>> mvv_count = [int(row.count) for row in mvv_list.collect()]
Out: TypeError: int() argument must be a string or a number, not 'builtin_function_or_method'

Dies geschieht, weil countes sich um eine integrierte Methode handelt. Und die Spalte hat den gleichen Namen wie count. Eine Problemumgehung hierfür ist das Ändern des Spaltennamens countin _count:

>>> mvv_list = mvv_list.selectExpr("mvv as mvv", "count as _count")
>>> mvv_count = [int(row._count) for row in mvv_list.collect()]

Diese Problemumgehung ist jedoch nicht erforderlich, da Sie über die Wörterbuchsyntax auf die Spalte zugreifen können:

>>> mvv_array = [int(row['mvv']) for row in mvv_list.collect()]
>>> mvv_count = [int(row['count']) for row in mvv_list.collect()]

Und es wird endlich funktionieren!


es funktioniert gut für die erste Spalte, aber es funktioniert nicht für die Spaltenanzahl, die ich wegen (der Funktionsanzahl des Funkens)
denke

Können Sie hinzufügen, was Sie mit der Zählung machen? Fügen Sie hier in den Kommentaren hinzu.
Thiago Baldim

Vielen Dank für Ihre Antwort. Diese Zeile funktioniert also mvv_list = [int (i.mvv) für i in mvv_count.select ('mvv'). collect ()], aber nicht diese eine count_list = [int (i.count) für i in mvv_count .select ('count'). collect ()] gibt eine ungültige Syntax zurück
a.moussa

Sie müssen diese select('count')Verwendung nicht wie count_list = [int(i.count) for i in mvv_list.collect()]folgt hinzufügen : Ich werde das Beispiel zur Antwort hinzufügen.
Thiago Baldim

1
@ a.moussa [i.['count'] for i in mvv_list.collect()]arbeitet, um es explizit zu machen, die Spalte 'count' und nicht die countFunktion zu verwenden
user989762

103

Wenn Sie einem Liner folgen, erhalten Sie die gewünschte Liste.

mvv = mvv_count_df.select("mvv").rdd.flatMap(lambda x: x).collect()

3
In Bezug auf die Leistung ist diese Lösung viel schneller als Ihre Lösung mvv_list = [int (i.mvv) für i in mvv_count.select ('mvv'). Collect ()]
Chanaka Fernando

Dies ist bei weitem die beste Lösung, die ich je gesehen habe. Vielen Dank.
Hui Chen

22

Dadurch erhalten Sie alle Elemente als Liste.

mvv_list = list(
    mvv_count_df.select('mvv').toPandas()['mvv']
)

1
Dies ist die schnellste und effizienteste Lösung für Spark 2.3+. Siehe die Benchmarking-Ergebnisse in meiner Antwort.
Powers

15

Der folgende Code hilft Ihnen dabei

mvv_count_df.select('mvv').rdd.map(lambda row : row[0]).collect()

3
Dies sollte die akzeptierte Antwort sein. Der Grund dafür ist, dass Sie sich während des gesamten Prozesses in einem Funkenkontext befinden und dann am Ende sammeln, anstatt früher aus dem Funkenkontext auszusteigen, was je nach Ihrer Tätigkeit zu einer größeren Sammlung führen kann.
AntiPawn79

15

Für meine Daten habe ich folgende Benchmarks erhalten:

>>> data.select(col).rdd.flatMap(lambda x: x).collect()

0,52 Sek

>>> [row[col] for row in data.collect()]

0,271 Sek

>>> list(data.select(col).toPandas()[col])

0,427 Sek

Das Ergebnis ist das gleiche


1
Wenn Sie toLocalIteratorstattdessen verwenden collect, sollte es sogar speichereffizienter sein[row[col] for row in data.toLocalIterator()]
oglop

5

Wenn Sie den folgenden Fehler erhalten:

AttributeError: Das Objekt 'list' hat kein Attribut 'collect'.

Dieser Code löst Ihre Probleme:

mvv_list = mvv_count_df.select('mvv').collect()

mvv_array = [int(i.mvv) for i in mvv_list]

Ich habe auch diesen Fehler bekommen und diese Lösung hat das Problem gelöst. Aber warum habe ich den Fehler bekommen? (Viele andere scheinen das nicht zu
verstehen

1

Ich habe eine Benchmarking-Analyse durchgeführt und bin list(mvv_count_df.select('mvv').toPandas()['mvv'])die schnellste Methode. Ich bin sehr überrascht.

Ich habe die verschiedenen Ansätze für 100.000 / 100 Millionen Zeilendatensätze mit einem i3.xlarge-Cluster mit 5 Knoten (jeder Knoten verfügt über 30,5 GB RAM und 4 Kerne) mit Spark 2.4.5 ausgeführt. Die Daten wurden gleichmäßig auf 20 bissig komprimierte Parkettdateien mit einer einzigen Spalte verteilt.

Hier sind die Benchmarking-Ergebnisse (Laufzeiten in Sekunden):

+-------------------------------------------------------------+---------+-------------+
|                          Code                               | 100,000 | 100,000,000 |
+-------------------------------------------------------------+---------+-------------+
| df.select("col_name").rdd.flatMap(lambda x: x).collect()    |     0.4 | 55.3        |
| list(df.select('col_name').toPandas()['col_name'])          |     0.4 | 17.5        |
| df.select('col_name').rdd.map(lambda row : row[0]).collect()|     0.9 | 69          |
| [row[0] for row in df.select('col_name').collect()]         |     1.0 | OOM         |
| [r[0] for r in mid_df.select('col_name').toLocalIterator()] |     1.2 | *           |
+-------------------------------------------------------------+---------+-------------+

* cancelled after 800 seconds

Goldene Regeln beim Sammeln von Daten auf dem Treiberknoten:

  • Versuchen Sie, das Problem mit anderen Ansätzen zu lösen. Das Sammeln von Daten auf dem Treiberknoten ist teuer, nutzt die Leistung des Spark-Clusters nicht und sollte nach Möglichkeit vermieden werden.
  • Sammle so wenige Zeilen wie möglich. Aggregieren, deduplizieren, filtern und bereinigen Sie Spalten, bevor Sie die Daten erfassen. Senden Sie so wenig Daten wie möglich an den Treiberknoten.

toPandas wurde in Spark 2.3 signifikant verbessert . Dies ist wahrscheinlich nicht der beste Ansatz, wenn Sie eine Spark-Version vor 2.3 verwenden.

Sehen Sie hier für weitere Informationen / Benchmarking Ergebnisse.


0

Eine mögliche Lösung ist die Verwendung der collect_list()Funktion von pyspark.sql.functions. Dadurch werden alle Spaltenwerte zu einem Pyspark-Array zusammengefasst, das beim Sammeln in eine Python-Liste konvertiert wird:

mvv_list   = df.select(collect_list("mvv")).collect()[0][0]
count_list = df.select(collect_list("count")).collect()[0][0] 
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.