0
本质上,它将保留任何必要的信息以计算某个变量,以便我们可以对其进行反向传递。 举例说明 假设我们有一个如上所示的计算图。变量d和e是输出,a是输入。例如,
import torch
from torch.autograd import Variable
a = Variable(torch.rand(1, 4), requires_grad=True)
b = a**2
c = b*2
d = c.mean()
e = c.sum()
当我们运行d.backward()时,将很好地进行。在此计算之后,默认情况下,计算d的图形部分将被释放以保存内存。 因此,如果我们运行e.backward(),错误信息将弹出。为了运行e.backward(),我们必须在d.backward()中将参数retain_graph设置为True,即,
d.backward(retain_graph=True)
只要在 backward方法中设置retain_graph=True,就可以在任何时候进行 backward 操作:
d.backward(retain_graph=True) # fine
e.backward(retain_graph=True) # fine
d.backward() # also fine
e.backward() # error will occur!
一个实际使用的例子 有一个实际使用的例子是多任务学习,有多个损失,它们可能在不同的层。 假设有2个损失:损失1和损失2,它们在不同的层中。 为了将损失1和损失2的梯度反向传播给网络的可学习权重。 在第一个反向传播损失中,必须在backward() 方法中设置retain_graph=True。
# suppose you first back-propagate loss1, then loss2 (you can also do the reverse)
loss1.backward(retain_graph=True)
loss2.backward() # now the graph is freed, and next process of batch gradient descent is ready
optimizer.step() # update the network parameters
收藏