参数初始化(Weight Initialization)

​ PyTorch中参数的默认初始化在各个层的reset_parameters()方法中。例如:nn.Linear和nn.Conv2D,都是在[-limit, limit]之间的均匀分布(Uniform distribution),其中limit是1. / sqrt(fan_in),fan_in是指参数张量(tensor)的输入单元的数量

​ pytorch在定义模型时有默认的参数初始化,有时候我们需要自定义参数的初始化,就需要用到torch.nn.init。具体的不同初始化,可以查看pytorch官方文档https://link.zhihu.com/?target=https%3A//pytorch.org/docs/stable/nn.init.html%3Fhighlight%3Dinit

​ 方法一:

​ 1,先定义初始化模型方法;

​ 2,运用apply().

class Net(nn.Module):


def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super().__init__()

self.layer = nn.Sequential(
nn.Linear(in_dim, n_hidden_1),
nn.ReLU(True),
nn.Linear(n_hidden_1, n_hidden_2),
nn.ReLU(True),
nn.Linear(n_hidden_2, out_dim)

)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x

# 1. 根据网络层的不同定义不同的初始化方式
def weight_init(m):
if isinstance(m, nn.Linear):
nn.init.xavier_normal_(m.weight)
nn.init.constant_(m.bias, 0)
# 也可以判断是否为conv2d,使用相应的初始化方式
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
# 是否为批归一化层
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# 2. 初始化网络结构
model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim)
# 3. 将weight_init应用在子模块上
model.apply(weight_init)
#torch中的apply函数通过可以不断遍历model的各个模块。实际上其使用的是深度优先算法

​ 方法二:

​ 定义在模型中,利用self.modules()来进行循环

class Net(nn.Module):
def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim):
super().__init__()

self.layer = nn.Sequential(
nn.Linear(in_dim, n_hidden_1),
nn.ReLU(True),
nn.Linear(n_hidden_1, n_hidden_2),
nn.ReLU(True),
nn.Linear(n_hidden_2, out_dim)
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
return x

​ 在torch.nn.init中的各种初始化方法中,如nn.init.constant_(m.weight, 1), nn.init.constant_(m.bias, 0)中第一个参数是tensor,也就是对应的参数。在方法二中,需要了解self.modules()和self.children()的区别,可以见https://link.zhihu.com/?target=https%3A//discuss.pytorch.org/t/module-children-vs-module-modules/4551/6

REFERENCE

https://zhuanlan.zhihu.com/p/188701989

https://blog.csdn.net/ys1305/article/details/94332007?spm=1001.2014.3001.5506