模型的保存与加载

PyTorch中的保存(序列化,从内存到硬盘)与反序列化(加载,从硬盘到内存)

torch.save主要参数:obj:对象 、f:输出路径

torch.load 主要参数 :f:文件路径 、map_location:指定存放位置、 cpu or gpu

模型的保存的两种方法:

保存整个Module

torch.save(net, path)

保存模型参数

state_dict = net.state_dict()
torch.save(state_dict , path)

模型的训练过程中保存

checkpoint = {
"net": model.state_dict(),
'optimizer':optimizer.state_dict(),
"epoch": epoch
}

​ 将网络训练过程中的网络的权重,优化器的权重保存,以及epoch 保存,便于继续训练恢复

​ 在训练过程中,可以根据自己的需要,每多少代,或者多少epoch保存一次网络参数,便于恢复,提高程序的鲁棒性。

checkpoint = {
"net": model.state_dict(),
'optimizer':optimizer.state_dict(),
"epoch": epoch
}
if not os.path.isdir("./models/checkpoint"):
os.mkdir("./models/checkpoint")
torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' %(str(epoch)))

​ 通过上述的过程可以在训练过程自动在指定位置创建文件夹,并保存断点文件

模型的断点继续训练

if RESUME:    
path_checkpoint = "./models/checkpoint/ckpt_best_1.pth" # 断点路径
checkpoint = torch.load(path_checkpoint) # 加载断点
model.load_state_dict(checkpoint['net']) # 加载模型可学习参数
optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
start_epoch = checkpoint['epoch'] # 设置开始的epoch

指出这里的是否继续训练,及训练的checkpoint的文件位置等可以通过argparse从命令行直接读取,也可以通过log文件直接加载,也可以自己在代码中进行修改。

重点在于epoch的恢复

start_epoch = -1


if RESUME:
path_checkpoint = "./models/checkpoint/ckpt_best_1.pth" # 断点路径
checkpoint = torch.load(path_checkpoint) # 加载断点

model.load_state_dict(checkpoint['net']) # 加载模型可学习参数

optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
start_epoch = checkpoint['epoch'] # 设置开始的epoch



for epoch in range(start_epoch + 1 ,EPOCH):
# print('EPOCH:',epoch)
for step, (b_img,b_label) in enumerate(train_loader):
train_output = model(b_img)
loss = loss_func(train_output,b_label)
# losses.append(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()

​ 通过定义start_epoch变量来保证继续训练的时候epoch不会变化