Pytorch: Richtige Verwendung benutzerdefinierter Gewichtskarten in unet-Architekturen


8

Es gibt einen berühmten Trick in der U-Net-Architektur, benutzerdefinierte Gewichtskarten zu verwenden, um die Genauigkeit zu erhöhen.

Geben Sie hier die Bildbeschreibung ein

Wenn ich hier und an mehreren anderen Stellen frage, lerne ich zwei Ansätze kennen. Ich möchte wissen, welcher richtig ist, oder gibt es einen anderen richtigen Ansatz, der korrekter ist?

1) Zuerst ist die torch.nn.FunctionalMethode in der Trainingsschleife zu verwenden.

loss = torch.nn.functional.cross_entropy(output, target, w) Dabei ist w das berechnete benutzerdefinierte Gewicht.

2) Zweitens ist reduction='none'beim Aufrufen der Verlustfunktion außerhalb der Trainingsschleife zu verwenden criterion = torch.nn.CrossEntropy(reduction='none')

und dann in der Trainingsschleife mit dem benutzerdefinierten Gewicht multiplizieren-

gt # Ground truth, format torch.long
pd # Network output
W # per-element weighting based on the distance map from UNet
loss = criterion(pd, gt)
loss = W*loss # Ensure that weights are scaled appropriately
loss = torch.sum(loss.flatten(start_dim=1), axis=0) # Sums the loss per image
loss = torch.mean(loss) # Average across a batch

Jetzt bin ich ein bisschen verwirrt, welches richtig ist oder gibt es einen anderen Weg oder beide sind richtig?

Antworten:


3

Der Gewichtungsabschnitt sieht aus wie eine einfach gewichtete Kreuzentropie, die für die Anzahl der Klassen (2 im folgenden Beispiel) so ausgeführt wird.

weights = torch.FloatTensor([.3, .7])
loss_func = nn.CrossEntropyLoss(weight=weights)

BEARBEITEN:

Haben Sie diese Implementierung von Patrick Black gesehen?

# Set properties
batch_size = 10
out_channels = 2
W = 10
H = 10

# Initialize logits etc. with random
logits = torch.FloatTensor(batch_size, out_channels, H, W).normal_()
target = torch.LongTensor(batch_size, H, W).random_(0, out_channels)
weights = torch.FloatTensor(batch_size, 1, H, W).random_(1, 3)

# Calculate log probabilities
logp = F.log_softmax(logits)

# Gather log probabilities with respect to target
logp = logp.gather(1, target.view(batch_size, 1, H, W))

# Multiply with weights
weighted_logp = (logp * weights).view(batch_size, -1)

# Rescale so that loss is in approx. same interval
weighted_loss = weighted_logp.sum(1) / weights.view(batch_size, -1).sum(1)

# Average over mini-batch
weighted_loss = -1. * weighted_loss.mean()

Die Sache ist Gewicht wird durch eine bestimmte Funktion hier berechnet und ist nicht diskret. Für weitere Informationen, hier ist ein Papier - arxiv.org/abs/1505.04597
Mark

1
@ Mark oh ich sehe jetzt. Es handelt sich also um eine pixelweise Verlustausgabe. Und die Ränder werden unter Verwendung einer Bibliothek wie opencvoder so vorberechnet , und dann werden diese Pixelpositionen für jedes Bild gespeichert und später während des Trainings mit den Verlusttensoren multipliziert, so dass sich der Algorithmus auf die Reduzierung des Verlusts in diesen Bereichen konzentriert.
Jchaykow

Vielen Dank. Diese Legitimität sieht aus wie eine Antwort. Ich werde versuchen, sie genauer zu überprüfen und umzusetzen, und werde Ihre Antwort danach akzeptieren.
Mark

Können Sie die Intuition hinter dieser Zeile erklärenlogp = logp.gather(1, target.view(batch_size, 1, H, W))
Mark

0

Beachten Sie, dass torch.nn.CrossEntropyLoss () eine Klasse ist, die torch.nn.functional aufruft. Siehe https://pytorch.org/docs/stable/_modules/torch/nn/modules/loss.html#CrossEntropyLoss

Sie können die Gewichte verwenden, wenn Sie die Kriterien definieren. Wenn man sie funktional vergleicht, sind beide Methoden gleich.

Nun, ich verstehe Ihre Idee, Verluste innerhalb der Trainingsschleife in Methode 1 und außerhalb der Trainingsschleife in Methode 2 zu berechnen, nicht. Wenn Sie Verluste außerhalb der Schleife berechnen, wie werden Sie dann rückpropagieren?


Ich war nicht verwirrt zwischen torch.nn.CrossEntropyLoss() und torch.nn.functional.cross_entropy(output, target, w), ich war verwirrt, wie man benutzerdefinierte Gewichtskarten im Verlust verwendet. Bitte lesen Sie dieses Papier - arxiv.org/abs/1505.04597 und lassen Sie mich wissen, wenn Sie immer noch nicht in der Lage sind, herauszufinden, was ich bin Fragen
Mark

1
Wenn ich es richtig verstehe, denke ich, dass Methode 2 die richtige ist. Die Gewichte (w) innerhalb des Verlustbrenners.nn.functional.cross_entropy (Ausgabe, Ziel, w) sind Gewichte für Klassen, die nicht w (x) in der Formel sind. Wir können es leicht mit einem kleinen Skript testen.
Devansh Bisla

Ja, auch ich komme zu dem gleichen Schluss. Ich werde mich wieder bei Ihnen melden, wenn mein Netzwerk wie erwartet läuft und die Antwort als akzeptiert markiert.
Mark

Okay, es funktioniert nicht. grad can be implicitly created only for scalar outputsIch bekomme, wenn ich verlust = verlust * w-Methode laufen lasse
Mark

Sind Sie sicher, dass Sie sie zusammenfassen oder den Mittelwert nehmen?
Devansh Bisla
Durch die Nutzung unserer Website bestätigen Sie, dass Sie unsere Cookie-Richtlinie und Datenschutzrichtlinie gelesen und verstanden haben.
Licensed under cc by-sa 3.0 with attribution required.