背景

​ JSCC(Joint Source Channel Coding)联合信源信道编解码主要使用DNN实现E2E的通信系统,为了适应信道,就需要将信道嵌入到网络中训练。如果使用信道冲激响应复数表示信道,那么势必要在网络中涉及复数运算。

​ 而我们知道目前的深度学习框架都不支持复数微分,这很难解决,因此在JSCC等涉及到将信道嵌入到网络中的情况时,都无法使用复数信道。

解决?

​ 在Pytorch中,之前尝试过解决这个问题,使用with torch.no_grad()或者.detach()将复数运算脱离网络,然后再将计算结果通过requires_grad添加到计算图中,但实际效果非常差,这或许完全不能解决问题,因此想摸清原因。

原因

​ 要想知道原因,这就需要知道with torch.no_grad().detach()的工作机制,分别进行介绍。

requires_grad熟悉

官网说:If autograd should record operations on the returned tensor. Default: False.

​ 是否追踪在张量上计算的所有操作,默认值为False
​ 什么意思?直接上代码测试一下吧!

# 测试一些什么都不做,查看计算的梯度
import torch

x = torch.tensor([1.0, 2.0])
y1 = x ** 2
y2 = y1 * 2
y3 = y1 + y2

print(y1, y1.requires_grad)
print(y2, y2.requires_grad)
print(y3, y3.requires_grad)

# 为什么backward里面需要加一个torch.ones(y3.shape)?
# 这是另外一个需要讨论的问题了可以在留言区一起讨论
y3.backward(torch.ones(y3.shape)) # y1.backward() y2.backward()
print(x.grad)
# 结果:
tensor([1., 4.]) False
tensor([2., 8.]) False
tensor([ 3., 12.]) False
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

​ X的requires_grad设置为True之后则如下:

# 设置好requires_grad的值为True
import torch

x = torch.tensor([1.0, 2.0], requires_grad=True)
y1 = x ** 2
y2 = y1 * 2
y3 = y1 + y2

print(y1, y1.requires_grad)
print(y2, y2.requires_grad)
print(y3, y3.requires_grad)

y3.backward(torch.ones(y3.shape)) # y1.backward() y2.backward()
print(x.grad)

"""
结果:
tensor([1., 4.], grad_fn=<PowBackward0>) True
tensor([2., 8.], grad_fn=<MulBackward0>) True
tensor([ 3., 12.], grad_fn=<AddBackward0>) True
tensor([ 6., 12.])
"""

​ 此时的y1、y2、y3输出都多了一个属性参数值:
​ 例如:y1的grad_fn = <PowBackward0>,就表示y1的上一次计算操作为pow,即指数运算
​ 再回到我们的y1 = x ** 2 ,果然,正是如此。

结论:

  1. 当grad_fn设置为Fasle或者默认时:计算梯度会出现如下错误
    RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
    因为并没有追踪到任何计算历史,所以就不存在梯度的计算了

  2. 因此在最开始定义x张量的时候,就应当设置好是否计算追踪历史计算记录

    detach()方法

官网又说:Returns a new Tensor, detached from the current graph.The result will never require gradient.

​ 就是返回了一个新的张量,该张量与当前计算图完全分离。且该张量的计算将不会记录到梯度当中。

​ 上代码看看那啥意思吧!

# 设置好requires_grad的值为True
import torch

x = torch.tensor([1.0, 2.0], requires_grad=True)
y1 = x ** 2
y2 = y1.detach() * 2 # 注意这里在计算y2的时候对y1进行了detach()
y3 = y1 + y2

print(y1, y1.requires_grad)
print(y2, y2.requires_grad)
print(y3, y3.requires_grad)

y3.backward(torch.ones(y3.shape)) # y1.backward() y2.backward()
print(x.grad)
# 结果
tensor([1., 4.], grad_fn=<PowBackward0>) True
tensor([2., 8.]) False
tensor([ 3., 12.], grad_fn=<AddBackward0>) True
tensor([2., 4.])

"""
根据结果可知y2所计算出来的张量,grad_fn属性没有被输出(其实为None),即不具有追踪能力了
而y1和y3都仍然显示出各自上一次的计算操作,但是最终计算出的x的梯度发生了变化
"""

​ 对比一下使用detach()前后的梯度值tensor([ 6., 12.])tensor([2., 4.])

​ tensor([ 6., 12.])

y3 = y2 + y1,根据 y2 = y1*2, 而y1 = x ** 2
所以y3 = 3x**2, y3对xi的偏导则为6xi
针对x = [1, 2]
所以,对应的梯度(偏导)则为:[6, 12]

​ tensor([ 2., 4.])

y3 = y2 + y1,因为y2是根据y1.detach()得到的;
根据定义,所以计算梯度的时候不考虑y2,但是实际计算y3的值还是按原公式
因此计算梯度时。y3 = y1 + (y2不考虑),所以y3 = x ** 2
y3对xi的偏导则为2xi
针对x = [1, 2]
所以,对应的梯度(偏导)则为:[2, 4]

​ 总结一下detach()吧:
​ 当我们在计算到某一步时,不需要在记录某一个张量的时,就可以使用detach()将其从追踪记录当中分离出来,这样一来该张量对应计算产生的梯度就不会被考虑了。

with torch.no_grad()

官方还说:Disabling gradient calculation is useful for inference, when you are sure that you will not call :meth:Tensor.backward(). It will reduce memory consumption for computations that would otherwise have requires_grad=True.

​ 先理解为也是类似取消梯度计算的一种方式,可以减少内存消耗,还是看代码结果吧!

# 设置好requires_grad的值为True
import torch

x = torch.tensor([1.0, 2.0], requires_grad=True)
y1 = x ** 2

with torch.no_grad(): # 这里使用了no_grad()包裹不需要被追踪的计算过程
y2 = y1 * 2

y3 = y1 + y2

print(y1, y1.requires_grad)
print(y2, y2.requires_grad)
print(y3, y3.requires_grad)

y3.backward(torch.ones(y3.shape)) # y1.backward() y2.backward()
print(x.grad)

​ 计算结果:

tensor([1., 4.], grad_fn=<PowBackward0>) True
tensor([2., 8.]) False
tensor([ 3., 12.], grad_fn=<AddBackward0>) True
tensor([2., 4.])

"""
结果和detach()方法一致,就不在分析了
"""

​ 可想而知,实际上torch.no_grad()功能和detach()方法作用是一致的。
​ 有差区别?
detach()是考虑将单个张量从追踪记录当中脱离出来;
​ 而torch.no_grad()是一个warper,可以将多个计算步骤的张量计算脱离出去,本质上没啥区别。

总结

  • requires_grad:在最开始创建Tensor时候可以设置的属性,用于表明是否追踪当前Tensor的计算操作。后面也可以通过requires_grad_()方法设置该参数,但是只有叶子节点才可以设置该参数。
  • detach()方法:则是用于将某一个Tensor从计算图中分离出来。返回的是一个内存共享的Tensor,一变都变。
  • torch.no_grad():对所有包裹的计算操作进行分离。
    但是torch.no_grad()将会使用更少的内存,因为从包裹的开始,就表明不需要计算梯度了,因此就不需要保存中间结果

其实detach()torch.no_grad()都可以这么理解:torch网络都是在动态计算图中完成,这两个方法都是将变量取出计算图,复制一份到新的地方计算,这个新的变量和之前的变量之间的唯一关系就是由之前变量通过公式计算而来,其他没有任何联系。非计算图中的变量在网络中计算也不会保存梯度,因此如果直接在网络中使用torch.no_grad()则会导致整个梯度直接断连,即从with no_grad开始往后都和前面断开,虽然前面的梯度仍然保存,但是和后面再恢复出来的梯度毫无关系,这就直接变成两个网络,梯度下降时,也是两个完全独立的网络分别在计算梯度和更新参数。如果使用.detach(),只会将被detach的变量分离出来,如果前后网络之间还有某个变量连接,那么仍然是统一的网络,只不过梯度计算会受到影响。在JSCC中,如果是复数信道,因为每个输入都需要和复数信道作用,因此我们必须将所有输入信号都detach,这就导致前后网络一定不再有任何连接,因此这种方式实现JSCC是完全不行的,而在网络中嵌入信道也只能使用考虑实信道。