1 什么是timm库?

​ timm 是 PyTorch Image Models 的缩写 is a collection of SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations and also training/validating scripts with ability to reproduce ImageNet training results.

​ timm 库实现了最新的几乎所有的具有影响力的视觉模型,它不仅提供了模型的权重,还提供了一个很棒的分布式训练和评估的代码框架,方便后人开发。timm 库是由 Ross Wightman 开发和维护的

​ 项目地址:https://github.com/huggingface/

2 最快上手使用一个timm模型

​ 问:使用timm搭建一个可以使用的CNN或ViT共需要几步?

​ 答:4步

​ 0.安装 timm

​ 1.import timm

​ 2.创建model

​ 3.运行model

1)安装、导入

conda imstall timm
#or
pip install timm
#or
git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .
import torch
import timm

2)创建、使用模型

​ 创建模型的最简单方法是使用create_model

​ 这是一个可用于在 timm 库中创建任何模型的函数

​ 这个函数各个参数有什么用,内部具体怎么实现的后再讲,先只用它来创建一个CNN用来做分类任务

model_resnet34 = timm.create_model('resnet34', pretrained=True)

resnet34是模型架构的名字

​ pretrained=True则会自动从网上下载训练好的模型权重加载到resnet34上

​ 然后模型就创建好了,可以直接使用了

​ 这里我们使用随机张量表示图像

​ torch.randn:用来生成随机数字的tensor,这些随机数字满足标准正态分布(0~1)

x = torch.randn([1, 1, 224, 224]) #创建一个tensor 代表 一张3x224x224的图片
model_resnet50.eval()
with torch.no_grad():
out = model_resnet50(x) #out就是x所对应的表示类别的一个tensor
print(out.shape) # Results: torch.Size([1, 1000])代表1000个类别

​ 我们可以看到模型已经处理了图像并返回了预期的输出形状

3)查看模型信息

​ 那么怎么知道timm都可以导入哪些模型来使用呢?

model_list = timm.list_models() #返回一个包含所有模型名称的list
print(len(model_list)) #964
pretrain_model_list = timm.list_models(pretrained = True)#筛选出带预训练模型的
print(len(pretrain_model_list))#770
##使用通配符字符串来列出可用的不同 ResNet 变体
resnet_model_list = timm.list_models('*resnet*')
pretrain_resnet_model_list = timm.list_models('*resnet*' , pretrained = True)

4)调整模型-创建适合自己的模型

​ 直接导入训练好的模型并不是万能的,经常会有维度不匹配的情况 比如说我的resnet34模型在cifar10和imagenet两个数据集上进行训练,分类类别不一样,输入的图片大小不一样,那我应该怎么创建合适的模型呢?

改变输出类别数目

​ 分类类别数量:num_classes

​ model的主体提取特征,之后往往会接一个mlp层用作分类

​ 如果设置num_classes,表示重设全连接层,num_classes设置为你需要分类的类别数量即可

import torch
x = torch.randn([1, 3, 224, 224])
model_resnet34_out10 = timm.create_model('resnet34', pretrained=True, num_classes=10)
out = model_resnet34_out10 (x)
print(out.shape) # Results: torch.Size([1, 10])

改变输入通道数

​ 输入通道数:in_chans

​ 对图片的大小,可以在输入model之前进行resize处理到统一大小

​ 但是如果输入的图片不是传统rgb图片,通道不是3怎么办

​ 当然,我们可以复制单通道像素来创建3通道图像,从而将其单通道输入图像转换为3通道图像。但是对于timm,他有一套申请的参数加载模式,我们可以直接改变in_chans 来指定输入图像的通道数

​ 通道数改变后,对应的权重参数会进行相应的处理,此处不作详细说明 可参照:https://fastai.github.io/timmdocs/models或直接查看源代码

x = torch.randn([1, 1, 224, 224])
model_resnet34_in1 = timm.create_model('resnet50',pretrained=True, in_chans=1)

3 特性

用timm有什么好处吗,下面是一些功能特性

​ 所有model都有一个通用的默认配置接口和API

​ 所有模型都支持通过create_model提取中间特征(vit除外)

​ 所有型号都有一个预训练权重加载器,可调整最后一个线性层,也可调整3通道输入为1个通道输入

​ 并且我们还可以直接复用timm的功能模块或者一些训练tricks(Learning rate schedulers/Optimizers/Augment),简直是百宝箱