torch.autograd.Function
torch.autograd.Function用于自定义网络层,自定义前向传播和反向传播
使用到torch.autograd.Function
是因为在无线通信中需要把网络输出的float32量化成bit或者星座点,pytorch没有提供任何函数满足要求,因此需要自定义该层。
Function和Module的差异
Function和Module都可以对Pytorch进行自定义拓展,使其满足网络的需求,但这两者还是有区别的:
- Function一般只定义一个操作,因为其无法保存参数,因此适用于激活函数等操作,Module是保存了参数,因此适用于定义一层,如线性层,卷积层,也适用于定义一个网络;
- Function需要定义三个方法:__init__,forward,backward(需要自己写求导公式)
- Module只需定义__init__和forward,而backward的计算由自动求导机制构成
- 可以不严谨的认为,Module是由一系列Function组成,因此其在forward的过程中,Function和Variable组成了计算图,在backward时,只需调用Function的backward就得到结果,因此Module不需要再定义backward。
- Module不仅包括了Function,还包括了对应的参数,以及其他函数与变量,这是Function所不具备的。
同时实现这两个函数:
- forward():执行这个操作的代码,需要定义前向传播过程,同时可以保存任何在反向传播中需要使用的变量值。输出类型是Tensor,或者是Tensor组成的tuple。假设运算的输入有a个,输出有b个,那么forward()的输入有a个,输出有b个。forward()函数的输入参数第1个是ctx,第2个是input,其他是可选参数。
- backward():计算导数的代码。假设运算的输入a个,输出b个,那么backward()的输入有b个,输出有a个。代表输出对a个输入的导数。
class Exp(torch.autograd.Function): |
注释:
forward()和backward()都应该是staticmethod。
forward()的输入只有2个(ctx, i),ctx必须有,i是input。
ctx.save_for_backward(result)表示forward()的结果要存起来,以后给backward()。
backward()的输入只有2个(ctx, grad_output),ctx必须有,grad_output是最终object对的forward()输出的导数。
result, = ctx.saved_tensors得到之前forward()存的结果。
调用时,直接 **Exp.apply(input)**即可。
本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 许可协议。转载请注明来自 JrunDing!
评论