Haben die Vorhersagen eines Random Forest-Modells ein Vorhersageintervall?


52

Wenn ich ein randomForestModell ausführe , kann ich anhand des Modells Vorhersagen treffen. Gibt es eine Möglichkeit, ein Vorhersageintervall für jede der Vorhersagen zu erhalten, so dass ich weiß, wie "sicher" das Modell in seiner Antwort ist? Wenn dies möglich ist, basiert es einfach auf der Variabilität der abhängigen Variablen für das gesamte Modell oder hat es breitere und engere Intervalle, abhängig von dem bestimmten Entscheidungsbaum, der für eine bestimmte Vorhersage befolgt wurde?


3
AFAIK, alle RF-Bibliotheken haben irgendeine scoreFunktion zur Bewertung der Leistung. Da die Ausgabe auf der Mehrheit der Stimmen der Bäume im Wald basiert, erhalten Sie im Falle einer Klassifizierung eine Wahrscheinlichkeit, dass dieses Ergebnis wahr ist, basierend auf der Stimmenverteilung. Ich bin mir nicht sicher über die Regression. Welche Bibliothek benutzen Sie?
Sashkello

Antworten:


40

Dies ist zum Teil eine Antwort auf @Sashikanth Dareddy (da es nicht in einen Kommentar passt) und zum Teil eine Antwort auf den ursprünglichen Beitrag.

Denken Sie daran, was ein Vorhersageintervall ist. Es ist ein Intervall oder eine Reihe von Werten, in denen wir vorhersagen, dass zukünftige Beobachtungen liegen werden. Im Allgemeinen hat das Vorhersageintervall 2 Hauptteile, die seine Breite bestimmen, wobei ein Teil die Unsicherheit über den vorhergesagten Mittelwert (oder einen anderen Parameter) darstellt und ein Teil die Variabilität der einzelnen Beobachtungen um diesen Mittelwert darstellt. Das Konfidenzintervall ist aufgrund des zentralen Grenzwertsatzes ziemlich robust, und im Fall einer zufälligen Gesamtstruktur hilft auch das Bootstrapping. Das Vorhersageintervall hängt jedoch vollständig von den Annahmen über die Verteilung der Daten ab, da die Vorhersagevariablen CLT und Bootstrapping keinen Einfluss auf diesen Teil haben.

Das Vorhersageintervall sollte breiter sein, wobei das entsprechende Konfidenzintervall auch breiter wäre. Andere Dinge, die die Breite des Vorhersageintervalls beeinflussen würden, sind Annahmen über die gleiche Varianz oder nicht, dies muss aus dem Wissen des Forschers stammen, nicht aus dem Zufallsforstmodell.

Ein Vorhersageintervall ist für ein kategoriales Ergebnis nicht sinnvoll (Sie könnten ein Vorhersage-Set anstelle eines Intervalls erstellen, aber in den meisten Fällen wäre es wahrscheinlich nicht sehr informativ).

Wir können einige Probleme in Bezug auf Vorhersageintervalle erkennen, indem wir Daten simulieren, bei denen wir die genaue Wahrheit kennen. Betrachten Sie die folgenden Daten:

set.seed(1)

x1 <- rep(0:1, each=500)
x2 <- rep(0:1, each=250, length=1000)

y <- 10 + 5*x1 + 10*x2 - 3*x1*x2 + rnorm(1000)

Diese speziellen Daten folgen den Annahmen für eine lineare Regression und sind für eine zufällige Gesamtstrukturanpassung ziemlich einfach. Wir wissen aus dem "wahren" Modell, dass, wenn beide Prädiktoren 0 sind, der Mittelwert 10 ist, wir auch wissen, dass die einzelnen Punkte einer Normalverteilung mit einer Standardabweichung von 1 folgen. Dies bedeutet, dass das 95% Vorhersageintervall auf perfekter Kenntnis basiert Diese Punkte liegen zwischen 8 und 12 (also eigentlich zwischen 8,04 und 11,96, aber die Rundung macht es einfacher). Jedes geschätzte Vorhersageintervall sollte breiter als dieses sein (da keine perfekte Information vorhanden ist, wird die Breite zum Kompensieren hinzugefügt) und diesen Bereich einschließen.

Schauen wir uns die Intervalle von der Regression an:

fit1 <- lm(y ~ x1 * x2)

newdat <- expand.grid(x1=0:1, x2=0:1)

(pred.lm.ci <- predict(fit1, newdat, interval='confidence'))
#        fit       lwr      upr
# 1 10.02217  9.893664 10.15067
# 2 14.90927 14.780765 15.03778
# 3 20.02312 19.894613 20.15162
# 4 21.99885 21.870343 22.12735

(pred.lm.pi <- predict(fit1, newdat, interval='prediction'))
#        fit      lwr      upr
# 1 10.02217  7.98626 12.05808
# 2 14.90927 12.87336 16.94518
# 3 20.02312 17.98721 22.05903
# 4 21.99885 19.96294 24.03476

Wir können sehen, dass das geschätzte Mittel (Konfidenzintervall) eine gewisse Unsicherheit aufweist und dass wir ein Vorhersageintervall erhalten, das breiter ist (aber den Bereich von 8 bis 12 einschließt).

Schauen wir uns nun das Intervall an, das auf den individuellen Vorhersagen der einzelnen Bäume basiert (wir sollten davon ausgehen, dass diese breiter sind, da der Zufallswald nicht von den Annahmen profitiert (von denen wir wissen, dass sie für diese Daten zutreffen), die die lineare Regression macht):

library(randomForest)
fit2 <- randomForest(y ~ x1 + x2, ntree=1001)

pred.rf <- predict(fit2, newdat, predict.all=TRUE)

pred.rf.int <- apply(pred.rf$individual, 1, function(x) {
  c(mean(x) + c(-1, 1) * sd(x), 
  quantile(x, c(0.025, 0.975)))
})

t(pred.rf.int)
#                           2.5%    97.5%
# 1  9.785533 13.88629  9.920507 15.28662
# 2 13.017484 17.22297 12.330821 18.65796
# 3 16.764298 21.40525 14.749296 21.09071
# 4 19.494116 22.33632 18.245580 22.09904

Die Intervalle sind breiter als die Intervalle für die Regressionsvorhersage, decken jedoch nicht den gesamten Bereich ab. Sie enthalten die wahren Werte und sind daher möglicherweise als Konfidenzintervalle legitim. Sie sagen jedoch nur voraus, wo der Mittelwert (vorhergesagter Wert) liegt, nicht das hinzugefügte Stück für die Verteilung um diesen Mittelwert. Für den ersten Fall, in dem x1 und x2 beide 0 sind, unterschreiten die Intervalle nicht 9,7. Dies unterscheidet sich sehr von dem wahren Vorhersageintervall, das auf 8 abfällt. Wenn wir neue Datenpunkte generieren, gibt es mehrere Punkte (viel mehr) als 5%), die in den Intervallen true und regression liegen, jedoch nicht in die zufälligen Gesamtstrukturintervalle fallen.

Um ein Vorhersageintervall zu generieren, müssen Sie einige starke Annahmen über die Verteilung der einzelnen Punkte um die vorhergesagten Mittelwerte treffen. Anschließend können Sie die Vorhersagen aus den einzelnen Bäumen (das Bootstrap-Konfidenzintervall-Stück) ableiten und dann einen Zufallswert aus den angenommenen Werten generieren Verteilung mit diesem Zentrum. Die Quantile für diese generierten Stücke bilden möglicherweise das Vorhersageintervall (aber ich würde es trotzdem testen, möglicherweise müssen Sie den Vorgang mehrmals wiederholen und kombinieren).

Hier ist ein Beispiel dafür, wie Sie normale Abweichungen zu den Vorhersagen hinzufügen (da wir wissen, dass die ursprünglichen Daten normal verwendet wurden), wobei die Standardabweichung auf der geschätzten MSE von diesem Baum basiert:

pred.rf.int2 <- sapply(1:4, function(i) {
  tmp <- pred.rf$individual[i, ] + rnorm(1001, 0, sqrt(fit2$mse))
  quantile(tmp, c(0.025, 0.975))
})
t(pred.rf.int2)
#           2.5%    97.5%
# [1,]  7.351609 17.31065
# [2,] 10.386273 20.23700
# [3,] 13.004428 23.55154
# [4,] 16.344504 24.35970

Diese Intervalle enthalten diejenigen, die auf perfektem Wissen basieren. Sie hängen jedoch stark von den getroffenen Annahmen ab (die Annahmen sind hier gültig, da wir das Wissen darüber verwendet haben, wie die Daten simuliert wurden, und sie sind möglicherweise in realen Datenfällen nicht so gültig). Ich würde die Simulationen immer noch mehrmals für Daten wiederholen, die eher Ihren realen Daten ähneln (aber simuliert wurden, damit Sie die Wahrheit wissen), bevor ich dieser Methode voll vertraue.


11

Mir ist klar, dass dies ein alter Beitrag ist, aber ich habe einige Simulationen durchgeführt und dachte, ich werde meine Erkenntnisse teilen.

[μ+σ,μσ][μ+1.96σ,μ1.96σ]

Wenn Sie diesen Code in @GregSnow ändern, erhalten Sie die folgenden Ergebnisse

set.seed(1)
x1 <- rep( 0:1, each=500 )
x2 <- rep( 0:1, each=250, length=1000 )
y <- 10 + 5*x1 + 10*x2 - 3*x1*x2 + rnorm(1000)

library(randomForest)
fit2 <- randomForest(y~x1+x2)
pred.rf <- predict(fit2, newdat, predict.all=TRUE)
pred.rf.int <- t(apply( pred.rf$individual, 1, function(x){ 
  c( mean(x) + c(-1.96,1.96)*sd(x), quantile(x, c(0.025,0.975)) )}))

pred.rf.int
                          2.5%    97.5%
1  7.826896 16.05521  9.915482 15.31431
2 11.010662 19.35793 12.298995 18.64296
3 14.296697 23.61657 14.749248 21.11239
4 18.000229 23.73539 18.237448 22.10331

Wenn wir diese mit den Intervallen vergleichen, die durch Hinzufügen einer normalen Abweichung zu Vorhersagen mit Standardabweichung generiert wurden, wie es MSE wie @GregSnow vorschlug,

pred.rf.int2 <- sapply(1:4, function(i) {
   tmp <- pred.rf$individual[i,] + rnorm(1000, 0, sqrt(fit2$mse))
   quantile(tmp, c(0.025, 0.975))
   })
t(pred.rf.int2)
          2.5%    97.5%
[1,]  7.486895 17.21144
[2,] 10.551811 20.50633
[3,] 12.959318 23.46027
[4,] 16.444967 24.57601

Das Intervall dieser beiden Ansätze ist jetzt sehr eng. Das Diagramm des Vorhersageintervalls für die drei Ansätze gegen die Fehlerverteilung sieht in diesem Fall wie folgt aus

Bildbeschreibung hier eingeben

  • Schwarze Linien = Vorhersageintervalle aus linearer Regression,
  • Rote Linien = Zufällige Waldintervalle, berechnet anhand individueller Vorhersagen
  • Blaue Linien = Zufällige Gesamtstrukturintervalle, berechnet durch Hinzufügen einer normalen Abweichung zu den Vorhersagen

Lassen Sie uns nun die Simulation erneut ausführen, diesmal jedoch die Varianz des Fehlerterms erhöhen. Wenn unsere Vorhersageintervallberechnungen gut sind, sollten wir am Ende größere Intervalle haben als oben angegeben.

set.seed(1)
x1 <- rep( 0:1, each=500 )
x2 <- rep( 0:1, each=250, length=1000 )
y <- 10 + 5*x1 + 10*x2 - 3*x1*x2 + rnorm(1000,mean=0,sd=5)

fit1 <- lm(y~x1+x2)
newdat <- expand.grid(x1=0:1,x2=0:1)
predict(fit1,newdata=newdat,interval = "prediction")
      fit       lwr      upr
1 10.75006  0.503170 20.99695
2 13.90714  3.660248 24.15403
3 19.47638  9.229490 29.72327
4 22.63346 12.386568 32.88035

set.seed(1)
fit2 <- randomForest(y~x1+x2,localImp=T)
pred.rf.int <- t(apply( pred.rf$individual, 1, function(x){ 
  c( mean(x) + c(-1.96,1.96)*sd(x), quantile(x, c(0.025,0.975)) )}))
pred.rf.int
                          2.5%    97.5%
1  7.889934 15.53642  9.564565 15.47893
2 10.616744 18.78837 11.965325 18.51922
3 15.024598 23.67563 14.724964 21.43195
4 17.967246 23.88760 17.858866 22.54337

pred.rf.int2 <- sapply(1:4, function(i) {
   tmp <- pred.rf$individual[i,] + rnorm(1000, 0, sqrt(fit2$mse))
   quantile(tmp, c(0.025, 0.975))
   })
t(pred.rf.int2)
         2.5%    97.5%
[1,] 1.291450 22.89231
[2,] 4.193414 25.93963
[3,] 7.428309 30.07291
[4,] 9.938158 31.63777

Bildbeschreibung hier eingeben

Dies macht deutlich, dass die Berechnung der Vorhersageintervalle mit dem zweiten Ansatz weitaus genauer ist und Ergebnisse liefert, die dem Vorhersageintervall der linearen Regression sehr nahe kommen.

μiMSEiN(μi,RMSEi)N(μi/n,RMSEi/n)

mean.rf <- pred.rf$aggregate
sd.rf <- mean(sqrt(fit2$mse))
pred.rf.int3 <- cbind(mean.rf - 1.96*sd.rf, mean.rf + 1.96*sd.rf)
pred.rf.int3
1  1.332711 22.09364
2  4.322090 25.08302
3  8.969650 29.73058
4 10.546957 31.30789

Diese stimmen sehr gut mit den linearen Modellintervallen und auch dem Ansatz von @GregSnow überein. Beachten Sie jedoch, dass die zugrunde liegende Annahme bei allen von uns diskutierten Methoden ist, dass die Fehler einer Normalverteilung folgen.


10

Wenn Sie R verwenden, können Sie leicht Vorhersageintervalle für die Vorhersagen einer zufälligen Waldregression erstellen: Verwenden Sie einfach das Paket quantregForest(bei CRAN erhältlich ) und lesen Sie den Artikel von N. Meinshausen darüber, wie bedingte Quantile mit quantilen Regressionswäldern abgeleitet werden können und wie sie kann verwendet werden, um Vorhersageintervalle zu erstellen. Sehr informativ, auch wenn Sie nicht mit R arbeiten!


Es scheint, dass das Papier hierher verschoben wurde: jmlr.org/papers/volume7/meinshausen06a/meinshausen06a.pdf
Monica am

2
Dies scheint die richtige Antwort zu sein und erfordert keine Verteilungsannahmen in Bezug auf das Vorhersageintervall. Es gibt eine Anleitung, wie man das hier in Python tun: blog.datadive.net/prediction-intervals-for-random-forests
colin

6

Dies ist mit randomForest einfach zu lösen.

Lassen Sie mich zunächst die Regressionsaufgabe behandeln (vorausgesetzt, Ihr Wald hat 1000 Bäume). In der predictFunktion haben Sie die Möglichkeit, Ergebnisse von einzelnen Bäumen zurückzugeben. Dies bedeutet, dass Sie 1000 Spalten erhalten. Wir können den Durchschnitt der 1000 Spalten für jede Zeile nehmen - dies ist die reguläre Ausgabe, die RF auf irgendeine Weise erzeugt hätte. Um das Vorhersageintervall zu erhalten, sagen wir +/- 2 std. abweichungen müssen nur für jede zeile von den 1000 werten berechnet werden +/- 2 std. Abweichungen und machen Sie diese zu Ihrer oberen und unteren Grenze Ihrer Vorhersage.

Denken Sie zweitens bei der Klassifizierung daran, dass jeder Baum entweder 1 oder 0 (standardmäßig) ausgibt und die Summe aller 1000 durch 1000 dividierten Bäume die Klassenwahrscheinlichkeit ergibt (im Fall der binären Klassifizierung). Um ein Vorhersageintervall für die Wahrscheinlichkeit festzulegen, müssen Sie die min. Nodesize-Option (den genauen Namen dieser Option finden Sie in der randomForest-Dokumentation) Wenn Sie einen Wert von >> 1 festlegen, geben die einzelnen Bäume Zahlen zwischen 1 und 0 aus. Von nun an können Sie denselben Vorgang wie oben für beschrieben wiederholen die Regressionsaufgabe.

Ich hoffe das ergibt Sinn.


Ich habe es nicht ausprobiert, aber es scheint sinnvoll zu sein. Vielen Dank für die Beantwortung meiner alten Frage.
Dean MacGregor

1
Ich denke, diese Methode würde eher ein Konfidenzintervall als ein Vorhersageintervall ergeben. Die Ergebnisse sollten mit einem linearen Modell verglichen werden, bei dem die Theorie der Vorhersageintervalle gut etabliert ist. Am besten bei einigen simulierten Daten, bei denen die Wahrheit bekannt ist und alle Annahmen zutreffen.
Greg Snow

1
@ GregSnow: Was Sie von dem bekommen, was ich oben beschrieben habe, ist definitiv das Vorhersageintervall. Es ist zu beachten, dass Vorhersageintervalle im Allgemeinen viel breiter sind als die Konfidenzintervalle, da Konfidenzintervalle tatsächlich angeben, wo die mittlere Statistik einer Menge liegt, wenn sich die Vorhersage auf nur eine Beobachtung bezieht, was zu einer größeren Unsicherheit und damit zu größeren Intervallen führt. Die 1000 Vorhersagen, die Sie von 1000 Bäumen erhalten, können als Bootstrap-Beispiel betrachtet werden, und Sie müssen hier keine Normalitätsannahmen anwenden. Selbst eine einfache Dezilanalyse liefert gute Ergebnisse.

5
@SashikanthDareddy, Was Sie von dem bekommen, was Sie beschreiben, ist definitiv kein Vorhersageintervall. Ein Vorhersageintervall wird dadurch bestimmt, dass es nicht nur breiter ist. Ja, die einzelnen Bäume bilden einen Bootstrap, aber der Bootstrap schätzt Parameter, nicht einzelne Werte. Das Vorhersageintervall ist stark von der Verteilung der einzelnen Punkte abhängig. Die Tatsache, dass Ihre Methode anstelle der Kategorien ein Intervall für die Proportionen mit einem kategorialen Ergebnis angibt, zeigt dies. Siehe mein Beispiel in der hinzugefügten Antwort.
Greg Snow

0

Ich habe einige Optionen ausprobiert (dies alles WIP):

  1. Ich habe die abhängige Variable tatsächlich zu einem Klassifizierungsproblem mit den Ergebnissen als Bereiche anstelle eines einzelnen Werts gemacht. Die Ergebnisse waren schlecht, verglichen mit einem einfachen Wert. Ich habe diesen Ansatz aufgegeben.

  2. Ich habe es dann in mehrere Klassifizierungsprobleme konvertiert, von denen jedes eine Untergrenze für den Bereich darstellte (das Ergebnis des Modells war, ob es die Untergrenze überschreiten würde oder nicht), und dann alle Modelle ausgeführt (~ 20) und dann kombinieren Sie das Ergebnis, um eine endgültige Antwort als Bereich zu erhalten. Dies funktioniert besser als 1 oben, aber nicht so gut, wie ich es brauche. Ich arbeite immer noch daran, diesen Ansatz zu verbessern.

Ich habe OOB und Auslassungsschätzungen verwendet, um zu entscheiden, wie gut / schlecht meine Modelle sind.


0

Das Problem der Erstellung von Vorhersageintervallen für zufällige Waldvorhersagen wurde in der folgenden Abhandlung angesprochen:

Zhang, Haozhe, Joshua Zimmerman, Dan Nettleton und Daniel J. Nordman. "Random Forest Prediction Intervals." Der amerikanische Statistiker, 2019.

Das R-Paket "rfinterval" ist die bei CRAN verfügbare Implementierung.

Installation

So installieren Sie das R-Paket rfinterval :

#install.packages("devtools")
#devtools::install_github(repo="haozhestat/rfinterval")
install.packages("rfinterval")
library(rfinterval)
?rfinterval

Verwendungszweck

Schnellstart:

train_data <- sim_data(n = 1000, p = 10)
test_data <- sim_data(n = 1000, p = 10)

output <- rfinterval(y~., train_data = train_data, test_data = test_data,
                     method = c("oob", "split-conformal", "quantreg"),
                     symmetry = TRUE,alpha = 0.1)

### print the marginal coverage of OOB prediction interval
mean(output$oob_interval$lo < test_data$y & output$oob_interval$up > test_data$y)

### print the marginal coverage of Split-conformal prediction interval
mean(output$sc_interval$lo < test_data$y & output$sc_interval$up > test_data$y)

### print the marginal coverage of Quantile regression forest prediction interval
mean(output$quantreg_interval$lo < test_data$y & output$quantreg_interval$up > test_data$y)

Datenbeispiel:

oob_interval <- rfinterval(pm2.5 ~ .,
                            train_data = BeijingPM25[1:1000, ],
                            test_data = BeijingPM25[1001:2000, ],
                            method = "oob",
                            symmetry = TRUE,
                            alpha = 0.1)
str(oob_interval)

1
Willkommen auf der Seite, @ xiaolongmao. Vielleicht möchten Sie unsere Tour machen . Bitte posten Sie keine identischen Antworten auf mehrere Themen. Versuchen Sie, Ihre Antworten auf die jeweilige Frage in jedem Thread anzupassen. Wenn Sie einen Fall haben, in dem Sie wirklich glauben, dass eine identische Antwort die Frage vollständig beantwortet, bedeutet dies, dass die Frage ein Duplikat ist. Wenn Sie 50 Reputation erreichen, können Sie einen Kommentar an das OP senden. In der Zwischenzeit können Sie das Q zum Schließen als Duplikat kennzeichnen.
gung - Wiedereinsetzung von Monica
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.