Wie funktioniert der Parameter class_weight in scikit-learn?


115

Ich habe große Probleme zu verstehen, wie der class_weightParameter in der logistischen Regression von scikit-learn funktioniert.

Die Situation

Ich möchte die logistische Regression verwenden, um eine binäre Klassifizierung für einen sehr unausgeglichenen Datensatz durchzuführen. Die Klassen sind mit 0 (negativ) und 1 (positiv) gekennzeichnet, und die beobachteten Daten liegen in einem Verhältnis von etwa 19: 1 vor, wobei die Mehrzahl der Proben ein negatives Ergebnis aufweist.

Erster Versuch: Manuelles Vorbereiten von Trainingsdaten

Ich teilte die Daten, die ich hatte, in disjunkte Sätze zum Trainieren und Testen auf (ungefähr 80/20). Dann habe ich die Trainingsdaten zufällig von Hand abgetastet, um Trainingsdaten in anderen Proportionen als 19: 1 zu erhalten. von 2: 1 -> 16: 1.

Ich habe dann die logistische Regression für diese verschiedenen Trainingsdaten-Teilmengen trainiert und den Rückruf (= TP / (TP + FN)) als Funktion der verschiedenen Trainingsproportionen aufgezeichnet. Natürlich wurde der Rückruf an den disjunkten TEST-Proben berechnet, die die beobachteten Verhältnisse von 19: 1 hatten. Hinweis: Obwohl ich die verschiedenen Modelle mit unterschiedlichen Trainingsdaten trainiert habe, habe ich den Rückruf für alle Modelle mit denselben (disjunkten) Testdaten berechnet.

Die Ergebnisse waren wie erwartet: Der Rückruf lag bei 2: 1-Trainingsverhältnissen bei etwa 60% und fiel ziemlich schnell ab, als er 16: 1 erreichte. Es gab verschiedene Anteile von 2: 1 -> 6: 1, bei denen der Rückruf anständig über 5% lag.

Zweiter Versuch: Rastersuche

Als nächstes wollte ich verschiedene Regularisierungsparameter testen und habe daher GridSearchCV verwendet und ein Raster aus mehreren Werten des CParameters sowie des class_weightParameters erstellt. Um meine n: m-Anteile an negativen: positiven Trainingsbeispielen in die Wörterbuchsprache von zu übersetzen, class_weightdachte ich, dass ich nur mehrere Wörterbücher wie folgt spezifiziere:

{ 0:0.67, 1:0.33 } #expected 2:1
{ 0:0.75, 1:0.25 } #expected 3:1
{ 0:0.8, 1:0.2 }   #expected 4:1

und ich habe auch Noneund aufgenommen auto.

Diesmal waren die Ergebnisse total verrückt. Alle meine Rückrufe waren winzig (<0,05) für jeden Wert von class_weightaußer auto. Daher kann ich nur davon ausgehen, dass mein Verständnis für das Einstellen des class_weightWörterbuchs falsch ist. Interessanterweise lag der class_weightWert von 'auto' in der Rastersuche für alle Werte von bei 59% C, und ich vermutete, dass er sich auf 1: 1?

Meine Fragen

  1. Wie verwenden Sie es richtig class_weight, um ein anderes Gleichgewicht in den Trainingsdaten zu erreichen, als Sie es tatsächlich geben? Welches Wörterbuch übergebe ich speziell, class_weightum n: m-Anteile von negativen: positiven Trainingsmustern zu verwenden?

  2. Wenn Sie verschiedene class_weightWörterbücher an GridSearchCV übergeben, werden während der Kreuzvalidierung die Daten der Trainingsfalte gemäß dem Wörterbuch neu gewichtet, aber die tatsächlich angegebenen Stichprobenanteile für die Berechnung meiner Bewertungsfunktion für die Testfalte verwendet? Dies ist wichtig, da jede Metrik für mich nur dann nützlich ist, wenn sie aus Daten in den beobachteten Anteilen stammt.

  3. Was macht der autoWert von class_weightin Bezug auf Proportionen? Ich habe die Dokumentation gelesen und gehe davon aus, dass "die Daten umgekehrt proportional zu ihrer Häufigkeit ausgleichen" nur bedeutet, dass sie 1: 1 sind. Ist das richtig? Wenn nicht, kann jemand klarstellen?


Wenn man class_weight verwendet, wird die Verlustfunktion geändert. Beispielsweise wird anstelle der Kreuzentropie die Kreuzentropie gewichtet. Richtung Datascience.com/…
Prashanth

Antworten:


123

Zunächst einmal ist es möglicherweise nicht gut, nur durch Rückruf allein zu gehen. Sie können einfach einen Rückruf von 100% erreichen, indem Sie alles als positive Klasse klassifizieren. Normalerweise schlage ich vor, AUC zur Auswahl von Parametern zu verwenden und dann einen Schwellenwert für den Betriebspunkt (z. B. eine bestimmte Genauigkeitsstufe) zu finden, an dem Sie interessiert sind.

Wie es class_weightfunktioniert: Es bestraft Fehler in Stichproben von class[i]mit class_weight[i]statt 1. Ein höheres Klassengewicht bedeutet also, dass Sie mehr Gewicht auf eine Klasse legen möchten. class_weightNach Ihren Angaben ist Klasse 0 19-mal häufiger als Klasse 1. Sie sollten also die Klasse 1 relativ zur Klasse 0 erhöhen , z. B. {0: .1, 1: .9}. Wenn die class_weightSumme nicht 1 ergibt, ändert sich der Regularisierungsparameter grundsätzlich.

Wie das class_weight="auto"funktioniert, können Sie sich in dieser Diskussion ansehen . In der Dev-Version können Sie verwenden class_weight="balanced", was einfacher zu verstehen ist: Es bedeutet im Grunde, die kleinere Klasse zu replizieren, bis Sie so viele Samples wie in der größeren haben, aber auf implizite Weise.


1
Vielen Dank! Kurze Frage: Ich habe aus Gründen der Klarheit den Rückruf erwähnt und versuche tatsächlich zu entscheiden, welche AUC als Maß verwendet werden soll. Mein Verständnis ist, dass ich entweder die Fläche unter der ROC-Kurve oder die Fläche unter Rückruf vs. Präzisionskurve maximieren sollte, um Parameter zu finden. Nachdem ich die Parameter auf diese Weise ausgewählt habe, glaube ich, dass ich den Schwellenwert für die Klassifizierung durch Gleiten entlang der Kurve wähle. Hast du das gemeint? Wenn ja, welche der beiden Kurven ist am sinnvollsten zu betrachten, wenn mein Ziel darin besteht, so viele TPs wie möglich zu erfassen? Vielen Dank auch für Ihre Arbeit und Beiträge zum Scikit-Lernen !!!
Kilgoretrout

1
Ich denke, ROC wäre der Standardweg, aber ich denke nicht, dass es einen großen Unterschied geben wird. Sie benötigen jedoch ein Kriterium, um den Punkt auf der Kurve auszuwählen.
Andreas Mueller

3
@MiNdFrEaK Ich denke, Andrew meint, dass der Schätzer Stichproben in der Minderheitsklasse repliziert, sodass Stichproben verschiedener Klassen ausgeglichen sind. Es ist nur eine implizite Überabtastung.
Shawn TIAN

8
@MiNdFrEaK und Shawn Tian: SV-basierte Klassifizierer erzeugen nicht mehr Stichproben der kleineren Klassen, wenn Sie "ausgeglichen" verwenden. Es bestraft buchstäblich Fehler, die in den kleineren Klassen gemacht wurden. Anders zu sagen ist ein Fehler und irreführend, insbesondere bei großen Datenmengen, wenn Sie es sich nicht leisten können, mehr Beispiele zu erstellen. Diese Antwort muss bearbeitet werden.
Pablo Rivas

4
scikit-learn.org/dev/glossary.html#term-class-weight Klassengewichte werden je nach Algorithmus unterschiedlich verwendet: Bei linearen Modellen (wie linearem SVM oder logistischer Regression) ändern die Klassengewichte die Verlustfunktion um Gewichtung des Verlusts jeder Probe durch ihr Klassengewicht. Bei baumbasierten Algorithmen werden die Klassengewichte zur Neugewichtung des Aufteilungskriteriums verwendet. Beachten Sie jedoch, dass bei dieser Neuausrichtung nicht das Gewicht der Stichproben in jeder Klasse berücksichtigt wird.
Prashanth

2

Die erste Antwort ist gut, um zu verstehen, wie es funktioniert. Aber ich wollte verstehen, wie ich es in der Praxis anwenden sollte.

ZUSAMMENFASSUNG

  • Bei mäßig unausgeglichenen Daten OHNE Rauschen gibt es keinen großen Unterschied bei der Anwendung von Klassengewichten
  • Für mäßig unausgeglichene Daten MIT Rauschen und stark unausgeglichenen Daten ist es besser, Klassengewichte anzuwenden
  • param class_weight="balanced"funktioniert anständig, wenn Sie nicht manuell optimieren möchten
  • Wenn class_weight="balanced"Sie mehr wahre Ereignisse erfassen (höherer TRUE-Rückruf), erhalten Sie jedoch mit größerer Wahrscheinlichkeit falsche Warnungen (geringere TRUE-Genauigkeit).
    • Infolgedessen kann der Gesamtwert von% TRUE aufgrund aller falsch positiven Ergebnisse höher als der tatsächliche Wert sein
    • AUC kann Sie hier irreführen, wenn die Fehlalarme ein Problem darstellen
  • Es ist nicht erforderlich, die Entscheidungsschwelle auf das Ungleichgewicht in% zu ändern, auch bei starkem Ungleichgewicht. Es ist in Ordnung, 0,5 beizubehalten (oder irgendwo in der Nähe, je nachdem, was Sie benötigen).

NB

Das Ergebnis kann bei Verwendung von RF oder GBM abweichen. sklearn hat nicht class_weight="balanced" für GBM, aber lightgbm hatLGBMClassifier(is_unbalance=False)

CODE

# scikit-learn==0.21.3
from sklearn import datasets
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
import pandas as pd

# case: moderate imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.8]) #,flip_y=0.1,class_sep=0.5)
np.mean(y) # 0.2

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.184
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.184 => same as first
LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X).mean() # 0.296 => seems to make things worse?
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.292 => seems to make things worse?

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.83
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:2,1:8}).fit(X,y).predict(X)) # 0.86 => about the same
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.86 => about the same

# case: strong imbalance
X, y = datasets.make_classification(n_samples=50*15, n_features=5, n_informative=2, n_redundant=0, random_state=1, weights=[0.95])
np.mean(y) # 0.06

LogisticRegression(C=1e9).fit(X,y).predict(X).mean() # 0.02
(LogisticRegression(C=1e9).fit(X,y).predict_proba(X)[:,1]>0.5).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:0.5,1:0.5}).fit(X,y).predict(X).mean() # 0.02 => same as first
LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X).mean() # 0.25 => huh??
LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X).mean() # 0.22 => huh??
(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).mean() # same as last

roc_auc_score(y,LogisticRegression(C=1e9).fit(X,y).predict(X)) # 0.64
roc_auc_score(y,LogisticRegression(C=1e9,class_weight={0:1,1:20}).fit(X,y).predict(X)) # 0.84 => much better
roc_auc_score(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)) # 0.85 => similar to manual
roc_auc_score(y,(LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict_proba(X)[:,1]>0.5).astype(int)) # same as last

print(classification_report(y,LogisticRegression(C=1e9).fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9).fit(X,y).predict(X),margins=True,normalize='index') # few prediced TRUE with only 28% TRUE recall and 86% TRUE precision so 6%*28%~=2%

print(classification_report(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X)))
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True)
pd.crosstab(y,LogisticRegression(C=1e9,class_weight="balanced").fit(X,y).predict(X),margins=True,normalize='index') # 88% TRUE recall but also lot of false positives with only 23% TRUE precision, making total predicted % TRUE > actual % TRUE
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.