torch.autograd.Function用于自定义网络层,自定义前向传播和反向传播

​ 使用到torch.autograd.Function是因为在无线通信中需要把网络输出的float32量化成bit或者星座点,pytorch没有提供任何函数满足要求,因此需要自定义该层。

Function和Module的差异

​ Function和Module都可以对Pytorch进行自定义拓展,使其满足网络的需求,但这两者还是有区别的:

  • Function一般只定义一个操作,因为其无法保存参数,因此适用于激活函数等操作,Module是保存了参数,因此适用于定义一层,如线性层,卷积层,也适用于定义一个网络;
  • Function需要定义三个方法:__init__,forward,backward(需要自己写求导公式)
  • Module只需定义__init__和forward,而backward的计算由自动求导机制构成
  • 可以不严谨的认为,Module是由一系列Function组成,因此其在forward的过程中,Function和Variable组成了计算图,在backward时,只需调用Function的backward就得到结果,因此Module不需要再定义backward。
  • Module不仅包括了Function,还包括了对应的参数,以及其他函数与变量,这是Function所不具备的。

​ 同时实现这两个函数:

  • forward():执行这个操作的代码,需要定义前向传播过程,同时可以保存任何在反向传播中需要使用的变量值。输出类型是Tensor,或者是Tensor组成的tuple。假设运算的输入有a个,输出有b个,那么forward()的输入有a个,输出有b个。forward()函数的输入参数第1个是ctx,第2个是input,其他是可选参数。
  • backward():计算导数的代码。假设运算的输入a个,输出b个,那么backward()的输入有b个,输出有a个。代表输出对a个输入的导数。
class Exp(torch.autograd.Function):

@staticmethod
def forward(ctx, i):
result = i.exp()
ctx.save_for_backward(result) # 表示保存forward的结果,后给backward()使用
return result

@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * result

# Use it by calling the apply method:
output = Exp.apply(input)

注释:
​ forward()和backward()都应该是staticmethod。

​ forward()的输入只有2个(ctx, i),ctx必须有,i是input。
​ ctx.save_for_backward(result)表示forward()的结果要存起来,以后给backward()。

​ backward()的输入只有2个(ctx, grad_output),ctx必须有,grad_output是最终object对的forward()输出的导数。
​ result, = ctx.saved_tensors得到之前forward()存的结果。

​ 调用时,直接 **Exp.apply(input)**即可。