Training image classifier model with Python to derive onnx test

-Pastoral- 2020-11-13 08:02:27
training image classifier model python

Project implementation function

  • 1. Build a simple image classifier , Complete training and testing
  • 2. transformation pytorch Of pth Model to ONNX Format , load ONNX And test the


Project structure

images  The training and test data sets are stored in the directory , This example uses the kaggle The cat and dog data set of the contest , Unified resize here we are 120*120 size ; It's training and testing code , It includes pth Model to onnx Transformation . Training in CPU and GPU Up test ok. Document overview :


TestOnnx.cpp yes onnx Loading and testing code for . Document overview :

( notes : Easy to operate with one key , Project management N Multiple operations are merged into one file )

Network building training part reference JR_Chan The blog of , Thanks !
The network structure is very simple , Contains 3 Convolution layers , A full connectivity layer :

The structure of the detail point :


Training effect

Epoch:1/100 test Loss: 0.6443 Acc: 0.6168 
Epoch:2/100 train Loss: 0.6298 Acc: 0.6421 
Epoch:2/100 test Loss: 0.5762 Acc: 0.6986 
Epoch:99/100 train Loss: 0.2731 Acc: 0.8842 
Epoch:99/100 test Loss: 0.2618 Acc: 0.8936 
Epoch:100/100 train Loss: 0.2757 Acc: 0.8837 
Epoch:100/100 test Loss: 0.2613 Acc: 0.8926

Learning rate 0.002,100 individual epoch, The accuracy is about 89% .


onnx The test results

The network is very small , Model file pth and cat_dog_classify.onnx The size is only 63KB. adopt OpenCV call onnx, The test results :


By the way py Document and cpp File code ( Slightly longer , There is a link at the end of the article )

# -*- coding: UTF-8 -*-
# Created by - Pastoral - CSDN
# Reference resources
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import time
import os
from math import ceil
import argparse
import copy
from PIL import Image
from torchvision import transforms, datasets
from torch.autograd import Variable
from tensorboardX import SummaryWriter
# Define a simple binary network
class SimpleNet(nn.Module):
def __init__(self):
super(SimpleNet, self).__init__()
# Three convolution layers are used to extract features
# 1 input channel image 90x90, 8 output channel image 44x44
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, stride=1, padding=0),
# 8 input channel image 44x44, 16 output channel image 22x22
self.conv2 = nn.Sequential(
nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1),
# 16 input channel image 22x22, 32 output channel image 10x10
self.conv3 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=0),
# classification
self.classifier = nn.Sequential(
nn.Linear(32 * 10 * 10, 3)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = x.view(-1, 32 * 10 * 10)
x = self.classifier(x)
return x
# Training model entrance
def train(args):
# read data
dataloders, dataset_sizes, class_names = ImageDataset(args)
with open(args.class_file, 'w') as f:
for name in class_names:
f.writelines(name + '\n')
# use gpu or not
use_gpu = torch.cuda.is_available()
# get model
model = SimpleNet()
if args.resume:
if os.path.isfile(args.resume):
print(("=> loading checkpoint '{}'".format(args.resume)))
print(("=> no checkpoint found at '{}'".format(args.resume)))
if use_gpu:
model = torch.nn.DataParallel(model)'cuda'))
# Using the cross entropy loss function (define loss function)
criterion = nn.CrossEntropyLoss()
# gradient descent (Observe that all parameters are being optimized)
optimizer_ft = optim.SGD(model.parameters(),, momentum=0.9, weight_decay=1e-4)
# Decay LR by a factor of 0.98 every 1 epoch
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=1, gamma=0.98)
model = train_model(args=args,
dataloders = dataloders), os.path.join(args.save_path, 'best_model.pth'))
# Test a single picture ( Use pth Model ) entrance
def test(test_model_path, test_img_path, class_file):
best_model_path = test_model_path
model = SimpleNet()
class_names = []
with open(class_file, 'r') as f:
lines = f.readlines()
for line in lines:
img_path = test_img_path
predict_class = class_names[predict_image(model, img_path)]
# transformation pytorch Trained pth Model to ONNX Model
def convert_model_to_ONNX(input_img_size, input_pth_model, output_ONNX):
dummy_input = torch.randn(3, 1, input_img_size, input_img_size)
model = SimpleNet()
state_dict = torch.load(input_pth_model, map_location='cpu')
model.eval() # Set the model to reasoning mode ( important )
input_names = ["input_image"]
output_names = ["output_classification"]
torch.onnx.export(model, dummy_input, output_ONNX, verbose=True, input_names=input_names,
# Training model principal function
def train_model(args, model, criterion, optimizer, scheduler, num_epochs, dataset_sizes, use_gpu, dataloders):
begin = time.time()
best_model_wts = copy.deepcopy(model.state_dict())
best_acc = 0.0
device = torch.device('cuda' if use_gpu else 'cpu')
for epoch in range(args.start_epoch, num_epochs):
# every last epoch There's a training and a validation process (Each epoch has a training and validation phase)
for phase in ['train', 'test']:
if phase == 'train':
# Set to training mode (Set model to training mode)
# Set to validation mode (Set model to evaluate mode)
running_loss = 0.0
running_corrects = 0
tic_batch = time.time()
# In more than one batch Processing data in turn (Iterate over data)
for i, (inputs, labels) in enumerate(dataloders[phase]):
inputs =
labels =
# Zero gradient (zero the parameter gradients)
# Forward propagation (forward)
# Only in training mode are operations recorded for back propagation (track history if only in train)
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# Back propagation and gradient descent in training mode (backward + optimize only if in training phase)
if phase == 'train':
# Statistical loss and accuracy (statistics)
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds ==
batch_loss = running_loss / (i * args.batch_size + inputs.size(0))
batch_acc = running_corrects.double() / (i * args.batch_size + inputs.size(0))
if phase == 'train' and (i + 1) % args.print_freq == 0:
'[Epoch {}/{}]-[batch:{}/{}] lr:{:.6f} {} Loss: {:.6f} Acc: {:.4f} Time: {:.4f} sec/batch'.format(
epoch + 1, num_epochs, i + 1, ceil(dataset_sizes[phase] / args.batch_size),
scheduler.get_lr()[0], phase, batch_loss, batch_acc,
(time.time() - tic_batch) / args.print_freq))
tic_batch = time.time()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
if epoch == 0 and os.path.exists('result.txt'):
with open('result.txt', 'a') as f:
f.write('Epoch:{}/{} {} Loss: {:.4f} Acc: {:.4f} \n'.format(epoch + 1, num_epochs, phase, epoch_loss,
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
writer.add_scalar(phase + '/Loss', epoch_loss, epoch)
writer.add_scalar(phase + '/Acc', epoch_acc, epoch)
if (epoch + 1) % args.save_epoch_freq == 0:
if not os.path.exists(args.save_path):
os.makedirs(args.save_path), os.path.join(args.save_path, "epoch_" + str(epoch) + ".pth"))
# Deep copy model (deep copy the model)
if phase == 'test' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
# take model Save as graph
writer.add_graph(model, (inputs,))
time_elapsed = time.time() - begin
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val Accuracy: {:4f}'.format(best_acc))
# Load the best model parameters (load best model weights)
return model
# Test the main function of a single image
def predict_image(model, image_path):
image ='L')
# Intercept the middle of the test 90x90
transformation1 = transforms.Compose([
transforms.Normalize([0.5], [0.5])
# Preprocessing images
image_tensor = transformation1(image).float()
# Add an extra batch dimension , because PyTorch Take all the images as batches
image_tensor = image_tensor.unsqueeze_(0)
if torch.cuda.is_available():
# Change the input to a variable
input = Variable(image_tensor)
# The category of the predicted image
output = model(input)
index =
return index
# Use PIL Read the picture and convert it to grayscale
def readImg(path):
im =
return im.convert("L")
# Read training and test data
def ImageDataset(args):
# Data enhancement and normalization
# The pictures are all 120x120 Of , Cut at random during training 90x90 Part of , Cut the middle of the test 90x90
data_transforms = {
'train': transforms.Compose([
transforms.Normalize([0.5], [0.5])
'test': transforms.Compose([
transforms.Normalize([0.5], [0.5])
data_dir = args.data_dir
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x], loader=readImg)
for x in ['train', 'test']}
dataloaders = {x:[x], batch_size=args.batch_size,
shuffle=(x == 'train'), num_workers=args.num_workers)
for x in ['train', 'test']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'test']}
class_names = image_datasets['train'].classes
return dataloaders, dataset_sizes, class_names
# Set parameters
def set_parser():
parser = argparse.ArgumentParser(description='classification')
# The root directory of the image data (Root catalog of images)
parser.add_argument('--data-dir', type=str, default='images')
parser.add_argument('--class-file', type=str, default='class_names.class')
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--num-epochs', type=int, default=100)
parser.add_argument('--lr', type=float, default=0.002) # those who set lr greater than 0.01 are hooligans!!
parser.add_argument('--num-workers', type=int, default=8)
parser.add_argument('--print-freq', type=int, default=100)
parser.add_argument('--save-epoch-freq', type=int, default=1)
parser.add_argument('--save-path', type=str, default='output')
parser.add_argument('--resume', type=str, default='', help='For training from one checkpoint')
parser.add_argument('--start-epoch', type=int, default=0, help='Corresponding to the epoch of resume')
return parser.parse_args()
if __name__ == '__main__':
writer = SummaryWriter(log_dir='log')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def_img_train_and_test_size = 90 # Training size
args = set_parser() # Set parameters
train(args) # Training models
test('./output/best_model.pth', './images/test/cat/cat.0.jpg', args.class_file) # test model ( A single picture )
# transformation pytorch Of pth Model to ONNX Model
convert_model_to_ONNX(def_img_train_and_test_size, './output/epoch_99.pth', "./cat_dog_classify.onnx")


// PthONNX.cpp : be based on OpenCV dnn、 onnx Of cat、dog Two classification procedures
// Created by - Pastoral - 2019 year 10 month 29 Japan
#include <iostream>
#include <opencv2/highgui.hpp>
#include <opencv2/imgproc.hpp>
#include <opencv2/dnn.hpp>
#include <fstream>
//ONNX Executive inference class
class PthONNX {
//@model_path ONNX Model path
//@classes_file_path Classification information file
//@input_size Network input size
PthONNX(const std::string &model_path, const std::string &classes_file_path, cv::Size input_size);
//@input_image Input picture ,BGR Format
//@classification_output The classification name of the network output 0:cat 1:dog 1:None
void Classify(const cv::Mat &input_image, std::string &classification_output);
void ClassifyImplement(const cv::Mat &image, std::string &classification_output);
cv::Size input_size_;
cv::dnn::Net net_classify_;
std::vector<std::string> classes_;
// Constructors
PthONNX::PthONNX(const std::string &model_path, const std::string &classes_file_path,
cv::Size input_size) : input_size_(input_size) {
std::ifstream ifs(classes_file_path.c_str());
std::string line;
while (getline(ifs, line)) {
line = line;
net_classify_ = cv::dnn::readNetFromONNX(model_path);
// ONNX Inference entry function
void PthONNX::Classify(const cv::Mat &input_image, std::string &classification_results) {
cv::Mat image = input_image.clone();
cv::resize(image, image, cv::Size(90, 90));
cv::cvtColor(image, image, cv::COLOR_BGR2GRAY);
//ONNX Reasoning main function
void PthONNX::ClassifyImplement(const cv::Mat &image,std::string &classification_results) {
//*********** Before processing ***********
cv::Scalar mean_value(0, 0, 0);
cv::Mat input_blob = cv::dnn::blobFromImage(image, 1, input_size_, mean_value, false, false, CV_32F);
//*********** Before processing ***********
const std::vector<cv::String> &out_names = net_classify_.getUnconnectedOutLayersNames();
cv::Mat out_tensor = net_classify_.forward(out_names[0]);
//*********** post-processing ***********
double minVal;
double maxVal;
cv::Point minIdx;
cv::Point maxIdx; // minnimum Index, maximum Index
cv::minMaxLoc(out_tensor, &minVal, &maxVal, &minIdx, &maxIdx);
int index_class = maxIdx.x;
classification_results = (index_class <= 1) ? classes_[index_class] : "None";
//*********** post-processing ***********
int main()
const std::string img_path = "D:/1/1/SimpleNet-master/images/train/cat/cat.4896.jpg";
const std::string onnx_model_path = "D:/1/1/pytorch-train-test-onnx/cat_dog_classify.onnx";
const std::string class_names_file_path = "D:/software/VS2019_Test/PthONNX/x64/class_names.class";
const cv::Size net_input_size(90, 90);
cv::Mat img = cv::imread(img_path);
std::string classify_output; // Classification results
PthONNX classifier(onnx_model_path, class_names_file_path, net_input_size);
classifier.Classify(img, classify_output);
std::cout << " Picture category :" << classify_output << std::endl << std::endl;
cv::putText(img, classify_output, cv::Point(20,20), 2, 1.2, cv::Scalar(0, 0, 255));
cv::imshow("classify", img);

Complete the project ( With data sets ,pytorch Training and testing ,pth Model transfer onnx,onnx File loading and testing ) Download link :pytorch Training image classification model pth turn ONNX And test the


  1. 利用Python爬虫获取招聘网站职位信息
  2. Using Python crawler to obtain job information of recruitment website
  3. Several highly rated Python libraries arrow, jsonpath, psutil and tenacity are recommended
  4. Python装饰器
  5. Python实现LDAP认证
  6. Python decorator
  7. Implementing LDAP authentication with Python
  8. Vscode configures Python development environment!
  9. In Python, how dare you say you can't log module? ️
  10. 我收藏的有关Python的电子书和资料
  11. python 中 lambda的一些tips
  12. python中字典的一些tips
  13. python 用生成器生成斐波那契数列
  14. python脚本转pyc踩了个坑。。。
  15. My collection of e-books and materials about Python
  16. Some tips of lambda in Python
  17. Some tips of dictionary in Python
  18. Using Python generator to generate Fibonacci sequence
  19. The conversion of Python script to PyC stepped on a pit...
  20. Python游戏开发,pygame模块,Python实现扫雷小游戏
  21. Python game development, pyGame module, python implementation of minesweeping games
  22. Python实用工具,email模块,Python实现邮件远程控制自己电脑
  23. Python utility, email module, python realizes mail remote control of its own computer
  24. 毫无头绪的自学Python,你可能连门槛都摸不到!【最佳学习路线】
  25. Python读取二进制文件代码方法解析
  26. Python字典的实现原理
  27. Without a clue, you may not even touch the threshold【 Best learning route]
  28. Parsing method of Python reading binary file code
  29. Implementation principle of Python dictionary
  30. You must know the function of pandas to parse JSON data - JSON_ normalize()
  31. Python实用案例,私人定制,Python自动化生成爱豆专属2021日历
  32. Python practical case, private customization, python automatic generation of Adu exclusive 2021 calendar
  33. 《Python实例》震惊了,用Python这么简单实现了聊天系统的脏话,广告检测
  34. "Python instance" was shocked and realized the dirty words and advertisement detection of the chat system in Python
  35. Convolutional neural network processing sequence for Python deep learning
  36. Python data structure and algorithm (1) -- enum type enum
  37. 超全大厂算法岗百问百答(推荐系统/机器学习/深度学习/C++/Spark/python)
  38. 【Python进阶】你真的明白NumPy中的ndarray吗?
  39. All questions and answers for algorithm posts of super large factories (recommended system / machine learning / deep learning / C + + / spark / Python)
  40. [advanced Python] do you really understand ndarray in numpy?
  41. 【Python进阶】Python进阶专栏栏主自述:不忘初心,砥砺前行
  42. [advanced Python] Python advanced column main readme: never forget the original intention and forge ahead
  43. python垃圾回收和缓存管理
  44. java调用Python程序
  45. java调用Python程序
  46. Python常用函数有哪些?Python基础入门课程
  47. Python garbage collection and cache management
  48. Java calling Python program
  49. Java calling Python program
  50. What functions are commonly used in Python? Introduction to Python Basics
  51. Python basic knowledge
  52. Anaconda5.2 安装 Python 库(MySQLdb)的方法
  53. Python实现对脑电数据情绪分析
  54. Anaconda 5.2 method of installing Python Library (mysqldb)
  55. Python implements emotion analysis of EEG data
  56. Master some advanced usage of Python in 30 seconds, which makes others envy it
  57. python爬取百度图片并对图片做一系列处理
  58. Python crawls Baidu pictures and does a series of processing on them
  59. python链接mysql数据库
  60. Python link MySQL database