In PySpark werden mehrere Datenrahmen zeilenweise zusammengeführt


21

Ich habe 10 Datenrahmen pyspark.sql.dataframe.DataFrame, erhalten aus randomSplitwie (td1, td2, td3, td4, td5, td6, td7, td8, td9, td10) = td.randomSplit([.1, .1, .1, .1, .1, .1, .1, .1, .1, .1], seed = 100)jetzt will ich 9 beitreten td‚s in einem einzigen Datenrahmen, wie soll ich das tun?

Ich habe es schon mit probiert unionAll, aber diese Funktion akzeptiert nur zwei Argumente.

td1_2 = td1.unionAll(td2) 
# this is working fine

td1_2_3 = td1.unionAll(td2, td3) 
# error TypeError: unionAll() takes exactly 2 arguments (3 given)

Gibt es eine Möglichkeit, mehr als zwei Datenrahmen zeilenweise zu kombinieren?

Der Zweck dabei ist, dass ich die 10-fache Kreuzvalidierung manuell ohne Verwendung der PySpark- CrossValidatorMethode durchführe. Daher nehme ich 9 in das Training und 1 in die Testdaten und wiederhole sie dann für andere Kombinationen.


1
Dies beantwortet die Frage nicht direkt, aber hier gebe ich einen Vorschlag zur Verbesserung der Benennungsmethode, damit wir am Ende nicht Folgendes eingeben müssen: [td1, td2, td3, td4, td5, td6, td7 td8, td9, td10]. Stellen Sie sich dies für einen 100-fachen Lebenslauf vor. Ich mache Folgendes: portions = [0.1] * 10 cv = df7.randomSplit (portions) folds = liste (range (10)) für i in range (10): test_data = cv [i] fold_no_i = folds [: i] + faltet [i + 1:] train_data = cv [fold_no_i [0]] für j in fold_no_i [1:]: train_data = train_data.union (cv [j])
ngoc bis zum

Antworten:


37

Gestohlen von: /programming/33743978/spark-union-of-multiple-rdds

Außerhalb der Verkettung von Gewerkschaften ist dies die einzige Möglichkeit, dies für DataFrames zu tun.

from functools import reduce  # For Python 3.x
from pyspark.sql import DataFrame

def unionAll(*dfs):
    return reduce(DataFrame.unionAll, dfs)

unionAll(td2, td3, td4, td5, td6, td7, td8, td9, td10)

Was passiert, ist, dass alle Objekte, die Sie als Parameter übergeben haben, mit unionAll reduziert werden (diese Reduzierung stammt aus Python, nicht aus der Spark-Reduzierung, obwohl sie ähnlich funktionieren), wodurch sie schließlich auf einen DataFrame reduziert werden.

Wenn es sich anstelle von DataFrames um normale RDDs handelt, können Sie eine Liste davon an die Union-Funktion Ihres SparkContext übergeben

BEARBEITEN: Für Ihren Zweck schlage ich eine andere Methode vor, da Sie diese gesamte Vereinigung 10 Mal für Ihre verschiedenen Falten für die Kreuzvalidierung wiederholen müssten, würde ich Beschriftungen hinzufügen, für welche Falte eine Zeile gehört, und einfach Ihren DataFrame für jede Falte basierend auf filtern das Etikett


(+1) Eine nette Abhilfe. Es muss jedoch eine Funktion geben, die die Verkettung mehrerer Datenrahmen ermöglicht. Wäre ganz praktisch!
Dawny33

Dem kann ich nicht widersprechen
Jan van der Vegt

@JanvanderVegt Danke, es funktioniert und die Idee, Etiketten hinzuzufügen, um den Trainings- und Testdatensatz herauszufiltern, habe ich bereits gemacht. Vielen Dank für Ihre Hilfe.
Krishna Prasad

@ Jan van der Vegt Können Sie bitte die gleiche Logik für Join anwenden und diese Frage beantworten
GeorgeOfTheRF


6

Wenn die zu kombinierenden Datenrahmen nicht dieselbe Spaltenreihenfolge haben, ist es besser, df2.select (df1.columns) auszuwählen, um sicherzustellen, dass beide df vor der Vereinigung dieselbe Spaltenreihenfolge haben.

import functools 

def unionAll(dfs):
    return functools.reduce(lambda df1,df2: df1.union(df2.select(df1.columns)), dfs) 

Beispiel:

df1 = spark.createDataFrame([[1,1],[2,2]],['a','b'])
# different column order. 
df2 = spark.createDataFrame([[3,333],[4,444]],['b','a']) 
df3 = spark.createDataFrame([555,5],[666,6]],['b','a']) 

unioned_df = unionAll([df1, df2, df3])
unioned_df.show() 

Bildbeschreibung hier eingeben

Andernfalls würde das folgende Ergebnis generiert.

from functools import reduce  # For Python 3.x
from pyspark.sql import DataFrame

def unionAll(*dfs):
    return reduce(DataFrame.unionAll, dfs) 

unionAll(*[df1, df2, df3]).show()

Bildbeschreibung hier eingeben


2

Wie wäre es mit Rekursion?

def union_all(dfs):
    if len(dfs) > 1:
        return dfs[0].unionAll(union_all(dfs[1:]))
    else:
        return dfs[0]

td = union_all([td1, td2, td3, td4, td5, td6, td7, td8, td9, td10])
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.