Der beste Weg, um ein trainiertes Modell in PyTorch zu speichern?


187

Ich suchte nach alternativen Möglichkeiten, um ein trainiertes Modell in PyTorch zu speichern. Bisher habe ich zwei Alternativen gefunden.

  1. torch.save () zum Speichern eines Modells und torch.load () zum Laden eines Modells.
  2. model.state_dict () zum Speichern eines trainierten Modells und model.load_state_dict () zum Laden des gespeicherten Modells.

Ich bin auf diese Diskussion gestoßen, bei der Ansatz 2 gegenüber Ansatz 1 empfohlen wird.

Meine Frage ist, warum der zweite Ansatz bevorzugt wird. Liegt es nur daran, dass torch.nn- Module diese beiden Funktionen haben und wir ermutigt werden, sie zu verwenden?


2
Ich denke, das liegt daran, dass torch.save () auch alle Zwischenvariablen speichert, wie Zwischenausgaben für die Verwendung der Rückübertragung. Sie müssen jedoch nur die Modellparameter wie Gewicht / Abweichung usw. speichern. Manchmal kann die erstere viel größer sein als die letztere.
Dawei Yang

2
Ich habe getestet torch.save(model, f)und torch.save(model.state_dict(), f). Die gespeicherten Dateien haben die gleiche Größe. Jetzt bin ich verwirrt. Außerdem fand ich die Verwendung von pickle zum Speichern von model.state_dict () extrem langsam. Ich denke, der beste Weg ist die Verwendung, torch.save(model.state_dict(), f)da Sie die Erstellung des Modells übernehmen und die Taschenlampe das Laden der Modellgewichte übernimmt, wodurch mögliche Probleme beseitigt werden. Referenz: diskutieren.pytorch.org/t/saving-torch-models/838/4
Dawei Yang

Scheint, als hätte PyTorch dies in seinem Tutorial-Abschnitt etwas expliziter angesprochen - dort gibt es viele gute Informationen, die in den Antworten hier nicht aufgeführt sind, einschließlich des Speicherns von mehr als einem Modell gleichzeitig und Warmstartmodellen.
whlteXbread

Antworten:


207

Ich habe diese Seite auf ihrem Github-Repo gefunden. Ich füge den Inhalt einfach hier ein.


Empfohlener Ansatz zum Speichern eines Modells

Es gibt zwei Hauptansätze zum Serialisieren und Wiederherstellen eines Modells.

Der erste (empfohlene) speichert und lädt nur die Modellparameter:

torch.save(the_model.state_dict(), PATH)

Dann später:

the_model = TheModelClass(*args, **kwargs)
the_model.load_state_dict(torch.load(PATH))

Der zweite speichert und lädt das gesamte Modell:

torch.save(the_model, PATH)

Dann später:

the_model = torch.load(PATH)

In diesem Fall sind die serialisierten Daten jedoch an die spezifischen Klassen und die genaue verwendete Verzeichnisstruktur gebunden, sodass sie bei Verwendung in anderen Projekten oder nach einigen schwerwiegenden Refaktoren auf verschiedene Weise beschädigt werden können.


6
Laut @smth wird das Modell standardmäßig neu geladen, um das Modell zu trainieren. Sie müssen also the_model.eval () nach dem Laden manuell aufrufen, wenn Sie es zur Schlussfolgerung laden und das Training nicht fortsetzen.
WillZ

Die zweite Methode gibt stackoverflow.com/questions/53798009/… Fehler unter Windows 10. konnte es nicht lösen
Gulzar

Gibt es eine Option zum Speichern, ohne dass ein Zugriff für die Modellklasse erforderlich ist?
Michael D

Wie können Sie mit diesem Ansatz die * Argumente und ** Warge verfolgen, die Sie für den Lastfall übergeben müssen?
Mariano Kamp

141

Es hängt davon ab, was Sie tun möchten.

Fall 1: Speichern Sie das Modell, um es selbst als Rückschluss zu verwenden : Sie speichern das Modell, stellen es wieder her und ändern das Modell in den Bewertungsmodus. Dies geschieht , weil Sie in der Regel BatchNormund DropoutSchichten , die in Zug - Modus auf dem Bau von Standard sind:

torch.save(model.state_dict(), filepath)

#Later to restore:
model.load_state_dict(torch.load(filepath))
model.eval()

Fall 2: Modell speichern, um das Training später fortzusetzen : Wenn Sie das Modell, das Sie speichern möchten, weiter trainieren müssen, müssen Sie mehr als nur das Modell speichern. Sie müssen auch den Status des Optimierers, der Epochen, der Partitur usw. speichern. Sie würden dies folgendermaßen tun:

state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    ...
}
torch.save(state, filepath)

Um das Training fortzusetzen, würden Sie Dinge tun wie: state = torch.load(filepath) und dann, um den Zustand jedes einzelnen Objekts wiederherzustellen, so etwas wie:

model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])

Da Sie das Training wieder aufnehmen, Rufen model.eval()Sie NICHT an , wenn Sie den Status beim Laden wiederherstellen, da Sie das wieder aufnehmen.

Fall 3: Modell, das von einer anderen Person ohne Zugriff auf Ihren Code verwendet werden soll : In Tensorflow können Sie eine .pbDatei erstellen , die sowohl die Architektur als auch die Gewichte des Modells definiert. Dies ist besonders bei der Verwendung sehr praktischTensorflow serve . Der äquivalente Weg, dies in Pytorch zu tun, wäre:

torch.save(model, filepath)

# Then later:
model = torch.load(filepath)

Dieser Weg ist immer noch nicht kugelsicher und da Pytorch immer noch viele Änderungen durchläuft, würde ich es nicht empfehlen.


1
Gibt es eine empfohlene Dateiende für die 3 Fälle? Oder ist es immer .pth?
Verena Haunschmid

1
In Fall 3 wird torch.loadnur ein OrderedDict zurückgegeben. Wie erhalten Sie das Modell, um Vorhersagen zu treffen?
Alber8295

Hallo, darf ich wissen, wie man den erwähnten "Fall 2: Modell speichern, um das Training später fortzusetzen" ausführt? Ich habe es geschafft, den Prüfpunkt in das Modell zu laden, dann konnte ich das Modell wie "model.to (device) model = train_model_epoch (Modell, Kriterium, Optimierer, Zeitplan, Epochen)" nicht ausführen oder nicht mehr trainieren
dnez

1
Hallo, für den Fall, dass es sich um einen Rückschluss handelt, heißt es im offiziellen Pytorch-Dokument, dass der Optimierer state_dict entweder für den Rückschluss oder für den Abschluss des Trainings gespeichert werden muss. "Wenn Sie einen allgemeinen Prüfpunkt speichern, um ihn entweder für den Rückschluss oder die Wiederaufnahme des Trainings zu verwenden, müssen Sie mehr als nur den state_dict des Modells speichern. Es ist wichtig, auch den state_dict des Optimierers zu speichern, da dieser Puffer und Parameter enthält, die beim Modellzug aktualisiert werden . "
Mohammed Awney

In Fall 3 sollte die Modellklasse irgendwo definiert werden.
Michael D

10

Die pickle Python-Bibliothek implementiert binäre Protokolle zum Serialisieren und De-Serialisieren eines Python-Objekts.

Wenn Sie import torch(oder wenn Sie PyTorch verwenden), wird es import picklefür Sie und Sie müssen nicht anrufen pickle.dump()undpickle.load() direkt, was die Methoden zu speichern und das Objekt zu laden.

In der Tat torch.save()und torch.load()wird wickeln pickle.dump()und pickle.load()für Sie.

EIN state_dict andere Antwort verdient nur noch ein paar Anmerkungen.

Was state_dicthaben wir in PyTorch? Es gibt tatsächlich zweistate_dict s.

Das PyTorch-Modell torch.nn.Modulemuss model.parameters()aufgerufen werden, um lernbare Parameter (w und b) zu erhalten. Diese lernbaren Parameter werden nach dem Zufallsprinzip im Laufe der Zeit aktualisiert, sobald wir lernen. Lernbare Parameter sind die erstenstate_dict .

Das zweite state_dictist das Optimierungsstatus-Diktat. Sie erinnern sich, dass der Optimierer verwendet wird, um unsere lernbaren Parameter zu verbessern. Aber der Optimiererstate_dict ist behoben. Dort gibt es nichts zu lernen.

Da es sich bei state_dictObjekten um Python-Wörterbücher handelt, können sie einfach gespeichert, aktualisiert, geändert und wiederhergestellt werden, wodurch PyTorch-Modelle und -Optimierer erheblich modularisiert werden.

Lassen Sie uns ein super einfaches Modell erstellen, um dies zu erklären:

import torch
import torch.optim as optim

model = torch.nn.Linear(5, 2)

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print("Model weight:")    
print(model.weight)

print("Model bias:")    
print(model.bias)

print("---")
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Dieser Code gibt Folgendes aus:

Model's state_dict:
weight   torch.Size([2, 5])
bias     torch.Size([2])
Model weight:
Parameter containing:
tensor([[ 0.1328,  0.1360,  0.1553, -0.1838, -0.0316],
        [ 0.0479,  0.1760,  0.1712,  0.2244,  0.1408]], requires_grad=True)
Model bias:
Parameter containing:
tensor([ 0.4112, -0.0733], requires_grad=True)
---
Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [140695321443856, 140695321443928]}]

Beachten Sie, dass dies ein Minimalmodell ist. Sie können versuchen, einen Stapel von sequentiellen hinzuzufügen

model = torch.nn.Sequential(
          torch.nn.Linear(D_in, H),
          torch.nn.Conv2d(A, B, C)
          torch.nn.Linear(H, D_out),
        )

Beachten Sie, dass nur Schichten mit lernbaren Parametern (Faltungsschichten, lineare Schichten usw.) und registrierten Puffern (Batchnorm-Schichten) Einträge in den Modellen haben state_dict.

Nicht lernbare Dinge gehören zum Optimierungsobjekt state_dict , das Informationen über den Status des Optimierers sowie die verwendeten Hyperparameter enthält.

Der Rest der Geschichte ist der gleiche; in der Inferenzphase (dies ist eine Phase, in der wir das Modell nach dem Training verwenden) zur Vorhersage; Wir sagen voraus, basierend auf den Parametern, die wir gelernt haben. Für die Schlussfolgerung müssen wir nur die Parameter speichern model.state_dict().

torch.save(model.state_dict(), filepath)

Und um später model.load_state_dict (torch.load (Dateipfad)) model.eval () zu verwenden

Hinweis: Vergessen Sie nicht die letzte Zeile, die model.eval()nach dem Laden des Modells von entscheidender Bedeutung ist.

Versuchen Sie auch nicht zu speichern torch.save(model.parameters(), filepath). Das model.parameters()ist nur das Generatorobjekt.

Auf der anderen Seite wird torch.save(model, filepath)das Modellobjekt selbst gespeichert. Beachten Sie jedoch, dass das Modell nicht über das Optimierungsobjekt verfügt state_dict. Überprüfen Sie die andere ausgezeichnete Antwort von @Jadiel de Armas, um das Statusdiktat des Optimierers zu speichern.


Obwohl es keine einfache Lösung ist, wird die Essenz des Problems gründlich analysiert! Upvote.
Jason Young

6

Eine übliche PyTorch-Konvention besteht darin, Modelle mit der Dateierweiterung .pt oder .pth zu speichern.

Ganzes Modell speichern / laden Speichern:

path = "username/directory/lstmmodelgpu.pth"
torch.save(trainer, path)

Belastung:

Die Modellklasse muss irgendwo definiert werden

model = torch.load(PATH)
model.eval()

1

Wenn Sie das Modell speichern und das Training später fortsetzen möchten:

Einzelne GPU: Speichern:

state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Belastung:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

Mehrere GPUs: Speichern

state = {
        'epoch': epoch,
        'state_dict': model.module.state_dict(),
        'optimizer': optimizer.state_dict(),
}
savepath='checkpoint.t7'
torch.save(state,savepath)

Belastung:

checkpoint = torch.load('checkpoint.t7')
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

#Don't call DataParallel before loading the model otherwise you will get an error

model = nn.DataParallel(model) #ignore the line if you want to load on Single GPU
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.