模型的保存与加载
PyTorch中的保存(序列化,从内存到硬盘)与反序列化(加载,从硬盘到内存)
torch.save主要参数:obj:对象 、f:输出路径
torch.load 主要参数 :f:文件路径 、map_location:指定存放位置、 cpu or gpu
模型的保存的两种方法:
保存整个Module
保存模型参数
state_dict = net.state_dict() torch.save(state_dict , path)
|
模型的训练过程中保存
checkpoint = { "net": model.state_dict(), 'optimizer':optimizer.state_dict(), "epoch": epoch }
|
将网络训练过程中的网络的权重,优化器的权重保存,以及epoch 保存,便于继续训练恢复
在训练过程中,可以根据自己的需要,每多少代,或者多少epoch保存一次网络参数,便于恢复,提高程序的鲁棒性。
checkpoint = { "net": model.state_dict(), 'optimizer':optimizer.state_dict(), "epoch": epoch } if not os.path.isdir("./models/checkpoint"): os.mkdir("./models/checkpoint") torch.save(checkpoint, './models/checkpoint/ckpt_best_%s.pth' %(str(epoch)))
|
通过上述的过程可以在训练过程自动在指定位置创建文件夹,并保存断点文件
模型的断点继续训练
if RESUME: path_checkpoint = "./models/checkpoint/ckpt_best_1.pth" checkpoint = torch.load(path_checkpoint) model.load_state_dict(checkpoint['net']) optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch']
|
指出这里的是否继续训练,及训练的checkpoint的文件位置等可以通过argparse从命令行直接读取,也可以通过log文件直接加载,也可以自己在代码中进行修改。
重点在于epoch的恢复
start_epoch = -1
if RESUME: path_checkpoint = "./models/checkpoint/ckpt_best_1.pth" checkpoint = torch.load(path_checkpoint)
model.load_state_dict(checkpoint['net'])
optimizer.load_state_dict(checkpoint['optimizer']) start_epoch = checkpoint['epoch']
for epoch in range(start_epoch + 1 ,EPOCH): for step, (b_img,b_label) in enumerate(train_loader): train_output = model(b_img) loss = loss_func(train_output,b_label) optimizer.zero_grad() loss.backward() optimizer.step()
|
通过定义start_epoch变量来保证继续训练的时候epoch不会变化