Was ist die Verwendung von torch.no_grad in Pytorch?


18

Ich bin neu in Pytorch und habe mit diesem Github-Code begonnen. Ich verstehe den Kommentar in Zeile 60-61 im Code nicht "because weights have requires_grad=True, but we don't need to track this in autograd". Ich habe verstanden, dass wir requires_grad=Truedie Variablen erwähnen , die wir zur Berechnung der Gradienten für die Verwendung von Autograd benötigen, aber was bedeutet das "tracked by autograd"?

Antworten:


21

Der Wrapper "with torch.no_grad ()" setzt vorübergehend alle require_grad-Flags auf false. Ein Beispiel aus dem offiziellen PyTorch-Tutorial ( https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#gradients) ):

x = torch.randn(3, requires_grad=True)
print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)

Aus:

True
True
False

Ich empfehle Ihnen, alle Tutorials von der obigen Website zu lesen.

In Ihrem Beispiel: Ich denke, der Autor möchte nicht, dass PyTorch die Gradienten der neu definierten Variablen w1 und w2 berechnet, da er nur deren Werte aktualisieren möchte.


5
with torch.no_grad()

Damit haben alle Operationen im Block keine Farbverläufe.

In Pytorch können Sie nicht die Inplacement-Änderung von w1 und w2 durchführen, bei denen es sich um zwei Variablen handelt require_grad = True . Ich denke, dass das Vermeiden der Inplacement-Änderung von w1 und w2 darauf zurückzuführen ist, dass es Fehler bei der Berechnung der Rückwärtsausbreitung verursacht. Da Inplacement-Änderung wird w1 und w2 total ändern.

Wenn Sie dies verwenden no_grad(), können Sie jedoch steuern, dass das neue w1 und das neue w2 keine Verläufe aufweisen, da sie durch Operationen generiert werden. Dies bedeutet, dass Sie nur den Wert von w1 und w2 ändern und nicht den Verlaufsteil. Sie haben weiterhin zuvor definierte variable Verlaufsinformationen und die Rückausbreitung kann fortgesetzt werden.

Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.