timm
1 什么是timm库?
timm 是 PyTorch Image Models 的缩写 is a collection of SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations and also training/validating scripts with ability to reproduce ImageNet training results.
timm 库实现了最新的几乎所有的具有影响力的视觉模型,它不仅提供了模型的权重,还提供了一个很棒的分布式训练和评估的代码框架,方便后人开发。timm 库是由 Ross Wightman 开发和维护的
项目地址:https://github.com/huggingface/
2 最快上手使用一个timm模型
问:使用timm搭建一个可以使用的CNN或ViT共需要几步?
答:4步
0.安装 timm
1.import timm
2.创建model
3.运行model
1)安装、导入
conda imstall timm |
import torch |
2)创建、使用模型
创建模型的最简单方法是使用create_model
这是一个可用于在 timm 库中创建任何模型的函数
这个函数各个参数有什么用,内部具体怎么实现的后再讲,先只用它来创建一个CNN用来做分类任务
model_resnet34 = timm.create_model('resnet34', pretrained=True) |
resnet34
是模型架构的名字
pretrained=True则会自动从网上下载训练好的模型权重加载到resnet34上
然后模型就创建好了,可以直接使用了
这里我们使用随机张量表示图像
torch.randn:用来生成随机数字的tensor,这些随机数字满足标准正态分布(0~1)
x = torch.randn([1, 1, 224, 224]) #创建一个tensor 代表 一张3x224x224的图片 |
我们可以看到模型已经处理了图像并返回了预期的输出形状
3)查看模型信息
那么怎么知道timm都可以导入哪些模型来使用呢?
model_list = timm.list_models() #返回一个包含所有模型名称的list |
4)调整模型-创建适合自己的模型
直接导入训练好的模型并不是万能的,经常会有维度不匹配的情况 比如说我的resnet34模型在cifar10和imagenet两个数据集上进行训练,分类类别不一样,输入的图片大小不一样,那我应该怎么创建合适的模型呢?
改变输出类别数目
分类类别数量:num_classes
model的主体提取特征,之后往往会接一个mlp层用作分类
如果设置num_classes,表示重设全连接层,num_classes设置为你需要分类的类别数量即可
import torch |
改变输入通道数
输入通道数:in_chans
对图片的大小,可以在输入model之前进行resize处理到统一大小
但是如果输入的图片不是传统rgb图片,通道不是3怎么办
当然,我们可以复制单通道像素来创建3通道图像,从而将其单通道输入图像转换为3通道图像。但是对于timm,他有一套申请的参数加载模式,我们可以直接改变in_chans 来指定输入图像的通道数
通道数改变后,对应的权重参数会进行相应的处理,此处不作详细说明 可参照:https://fastai.github.io/timmdocs/models或直接查看源代码
x = torch.randn([1, 1, 224, 224]) |
3 特性
用timm有什么好处吗,下面是一些功能特性
所有model都有一个通用的默认配置接口和API
所有模型都支持通过create_model提取中间特征(vit除外)
所有型号都有一个预训练权重加载器,可调整最后一个线性层,也可调整3通道输入为1个通道输入
并且我们还可以直接复用timm的功能模块或者一些训练tricks(Learning rate schedulers/Optimizers/Augment),简直是百宝箱