参数初始化(Weight Initialization)
PyTorch中参数的默认初始化在各个层的reset_parameters()方法中。例如:nn.Linear和nn.Conv2D,都是在[-limit, limit]之间的均匀分布(Uniform distribution),其中limit是1. / sqrt(fan_in),fan_in是指参数张量(tensor)的输入单元的数量
pytorch在定义模型时有默认的参数初始化,有时候我们需要自定义参数的初始化,就需要用到torch.nn.init。具体的不同初始化,可以查看pytorch官方文档https://link.zhihu.com/?target=https%3A//pytorch.org/docs/stable/nn.init.html%3Fhighlight%3Dinit
方法一:
1,先定义初始化模型方法;
2,运用apply().
class Net(nn.Module): def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim): super().__init__()
self.layer = nn.Sequential( nn.Linear(in_dim, n_hidden_1), nn.ReLU(True), nn.Linear(n_hidden_1, n_hidden_2), nn.ReLU(True), nn.Linear(n_hidden_2, out_dim) ) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) return x def weight_init(m): if isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0)
model = Net(in_dim, n_hidden_1, n_hidden_2, out_dim)
model.apply(weight_init)
|
方法二:
定义在模型中,利用self.modules()来进行循环
class Net(nn.Module): def __init__(self, in_dim, n_hidden_1, n_hidden_2, out_dim): super().__init__()
self.layer = nn.Sequential( nn.Linear(in_dim, n_hidden_1), nn.ReLU(True), nn.Linear(n_hidden_1, n_hidden_2), nn.ReLU(True), nn.Linear(n_hidden_2, out_dim) ) for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) return x
|
在torch.nn.init中的各种初始化方法中,如nn.init.constant_(m.weight, 1), nn.init.constant_(m.bias, 0)中第一个参数是tensor,也就是对应的参数。在方法二中,需要了解self.modules()和self.children()的区别,可以见https://link.zhihu.com/?target=https%3A//discuss.pytorch.org/t/module-children-vs-module-modules/4551/6
REFERENCE
https://zhuanlan.zhihu.com/p/188701989
https://blog.csdn.net/ys1305/article/details/94332007?spm=1001.2014.3001.5506