参数keep_graph在变量的backward()方法中是什么意思?
0 1257
0

我正在阅读pytorch教程,并对retain_variable(不推荐使用,现在称为retain_graph)的用法感到困惑。代码示例如下:

class ContentLoss(nn.Module):

    def __init__(self, target, weight):
        super(ContentLoss, self).__init__()
        self.target = target.detach() * weight
        self.weight = weight
        self.criterion = nn.MSELoss()

    def forward(self, input):
        self.loss = self.criterion(input * self.weight, self.target)
        self.output = input
        return self.output

    def backward(self, retain_variables=True):
        #Why is retain_variables True??
        self.loss.backward(retain_variables=retain_variables)
        return self.loss

文档中说明 keep_graph(布尔型,可选)–如果为False,将释放用于计算grad的图形。请注意,几乎在所有情况下都不需要将此选项设置为True,并且通常可以用更有效的方式解决它。默认为create_graph的值。

因此,通过设置retain_graph= True,我们不会在向后传递时释放分配给图的内存。保留此内存有什么好处,为什么我们需要它?

收藏
2021-02-22 10:13 更新 anna •  5042
共 1 个回答
高赞 时间
0

本质上,它将保留任何必要的信息以计算某个变量,以便我们可以对其进行反向传递。 举例说明 假设我们有一个如上所示的计算图。变量d和e是输出,a是输入。例如,

import torch
from torch.autograd import Variable
a = Variable(torch.rand(1, 4), requires_grad=True)
b = a**2
c = b*2
d = c.mean()
e = c.sum()

当我们运行d.backward()时,将很好地进行。在此计算之后,默认情况下,计算d的图形部分将被释放以保存内存。 因此,如果我们运行e.backward(),错误信息将弹出。为了运行e.backward(),我们必须在d.backward()中将参数retain_graph设置为True,即,

d.backward(retain_graph=True)

只要在 backward方法中设置retain_graph=True,就可以在任何时候进行 backward 操作:

d.backward(retain_graph=True) # fine
e.backward(retain_graph=True) # fine
d.backward() # also fine
e.backward() # error will occur!

一个实际使用的例子 有一个实际使用的例子是多任务学习,有多个损失,它们可能在不同的层。 假设有2个损失:损失1和损失2,它们在不同的层中。 为了将损失1和损失2的梯度反向传播给网络的可学习权重。 在第一个反向传播损失中,必须在backward() 方法中设置retain_graph=True。

# suppose you first back-propagate loss1, then loss2 (you can also do the reverse)
loss1.backward(retain_graph=True)
loss2.backward() # now the graph is freed, and next process of batch gradient descent is ready
optimizer.step() # update the network parameters

Via:https://stackoverflow.com/a/47174709/14964791

收藏
2021-02-22 11:36 更新 karry •  4540