Python-Bibliothek für segmentierte Regression (auch stückweise Regression genannt)


16


Diese Frage gibt eine Methode zum Ausführen einer stückweisen Regression, indem eine Funktion definiert und Standard-Python-Bibliotheken verwendet werden. stackoverflow.com/questions/29382903/…

Eine ähnliche Frage ( stackoverflow.com/questions/29382903/… ) und eine hilfreiche Bibliothek für stückweise Regression ( pypi.org/project/pwlf )
prashanth

Antworten:


7

numpy.piecewise kann dies tun.

stückweise (x, condlist, funclist, * args, ** kw)

Bewerten Sie eine stückweise definierte Funktion.

Bewerten Sie bei einer Reihe von Bedingungen und entsprechenden Funktionen jede Funktion an den Eingabedaten, wo immer ihre Bedingung wahr ist.

Ein Beispiel ist auf SO gegeben hier . Der Vollständigkeit halber hier ein Beispiel:

from scipy import optimize
import matplotlib.pyplot as plt
import numpy as np

x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10 ,11, 12, 13, 14, 15], dtype=float)
y = np.array([5, 7, 9, 11, 13, 15, 28.92, 42.81, 56.7, 70.59, 84.47, 98.36, 112.25, 126.14, 140.03])

def piecewise_linear(x, x0, y0, k1, k2):
    return np.piecewise(x, [x < x0, x >= x0], [lambda x:k1*x + y0-k1*x0, lambda x:k2*x + y0-k2*x0])

p , e = optimize.curve_fit(piecewise_linear, x, y)
xd = np.linspace(0, 15, 100)
plt.plot(x, y, "o")
plt.plot(xd, piecewise_linear(xd, *p))

4

Die von Vito MR Muggeo [1] vorgeschlagene Methode ist relativ einfach und effizient. Es funktioniert für eine bestimmte Anzahl von Segmenten und für eine kontinuierliche Funktion. Die Positionen der Haltepunkte werden iterativ geschätzt, indem für jede Iteration eine segmentierte lineare Regression durchgeführt wird, die Sprünge an den Haltepunkten ermöglicht. Aus den Werten der Sprünge werden die nächsten Haltepunktpositionen abgeleitet, bis es keine Diskontinuität mehr gibt (Sprünge).

"Der Prozess wird bis zu einer möglichen Konvergenz iteriert, die im Allgemeinen nicht garantiert ist."

Insbesondere kann die Konvergenz oder das Ergebnis von der ersten Schätzung der Haltepunkte abhängen.

Dies ist die Methode, die im R Segmented-Paket verwendet wird .

Hier ist eine Implementierung in Python:

import numpy as np
from numpy.linalg import lstsq

ramp = lambda u: np.maximum( u, 0 )
step = lambda u: ( u > 0 ).astype(float)

def SegmentedLinearReg( X, Y, breakpoints ):
    nIterationMax = 10

    breakpoints = np.sort( np.array(breakpoints) )

    dt = np.min( np.diff(X) )
    ones = np.ones_like(X)

    for i in range( nIterationMax ):
        # Linear regression:  solve A*p = Y
        Rk = [ramp( X - xk ) for xk in breakpoints ]
        Sk = [step( X - xk ) for xk in breakpoints ]
        A = np.array([ ones, X ] + Rk + Sk )
        p =  lstsq(A.transpose(), Y, rcond=None)[0] 

        # Parameters identification:
        a, b = p[0:2]
        ck = p[ 2:2+len(breakpoints) ]
        dk = p[ 2+len(breakpoints): ]

        # Estimation of the next break-points:
        newBreakpoints = breakpoints - dk/ck 

        # Stop condition
        if np.max(np.abs(newBreakpoints - breakpoints)) < dt/5:
            break

        breakpoints = newBreakpoints
    else:
        print( 'maximum iteration reached' )

    # Compute the final segmented fit:
    Xsolution = np.insert( np.append( breakpoints, max(X) ), 0, min(X) )
    ones =  np.ones_like(Xsolution) 
    Rk = [ c*ramp( Xsolution - x0 ) for x0, c in zip(breakpoints, ck) ]

    Ysolution = a*ones + b*Xsolution + np.sum( Rk, axis=0 )

    return Xsolution, Ysolution

Beispiel:

import matplotlib.pyplot as plt

X = np.linspace( 0, 10, 27 )
Y = 0.2*X  - 0.3* ramp(X-2) + 0.3*ramp(X-6) + 0.05*np.random.randn(len(X))
plt.plot( X, Y, 'ok' );

initialBreakpoints = [1, 7]
plt.plot( *SegmentedLinearReg( X, Y, initialBreakpoints ), '-r' );
plt.xlabel('X'); plt.ylabel('Y');

Graph

[1]: Muggeo, VM (2003). Schätzung von Regressionsmodellen mit unbekannten Haltepunkten. Statistics in Medicine, 22 (19), 3055-3071.


3

Ich habe das Gleiche gesucht, und leider scheint es derzeit keine zu geben. Einige Vorschläge zur weiteren Vorgehensweise finden Sie in dieser vorherigen Frage .

Alternativ können Sie sich einige R-Bibliotheken ansehen, z. B. segmented, SiZer, strucchange, und wenn dort etwas für Sie funktioniert, versuchen Sie, den R-Code mit rpy2 in Python einzubetten .

Bearbeiten, um einen Link zu Py-Earth hinzuzufügen : "Eine Python-Implementierung von Jerome Friedmans multivariaten adaptiven Regressionssplines".


2

Es gibt einen Blog-Beitrag mit einer rekursiven Implementierung der stückweisen Regression. Diese Lösung passt zur diskontinuierlichen Regression.

Wenn Sie mit dem diskontinuierlichen Modell nicht zufrieden sind und eine kontinuierliche Einstellung wünschen, würde ich vorschlagen, Ihre Kurve auf der Grundlage von kL-förmigen Kurven zu suchen und Lasso für die Sparsamkeit zu verwenden:

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import Lasso
# generate data
np.random.seed(42)
x = np.sort(np.random.normal(size=100))
y_expected = 3 + 0.5 * x + 1.25 * x * (x>0)
y = y_expected + np.random.normal(size=x.size, scale=0.5)
# prepare a basis
k = 10
thresholds = np.percentile(x, np.linspace(0, 1, k+2)[1:-1]*100)
basis = np.hstack([x[:, np.newaxis],  np.maximum(0,  np.column_stack([x]*k)-thresholds)]) 
# fit a model
model = Lasso(0.03).fit(basis, y)
print(model.intercept_)
print(model.coef_.round(3))
plt.scatter(x, y)
plt.plot(x, y_expected, color = 'b')
plt.plot(x, model.predict(basis), color='k')
plt.legend(['true', 'predicted'])
plt.xlabel('x')
plt.ylabel('y')
plt.title('fitting segmented regression')
plt.show()

Dieser Code gibt einen Vektor mit geschätzten Koeffizienten an Sie zurück:

[ 0.57   0.     0.     0.     0.     0.825  0.     0.     0.     0.     0.   ]

Aufgrund des Lasso-Ansatzes ist es spärlich: Das Modell hat genau einen Haltepunkt unter 10 möglichen gefunden. Die Zahlen 0,57 und 0,825 entsprechen 0,5 und 1,25 im wahren DGP. Obwohl sie nicht sehr eng sind, sind die angepassten Kurven:

Bildbeschreibung hier eingeben

Mit diesem Ansatz können Sie den Haltepunkt nicht genau schätzen. Wenn Ihr Datensatz jedoch groß genug ist, können Sie mit anderen kDaten spielen (möglicherweise durch Kreuzvalidierung optimieren) und den Haltepunkt genau genug schätzen.

Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.