Pytorch classification
利用pytorch实现图像分类的一个完整的代码,训练,预测,TTA,模型融合,模型部署,cnn提取特征,svm或者随机森林等进行分类,模型蒸馏,一个完整的代码
* 2022.10.29更新,进行代码重构,基本的功能基本一致。 - 支持pytorch ddp的训练 - 支持c++ libtorch的模型推理 - 支持script脚本一键运行 - 添加日志模块 The project is written primarily in Jupyter Notebook, distributed under the MIT License license, first published in 2020. It has gained significant community traction with 1,466 stars and 344 forks on GitHub. Key topics include: cnn, densenet, deployment, flask, image-classification.
简介
基于torchision实现的pytorch图像分类功能。
近期更新
-
2022.11.05更新
- 新添加tensorrt c++的推理方案
-
2022.10.29更新,进行代码重构,基本的功能基本一致。
- 支持pytorch ddp的训练
- 支持c++ libtorch的模型推理
- 支持script脚本一键运行
- 添加日志模块
习惯之前版本的请看v1版本的代码:V1版本。
主要功能:
利用pytorch实现图像分类,基于torchision可以扩展使用densenet,resnext,mobilenet,efficientnet,swin transformer等图像分类网络
如果有用欢迎star
实现功能
- 基础功能利用pytorch实现图像分类
- 包含带有warmup的cosine学习率调整
- warmup的step学习率优调整
- 多模型融合预测,加权与投票融合
- 利用flask + redis实现模型云端api部署(tag v1)
- c++ libtorch的模型部署
- 使用tta测试时增强进行预测(tag v1)
- 添加label smooth的pytorch实现(标签平滑)(tag v1)
- 添加使用cnn提取特征,并使用SVM,RF,MLP,KNN等分类器进行分类(tag v1)。
- 可视化特征层
运行环境
- python3.7
- pytorch 1.8.1
- torchvision 0.9.1
- opencv(libtorch cpp推理使用, 版本3.4.6)(可选)
- libtorch cpp推理使用(可选)
快速开始
数据集形式
数据集的组织形式,参考sample_files/imgs/listfile.txt
训练 测试
修改run.sh中的参数,直接运行run.sh即可运行
主要修改的参数:
OUTPUT_PATH 模型保存和log文件的路径
TRAIN_LIST 训练数据集的list文件
VAL_LIST 测试集合的list文件
model_name 默认是resnet50
lr 学习率
epochs 训练总的epoch
batch-size batch的大小
j dataloader的num_workers的大小
num_classes 类别数
libtorch inference
代码存储在cpp_inference文件夹中。
-
利用cpp_inference/traced_model/trace_model.py将训练好的模型导出。
-
编译所需的opencv和libtorch代码到
cpp_inference/third_party_library -
编译
sh compile.sh
- 可执行文件测试
./bin/imgCls imgpath
Contributors
Showing top 1 contributor by commit count.
