Streudiagramme in Pandas / Pyplot: So zeichnen Sie nach Kategorien


88

Ich versuche, mit einem Pandas DataFrame-Objekt ein einfaches Streudiagramm in Pyplot zu erstellen, möchte aber eine effiziente Methode zum Zeichnen von zwei Variablen, wobei die Symbole durch eine dritte Spalte (Schlüssel) vorgegeben werden. Ich habe verschiedene Möglichkeiten mit df.groupby ausprobiert, aber nicht erfolgreich. Ein Beispiel für ein df-Skript finden Sie unten. Dies färbt die Markierungen gemäß 'key1', aber ich würde gerne eine Legende mit 'key1'-Kategorien sehen. Bin ich nah dran Vielen Dank.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
plt.show()

Antworten:


116

Sie können dies verwenden, scatterdies erfordert jedoch numerische Werte für Ihre key1und Sie haben keine Legende, wie Sie bemerkt haben.

Es ist besser, nur plotfür solche diskreten Kategorien zu verwenden. Beispielsweise:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
fig, ax = plt.subplots()
ax.margins(0.05) # Optional, just adds 5% padding to the autoscaling
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend()

plt.show()

Geben Sie hier die Bildbeschreibung ein

Wenn Sie möchten, dass die Dinge wie der Standardstil aussehen pandas, aktualisieren Sie sie einfach rcParamsmit dem Pandas-Stylesheet und verwenden Sie den Farbgenerator. (Ich optimiere auch die Legende leicht):

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

groups = df.groupby('label')

# Plot
plt.rcParams.update(pd.tools.plotting.mpl_stylesheet)
colors = pd.tools.plotting._get_standard_colors(len(groups), color_type='random')

fig, ax = plt.subplots()
ax.set_color_cycle(colors)
ax.margins(0.05)
for name, group in groups:
    ax.plot(group.x, group.y, marker='o', linestyle='', ms=12, label=name)
ax.legend(numpoints=1, loc='upper left')

plt.show()

Geben Sie hier die Bildbeschreibung ein


Warum wird im obigen RGB-Beispiel das Symbol in der Legende zweimal angezeigt? Wie nur einmal zeigen?
Steve Schulist

1
@SteveSchulist - Verwenden Sie diese Option ax.legend(numpoints=1), um nur einen Marker anzuzeigen . Es gibt zwei, wie bei a Line2D, es gibt oft eine Linie, die die beiden Markierungen verbindet.
Joe Kington

Dieser Code funktionierte nur für mich, nachdem er plt.hold(True)nach dem ax.plot()Befehl hinzugefügt wurde . Irgendeine Idee warum?
Yuval Atzmon

set_color_cycle() wurde in matplotlib 1.5 veraltet. Das gibt es set_prop_cycle()jetzt.
Ale

51

Dies ist einfach mit Seaborn ( pip install seaborn) als Oneliner zu tun

sns.pairplot(x_vars=["one"], y_vars=["two"], data=df, hue="key1", size=5) ::

import seaborn as sns
import pandas as pd
import numpy as np
np.random.seed(1974)

df = pd.DataFrame(
    np.random.normal(10, 1, 30).reshape(10, 3),
    index=pd.date_range('2010-01-01', freq='M', periods=10),
    columns=('one', 'two', 'three'))
df['key1'] = (4, 4, 4, 6, 6, 6, 8, 8, 8, 8)

sns.pairplot(x_vars=["one"], y_vars=["two"], data=df, hue="key1", size=5)

Geben Sie hier die Bildbeschreibung ein

Hier ist der Datenrahmen als Referenz:

Geben Sie hier die Bildbeschreibung ein

Da Ihre Daten drei variable Spalten enthalten, möchten Sie möglicherweise alle paarweisen Dimensionen mit folgenden Elementen zeichnen:

sns.pairplot(vars=["one","two","three"], data=df, hue="key1", size=5)

Geben Sie hier die Bildbeschreibung ein

https://rasbt.github.io/mlxtend/user_guide/plotting/category_scatter/ ist eine weitere Option.


19

Mit plt.scatterkann ich mir nur eines vorstellen: einen Proxy-Künstler verwenden:

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
fig1 = plt.figure(1)
ax1 = fig1.add_subplot(111)
x=ax1.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)

ccm=x.get_cmap()
circles=[Line2D(range(1), range(1), color='w', marker='o', markersize=10, markerfacecolor=item) for item in ccm((array([4,6,8])-4.0)/4)]
leg = plt.legend(circles, ['4','6','8'], loc = "center left", bbox_to_anchor = (1, 0.5), numpoints = 1)

Und das Ergebnis ist:

Geben Sie hier die Bildbeschreibung ein


10

Sie können df.plot.scatter verwenden und ein Array an das Argument c = übergeben, das die Farbe jedes Punkts definiert:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), index = pd.date_range('2010-01-01', freq = 'M', periods = 10), columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)
colors = np.where(df["key1"]==4,'r','-')
colors[df["key1"]==6] = 'g'
colors[df["key1"]==8] = 'b'
print(colors)
df.plot.scatter(x="one",y="two",c=colors)
plt.show()

Geben Sie hier die Bildbeschreibung ein


4

Sie können auch Altair oder ggpot ausprobieren, die sich auf deklarative Visualisierungen konzentrieren.

import numpy as np
import pandas as pd
np.random.seed(1974)

# Generate Data
num = 20
x, y = np.random.random((2, num))
labels = np.random.choice(['a', 'b', 'c'], num)
df = pd.DataFrame(dict(x=x, y=y, label=labels))

Altair-Code

from altair import Chart
c = Chart(df)
c.mark_circle().encode(x='x', y='y', color='label')

Geben Sie hier die Bildbeschreibung ein

ggplot code

from ggplot import *
ggplot(aes(x='x', y='y', color='label'), data=df) +\
geom_point(size=50) +\
theme_bw()

Geben Sie hier die Bildbeschreibung ein


3

Ab matplotlib 3.1 können Sie verwenden .legend_elements(). Ein Beispiel finden Sie unter Automatische Legendenerstellung . Der Vorteil ist, dass ein einzelner Scatter-Aufruf verwendet werden kann.

In diesem Fall:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = (4,4,4,6,6,6,8,8,8,8)


fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = df['key1'], alpha = 0.8)
ax.legend(*sc.legend_elements())
plt.show()

Geben Sie hier die Bildbeschreibung ein

Falls die Schlüssel nicht direkt als Zahlen angegeben wurden, würde dies so aussehen

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

df = pd.DataFrame(np.random.normal(10,1,30).reshape(10,3), 
                  index = pd.date_range('2010-01-01', freq = 'M', periods = 10), 
                  columns = ('one', 'two', 'three'))
df['key1'] = list("AAABBBCCCC")

labels, index = np.unique(df["key1"], return_inverse=True)

fig, ax = plt.subplots()
sc = ax.scatter(df['one'], df['two'], marker = 'o', c = index, alpha = 0.8)
ax.legend(sc.legend_elements()[0], labels)
plt.show()

Geben Sie hier die Bildbeschreibung ein


Ich habe eine Fehlermeldung erhalten, dass das Objekt 'PathCollection' kein Attribut 'legends_elements' hat. Mein Code lautet wie folgt. fig, ax = plt.subplots(1, 1, figsize = (4,4)) scat = ax.scatter(rand_jitter(important_dataframe["workout_type_int"], jitter = 0.04), important_dataframe["distance"], c = color_list, marker = 'o', alpha = 0.9) print(scat.legends_elements()) #ax.legend(*scat.legend_elements())
Nandish Patel

1
@NandishPatel Überprüfen Sie den ersten Satz dieser Antwort. Auch stellen Sie sicher , nicht zu verwirren legends_elementsund legend_elements.
ImportanceOfBeingErnest

Ja Dankeschön. Das war ein Tippfehler (Legenden / Legende). Ich habe seit den letzten 6 Stunden an etwas gearbeitet, sodass mir die Matplotlib-Version nicht in den Sinn kam. Ich dachte, ich würde die neueste verwenden. Ich war verwirrt, dass die Dokumentation besagt, dass es eine solche Methode gibt, aber der Code einen Fehler gab. Danke nochmal. Ich kann jetzt schlafen
Nandish Patel

Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.