“好菇毒”——蘑菇识别系统

项目简介

项目介绍

误食野生蘑菇中毒事件时有发生,误食毒蘑菇是我国食物中毒事件中导致死亡的最主要原因,而且蘑菇形态千差万别,对于非专业人士,无法从外观、形态、颜色等方面区分有毒蘑菇与可食用蘑菇,没有一个简单的标准能够将有毒蘑菇和可食用蘑菇区分开来。
通过本项目“好菇毒”——蘑菇识别系统,实现对蘑菇图片的分类识别。

项目设计

系统架构

image.png

数据库设计

采用Mysql数据库。由于系统比较简单,只有一张表:mushroom。
mushroom表字段:id(主键),name(蘑菇名称),scientific_name (学名),species(所属科属),toxicity(毒性),feature(详细信息),img_path(图片路径)。
其中为提升查询效率,在scientific_name字段上建立了索引。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
-- 创建库
create database if not exists mushroom;

-- 切换库
use mushroom;

-- 蘑菇表
create table if not exists mushroom
(
id bigint auto_increment comment 'id' primary key,
name varchar(256) not null comment '蘑菇名称',
scientific_name varchar(256) not null comment '学名',
species varchar(256) null comment '所属科属',
toxicity varchar(512) null comment '毒性',
feature varchar(1024) null comment '详细信息',
img_path varchar(512) null comment '图片路径',
index idx_unionId (scientific_name)
) comment '蘑菇表' collate = utf8mb4_unicode_ci;


INSERT INTO mushroom(name,scientific_name,species,toxicity,feature,img_path) VALUES ('伞菌属(蘑菇属)','Agaricus','伞菌属,蘑菇属','部分有毒','伞菌属(Agaricus),又名蘑菇属,是蘑菇科下的一个大型及重要的属,包括了可吃菇及有毒菇,在全世界合共超过300个物种。其担子果通常有白色、褐色或者灰褐色的肉质菌盖,腹面又有辐射状的菌褶,在其内形成担子和担孢子,菌柄极易与菌盖分开,孢子为卵圆形或者椭圆形,大部分蘑菇属都可食用,少数有毒。','https://img.wpixiu.cn/i/Agaricus.jpg');
INSERT INTO mushroom(name,scientific_name,species,toxicity,feature,img_path) VALUES ('鹅膏菌','Amanita','真菌界、担子菌门、伞菌目、鹅膏科','部分有毒','鹅膏属隶属于真菌界、担子菌门、伞菌目、鹅膏科。全世界报道500多种,广布世界各大洲。该属中有些种是美味食用菌,如红黄鹅膏;而另一些种则是剧毒蘑菇。蘑菇中毒事件大多数是由鹅膏(如玫瑰红鹅膏、灰花纹鹅膏、绿盖鹅膏、和鳞柄鹅膏等)引起的。
百余年来,人们对鹅膏所含的毒素作了大量的研究,发现毒鹅膏中含有两大类毒素即肽类毒素和非肽类毒素。
鹅膏中主要的毒素为肽类毒素,分为鹅膏毒肽类、鬼笔毒肽类、和毒伞毒肽类三大类群22种,属于环肽化合物,绝大多数的化学结构稳定、耐高温,一般烹调加工不能改变其结构,进入人体后对肝脏和肾脏有强烈的毁坏作用。鹅膏毒肽对真核细胞的RNA聚合酶II具有专一性抑制作用,鬼笔毒肽对肌动蛋白具有束缚作用,因此它们在生命科学研究中具有重要的应用价值。
非肽类毒素有毒蝇碱、异噁唑衍生物、色胺衍生物,主要为神经性毒素。在生物防冶及其它方面毒蝇鹅膏、春生鹅膏等所含毒素对昆虫或农业害虫都有一定的诱杀作用。
鹅膏属大多数与针科或壳斗科形成外生菌根菌,尚无人工栽培,科学研究中需要的鹅膏肽类毒素只能从野外采集的子实体中提取,因此鹅膏肽类毒素的价格非常昂贵。','https://img.wpixiu.cn/i/Amanita.jpg');
INSERT INTO mushroom(name,scientific_name,species,toxicity,feature,img_path) VALUES ('牛肝菌','Boletus','牛肝菌科、乳牛肝菌科','部分有毒','牛肝菌是牛肝菌科和松塔牛肝菌科等真菌的统称,是野生而可以食用的菇菌类,其中除少数品种有毒或味苦而不能食用外,大部分品种均可食用。主要有白、黄、黑牛肝菌。
白牛肝菌味道鲜美,营养丰富。该菌菌体较大,肉肥厚,柄粗壮,食味香甜可口,营养丰富,是一种世界性著名食用菌。西欧各国也有广泛食用白牛肝菌的习惯,除新鲜的作菜外,大部分切片干燥,加工成各种小包装,用来配制汤料或做成酱油浸膏,也有制成盐腌品食用。
菌盖扁半球形,光滑、不粘、淡褐色,菌肉白色,有酱香味,可入药。大腿蘑营养丰富,味道香美,是极富美味的野生食用菌之一,可出口欧美、日本等国,深受外商欢迎。该菌菌体较大,肉肥厚,柄粗壮,食味香甜可口,营养丰富,是一种世界性著名食用菌。','https://img.wpixiu.cn/i/Boletus.jpg');
INSERT INTO mushroom(name,scientific_name,species,toxicity,feature,img_path) VALUES ('丝膜菌','Cortinarius','丝膜菌属','部分有毒','丝膜菌属Cortinarius (Pers.) Gray隶属于担子菌门Basidiomycota、蘑菇纲Agaricomycetes、蘑菇目Agaricales、丝膜菌科Cortinariaceae,该属是蘑菇目中最大的属,目前已描述的物种超过2 000种,分布全球(Niskanen et al. 2016),在我国,2018年编制的《中国生物多样性红色名录-大型真菌卷》记录了丝膜菌属物种200种,分布于我国大部分地区。丝膜菌属真菌具有重要的生态价值和经济价值,一方面,该属真菌可以与一些乔木、灌木形成外生菌根,在植物生长和森林生态系统中发挥着重要作用(Bödeker et al. 2014;Defrenne et al. 2019);另一方面,丝膜菌属的一些种类具有食药用价值。此外,还有一些种类含有奥来毒素(orellanine),是剧毒的,会导致急性肾衰和死亡;近年研究表明,丝膜菌的奥来毒素可用于肾小管上皮转移性肾癌的治疗,可望开发成治疗药物(Buvall et al. 2017)。','https://img.wpixiu.cn/i/Cortinarius.jpg');
INSERT INTO mushroom(name,scientific_name,species,toxicity,feature,img_path) VALUES ('粉褶菌','Entoloma','粉褶菌属','部分有毒','粉褶菌属 Entoloma (Fr.) P. Kumm.由 Kummer 于 1871 年建立,隶属于担子菌门(Basidiomycota),层菌纲(Hymenomycet),伞菌目(Agaricales),粉褶菌科(Entolomataceae)。作为伞菌目第二大属,粉褶菌属物种繁多,分布范围广泛,从高寒山地到盆地,从寒带到热带都有粉褶菌属物种的分布。','https://img.wpixiu.cn/i/Entoloma.jpg');
INSERT INTO mushroom(name,scientific_name,species,toxicity,feature,img_path) VALUES ('湿伞菌','Hygrocybe','蜡伞科,湿伞属','部分有毒','菌盖呈伞状,有时稍微凸起,通常具有明亮且鲜艳的颜色,如红色、橙色、黄色或绿色。它们的菌盖表面光滑或有时有些粉状。菌柄通常是细长且硬实的,有时会有纤维状的环带。','https://img.wpixiu.cn/i/Hygrocybe.jpg');
INSERT INTO mushroom(name,scientific_name,species,toxicity,feature,img_path) VALUES ('毛头乳菇','Lactarius','乳菇属','有毒','毛头乳菇又称疝疼乳菇。子实体中等。菌盖深蛋壳色至暗土黄色,具同心环纹,边缘白色长绒毛,乳汁白色,不变色,味苦。菌盖扁半球形,中部下凹呈漏斗状,这缘内卷 。菌肉白色。菌褶直生至延生,较密。夏秋季在林中地上单生或散生。','https://img.wpixiu.cn/i/Lactarius.jpg');
INSERT INTO mushroom(name,scientific_name,species,toxicity,feature,img_path) VALUES ('红菇','Russula','红菇属','可食用','红菇是一类大型菌根真菌,属担子菌亚门、弹子菌纲、伞菌目、红菇科、红菇属。是一种名贵的野生食(药)用菌。在世界范围内分布广泛,我国主要分布于福建、云南、江西、辽宁、河南、四川、广西等省区,海拔为300~1000m山林地带,植被和土壤垂直分布明显,坡度10°~45°的缓坡地至斜坡地。
红菇属大多数种类是可食用的,且营养丰富,味道鲜美,有“菇中之王”的美称,系天然营养佳品,具有较高的营养价值。红菇含有丰富的必需氨基酸、多糖、有机酸、维生素、脂肪酸和甾类化合物、色素和抗生素等。','https://img.wpixiu.cn/i/Russula.jpg');
INSERT INTO mushroom(name,scientific_name,species,toxicity,feature,img_path) VALUES ('乳牛肝菌','Suillus','乳牛肝菌属','可食用','乳牛肝菌,野生食用菌,子实体伞状,口感滑爽,肉质细嫩,营养丰富,富含人体必需的多种氨基酸,对增强人体免疫系统有较好的作用,可分为红乳牛肝菌和白乳牛肝菌两种。','https://img.wpixiu.cn/i/Suillus.jpg');

image.png

项目实现

环境配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# python == 3.8
# torch >= 2.0

albumentations==1.3.1
lxml==5.0.0
matplotlib==3.7.4
numpy==1.24.3
onnxruntime==1.16.3
opencv_python==4.9.0.80
opencv_python_headless==4.9.0.80
pandas==2.0.3
Pillow==8.0.0
PyQt5==5.15.10
PyQt5_sip==12.13.0
Requests==2.31.0
scikit_learn==1.3.2
timm==0.9.12
torch==2.0.0+cu117
torchstat==0.0.7
torchvision==0.15.1
tqdm==4.66.1
backgroundremover

数据获取

  1. kaggle数据集,9 个最常见的北欧蘑菇属的图像文件夹。每个文件夹包含 300 到 1500 张精选的蘑菇属图像。https://www.kaggle.com/datasets/maysee/mushrooms-classification-common-genuss-images/data

  2. 爬取百度图片

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# -*- coding: utf-8 -*-
# @Time : 2024/1/3 19:51
# @Author : wpixiu
# @File : img_bing.py
# @Software: PyCharm
# 爬取百度图片

import requests
import re
import os

headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/84.0.4147.125 Safari/537.36'}
name = input('请输入要爬取的图片类别:')
num = 0
num_1 = 0
num_2 = 0
x = input('请输入要爬取的图片数量?(1等于60张图片,2等于120张图片):')
list_1 = []
for i in range(int(x)):
name_1 = os.getcwd()
name_2 = os.path.join(name_1, 'data/baidu/' + name)
url = 'https://image.baidu.com/search/flip?tn=baiduimage&ie=utf-8&word=' + name + '&pn=' + str(i * 30)
res = requests.get(url, headers=headers)
htlm_1 = res.content.decode()
a = re.findall('"objURL":"(.*?)",', htlm_1)
if not os.path.exists(name_2):
os.makedirs(name_2)
for b in a:
try:
b_1 = re.findall('https:(.*?)&', b)
b_2 = ''.join(b_1)
if b_2 not in list_1:
num = num + 1
img = requests.get(b)
f = open(os.path.join(name_1, 'data/baidu/' + name, name + str(num) + '.jpg'), 'ab')
print('---------正在下载第' + str(num) + '张图片----------')
f.write(img.content)
f.close()
list_1.append(b_2)
elif b_2 in list_1:
num_1 = num_1 + 1
continue
except Exception as e:
print('---------第' + str(num) + '张图片无法下载----------')
num_2 = num_2 + 1
continue
# 为了防止下载的数据有坏图,直接在下载过程中对数据进行清洗
print('下载完成,总共下载{}张,成功下载:{}张,重复下载:{}张,下载失败:{}张'.format(num + num_1 + num_2, num, num_1, num_2))
  1. 爬取bing图片
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
# -*- coding: utf-8 -*-
# @Time : 2024/1/3 19:51
# @Author : wpixiu
# @File : img_bing.py
# @Software: PyCharm
# 爬取bing图片

import requests
from lxml import etree
import os
from multiprocessing.dummy import Pool
import json
from time import time
import re


class BingImagesSpider:
thread_amount = 1000 # 线程池数量,线程池用于多IO请求,减少总的http请求时间
per_page_images = 30 # 每页必应请求的图片数
count = 0 # 图片计数
success_count = 0
# 忽略图片标签的一些字符
ignore_chars = ['|', '.', ',', ',', '', '', '/', '@', ':', ':', ';', ';', '[', ']', '+']
# 允许的图片类型
image_types = ['bmp', 'jpg', 'png', 'tif', 'gif', 'pcx', 'tga', 'exif', 'fpx', 'svg', 'psd', 'cdr', 'pcd', 'dxf',
'ufo', 'eps', 'ai', 'raw', 'WMF', 'webp']
# 请求头
headers = {
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/101.0.4951.54 Safari/537.36',
'referer': 'https://www.bing.com/'}
# 必应图片 url
bing_image_url_pattern = 'https://www.bing.com/images/async?q={}&first={}&count={}&mmasync=1'

def __init__(self, keyword, amount, path='./'):
# keyword: 需爬取的关键字
# amount: 需爬取的数量
# path: 图片存放路径
self.keyword = keyword
self.amount = amount
self.path = path
self.thread_pool = Pool(self.thread_amount)

def __del__(self):
self.thread_pool.close()
self.thread_pool.join()

# 作用:从必应请求图片
def request_homepage(self, url):
# url: 必应图片页的 url
return requests.get(url, headers=self.headers)

# 作用:解析必应网页,得到所有图片的信息,封装到列表中返回
# 每个图片的信息以字典对象存储,字典的键包括 image_title, image_type, image_md5, image_url
def parse_homepage_response(self, response):
# response: 必应网站的响应

# 获取各图片信息所在的json格式字符串 m
tree = etree.HTML(response.text)
m_list = tree.xpath('//*[@class="imgpt"]/a/@m')

# 对每个图片分别处理
info_list = []
for m in m_list:
dic = json.loads(m)

# 去除一些文件名中不允许的字符
image_title = dic['t']
for char in self.ignore_chars:
image_title = image_title.replace(char, ' ')
image_title = image_title.strip()

# 有些图片的信息中不包含图片格式,该情况将图片设置为 jpg 格式
image_type = dic['murl'].split('.')[-1]
if image_type not in self.image_types:
image_type = 'jpg'

# 将每个图片的信息存为字典格式
info = dict()
info['image_title'] = image_title
info['image_type'] = image_type
info['image_md5'] = dic['md5']
info['image_url'] = dic['murl']

info_list.append(info)
return info_list

# 请求具体图片,保存到初始化时指定的路径
def request_and_save_image(self, info):
# info: 每个图片的信息,以字典对象存储。字典的键包括 image_title, image_type, image_md5, image_url
filename = '{}_{}.{}'.format(info['image_title'], self.count, info['image_type'])
filename = re.sub('[\/:*?"<>|]', '_', filename) # 用_替换非法字符

for i in range(len(filename)):
if filename[i] == "\\":
filename[i] = '_'
filepath = os.path.join(self.path, filename)

try:
# 请求图片
response = requests.get(info['image_url'], headers=self.headers, timeout=1.5)
# 保存图片
with open(filepath, 'wb') as fp:
fp.write(response.content)
# 打印日志
self.count += 1
self.success_count += 1
print('{}: saving {} done.'.format(self.count, filepath))

except requests.exceptions.RequestException as e:
self.count += 1
print('{}: saving {}failed. url: {}'.format(self.count, filepath, info['image_url']))
print('\t tip:', e)

# 作用:图片信息的列表去重,去除重复的图片信息
def deduplication(self, info_list):
result = []

# 用图片的 md5 做为唯一标识符
md5_set = set()
for info in info_list:
if info['image_md5'] not in md5_set:
result.append(info)
md5_set.add(info['image_md5'])
return result

# 作用:运行爬虫,爬取图片
def run(self):
# 创建用于保存图片的目录
if not os.path.exists(self.path):
os.mkdir(self.path)

# 根据关键词和需要的图片数量,生成将爬取的必应图片网页列表
homepage_urls = []
for i in range(int(self.amount / self.per_page_images * 1.5) + 1): # 由于有些图片会重复,故先请求1.5倍图片
url = self.bing_image_url_pattern.format(self.keyword, i * self.per_page_images, self.per_page_images)
homepage_urls.append(url)
print('homepage_urls len {}'.format(len(homepage_urls)))

# 通过线程池请求所有必应图片网页
homepage_responses = self.thread_pool.map(self.request_homepage, homepage_urls)

# 从必应网页解析所有图片的信息,每个图片包括 image_title, image_type, image_md5, image_url 等信息。
info_list = []
for response in homepage_responses:
result = self.parse_homepage_response(response)
info_list += result
print('info amount before deduplication', len(info_list))

# 删除重复的图片,避免重复下载
info_list = self.deduplication(info_list)
print('info amount after deduplication', len(info_list))
info_list = info_list[: self.amount]
print('info amount after split', len(info_list))

# 下载所有图片,并保存
self.thread_pool.map(self.request_and_save_image, info_list)
print('all done. {} successfully downloaded, {} failed.'.format(self.success_count,
self.count - self.success_count))


if __name__ == '__main__':
# 关键词,爬取数量,路径
BingImagesSpider('Agaricus', 2000, path='data/bing/' + 'Agaricus').run()
BingImagesSpider('Amanita', 2000, path='data/bing/' + 'Amanita').run()
BingImagesSpider('Boletus', 2000, path='data/bing/' + 'Boletus').run()
BingImagesSpider('Cortinarius', 2000, path='data/bing/' + 'Cortinarius').run()
BingImagesSpider('Entoloma', 2000, path='data/bing/' + 'Entoloma').run()
BingImagesSpider('Hygrocybe', 2000, path='data/bing/' + 'Hygrocybe').run()
BingImagesSpider('Lactarius', 2000, path='data/bing/' + 'Lactarius').run()
BingImagesSpider('Russula', 2000, path='data/bing/' + 'Russula').run()
BingImagesSpider('Suillus', 2000, path='data/bing/' + 'Suillus').run()

数据清洗

  1. 去除爬取数据中无法打开的坏图

image.png

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# -*- coding: utf-8 -*-
# @Time : 2024/1/3 19:51
# @Author : wpixiu
# @File : img_bing.py
# @Software: PyCharm
# 数据清洗


import shutil
import cv2
import os
import os.path as osp
import numpy as np
from tqdm import tqdm


# 实际的图片保存和读取的过程中存在中文,所以这里通过这两种方式来应对中文读取的情况。
# handle chinese path
def cv_imread(file_path, type=-1):
cv_img = cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), -1)
if type == 0:
cv_img = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
return cv_img


def cv_imwrite(file_path, cv_img, is_gray=True):
if len(cv_img.shape) == 3 and is_gray:
cv_img = cv_img[:, :, 0]
cv2.imencode(file_path[-4:], cv_img)[1].tofile(file_path)


def data_clean(src_folder, english_name, path):
clean_folder = "data/data_clean/" + path + "/" + english_name
print(clean_folder)
if os.path.isdir(clean_folder):
print("保存目录已存在")
shutil.rmtree(clean_folder)
os.makedirs(clean_folder)
# 数据清晰的过程主要是通过oepncv来进行读取,读取之后没有问题就可以进行保存
# 数据清晰的过程中,一是为了保证数据是可以读取的,二是需要将原先的中文修改为英文,方便后续的程序读取。
image_names = os.listdir(src_folder)
with tqdm(total=len(image_names)) as pabr:
for i, image_name in enumerate(image_names):
image_path = osp.join(src_folder, image_name)
try:
img = cv_imread(image_path)
img_channel = img.shape[-1]
if img_channel == 3:
save_image_name = english_name + "_" + str(i) + ".jpg"
save_path = osp.join(clean_folder, save_image_name)
cv_imwrite(file_path=save_path, cv_img=img, is_gray=False)
except:
print("{}是坏图".format(image_name))
pabr.update(1)


if __name__ == '__main__':
# baidu
data_clean(src_folder="data/baidu/丝膜菌", english_name="Agaricus",path="baidu")
data_clean(src_folder="data/baidu/鹅膏菌", english_name="Amanita",path="baidu")
data_clean(src_folder="data/baidu/牛肝菌", english_name="Boletus",path="baidu")
data_clean(src_folder="data/baidu/鹅膏菌", english_name="Cortinarius",path="baidu")
data_clean(src_folder="data/baidu/鹅膏菌", english_name="Entoloma",path="baidu")
data_clean(src_folder="data/baidu/鹅膏菌", english_name="Hygrocybe",path="baidu")
data_clean(src_folder="data/baidu/鹅膏菌", english_name="Lactarius",path="baidu")
data_clean(src_folder="data/baidu/鹅膏菌", english_name="Russula",path="baidu")
data_clean(src_folder="data/baidu/鹅膏菌", english_name="Suillus",path="baidu")
# bing
data_clean(src_folder="data/bing/Agaricus", english_name="Agaricus",path="bing")
data_clean(src_folder="data/bing/Amanita", english_name="Amanita",path="bing")
data_clean(src_folder="data/bing/Boletus", english_name="Boletus",path="bing")
data_clean(src_folder="data/bing/Cortinarius", english_name="Cortinarius",path="bing")
data_clean(src_folder="data/bing/Entoloma", english_name="Entoloma",path="bing")
data_clean(src_folder="data/bing/Hygrocybe", english_name="Hygrocybe",path="bing")
data_clean(src_folder="data/bing/Lactarius", english_name="Lactarius",path="bing")
data_clean(src_folder="data/bing/Russula", english_name="Russula",path="bing")
data_clean(src_folder="data/bing/Agaricus", english_name="Suillus",path="bing")

  1. 手动去除爬取数据集中无关内容

image.png

  1. 将清洗后数据集合并。将data/Mushroomsdata/data_clean/baidudata/data_clean/bing数据集合并至data/merge
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# -*- coding: utf-8 -*-
# @Time : 2024/1/14 22:33
# @Author : wpixiu
# @File : data_merge.py
# @Software: PyCharm
# 将清洗后数据集合并

import os
import shutil

# 定义目录路径
data_dir = "data"
data_clean_dir = os.path.join(data_dir, "data_clean")
baidu_dir = os.path.join(data_clean_dir, "baidu")
bing_dir = os.path.join(data_clean_dir, "bing")
moom_dir = os.path.join(data_dir, "Mushrooms") # 注意这里的修改
merge_dir = os.path.join(data_dir, "merge")

# 获取baidu、bing和Mushrooms目录下的文件夹列表
baidu_folders = [folder for folder in os.listdir(baidu_dir) if os.path.isdir(os.path.join(baidu_dir, folder))]
bing_folders = [folder for folder in os.listdir(bing_dir) if os.path.isdir(os.path.join(bing_dir, folder))]
moom_folders = [folder for folder in os.listdir(moom_dir) if os.path.isdir(os.path.join(moom_dir, folder))]


def copy_files_to_directory(source_dir, target_dir):
for filename in os.listdir(source_dir):
source_file = os.path.join(source_dir, filename)
target_file = os.path.join(target_dir, filename)
if os.path.isfile(source_file):
shutil.copy2(source_file, target_file) # 复制文件并保留元数据

# 循环遍历文件夹列表,将名称相同的合并到merge目录下
for folder_name in set(baidu_folders + bing_folders + moom_folders):
merge_folder_dir = os.path.join(merge_dir, folder_name)
try:
os.makedirs(merge_folder_dir, exist_ok=True) # 创建合并后的目录
except Exception as e:
print(f"Error creating directory: {e}")
continue

# 复制baidu、bing和moom目录下相同名称的文件夹到合并后的目录
for source_dir in [baidu_dir, bing_dir, moom_dir]:
source_folder_dir = os.path.join(source_dir, folder_name)
print(source_folder_dir)
if os.path.isdir(source_folder_dir):
try:
# shutil.copytree(source_folder_dir, merge_folder_dir)
copy_files_to_directory(source_folder_dir, merge_folder_dir)
except Exception as e:
print(f"Error copying directory: {e}")
continue

print("合并文件夹完成")

划分数据集

将数据集按照6:2:2的比例划分训练集、验证集、测试集。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# -*- coding: utf-8 -*-
# @Time : 2024/1/14 22:33
# @Author : wpixiu
# @File : data_merge.py
# @Software: PyCharm
# 将清洗后数据集合并


import os
import random
import shutil
from shutil import copy2
import os.path as osp

def data_set_split(src_data_folder, target_data_folder, train_scale=0.6, val_scale=0.2, test_scale=0.2):
'''
读取源数据文件夹,生成划分好的文件夹,分为trian、val、test三个文件夹进行
:param src_data_folder: 源文件夹
:param target_data_folder: 目标文件夹
:param train_scale: 训练集比例
:param val_scale: 验证集比例
:param test_scale: 测试集比例
:return:
'''
print("开始数据集划分")
class_names = os.listdir(src_data_folder)
# 在目标目录下创建文件夹
split_names = ['train', 'val', 'test']
for split_name in split_names:
split_path = os.path.join(target_data_folder, split_name)
if os.path.isdir(split_path):
pass
else:
os.mkdir(split_path)
# 然后在split_path的目录下创建类别文件夹
for class_name in class_names:
class_split_path = os.path.join(split_path, class_name)
if os.path.isdir(class_split_path):
pass
else:
os.mkdir(class_split_path)

# 按照比例划分数据集,并进行数据图片的复制
# 首先进行分类遍历
for class_name in class_names:
current_class_data_path = os.path.join(src_data_folder, class_name)
current_all_data = os.listdir(current_class_data_path)
current_data_length = len(current_all_data)
current_data_index_list = list(range(current_data_length))
random.shuffle(current_data_index_list)

train_folder = os.path.join(os.path.join(target_data_folder, 'train'), class_name)
val_folder = os.path.join(os.path.join(target_data_folder, 'val'), class_name)
test_folder = os.path.join(os.path.join(target_data_folder, 'test'), class_name)
train_stop_flag = current_data_length * train_scale
val_stop_flag = current_data_length * (train_scale + val_scale)
current_idx = 0
train_num = 0
val_num = 0
test_num = 0
for i in current_data_index_list:
# print(current_class_data_path)
# print(current_all_data[i])
src_img_path = os.path.join(current_class_data_path, current_all_data[i])
if current_idx <= train_stop_flag:
copy2(src_img_path, train_folder)
# print("{}复制到了{}".format(src_img_path, train_folder))
train_num = train_num + 1
elif (current_idx > train_stop_flag) and (current_idx <= val_stop_flag):
copy2(src_img_path, val_folder)
# print("{}复制到了{}".format(src_img_path, val_folder))
val_num = val_num + 1
else:
copy2(src_img_path, test_folder)
# print("{}复制到了{}".format(src_img_path, test_folder))
test_num = test_num + 1

current_idx = current_idx + 1

print("*********************************{}*************************************".format(class_name))
print(
"{}类按照{}:{}:{}的比例划分完成,一共{}张图片".format(class_name, train_scale, val_scale, test_scale,
current_data_length))
print("训练集{}:{}张".format(train_folder, train_num))
print("验证集{}:{}张".format(val_folder, val_num))
print("测试集{}:{}张".format(test_folder, test_num))


if __name__ == '__main__':
src_data_folder = "data/merge" # todo 修改你的原始数据集路径
target_data_folder = "data/" + "split"
if osp.isdir(target_data_folder):
print("target folder 已存在, 正在删除...")
shutil.rmtree(target_data_folder)
os.mkdir(target_data_folder)
print("Target folder 创建成功")

data_set_split(src_data_folder, target_data_folder)
print("*****************************************************************")
print("数据集划分完成,请在{}目录下查看".format(target_data_folder))

模型训练

使用resnet50d预训练模型

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import numpy as np
from torch.utils.data import DataLoader

from torchutils import *
from torchvision import datasets
import os.path as osp
import os
import timm
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm


if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
print(f'Using device: {device}')
# 固定随机种子,保证实验结果是可以复现的
seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True
data_path = "data/split" # todo 数据集路径

# 注: 执行之前请先划分数据集
# 超参数设置
params = {
# 'model': 'vit_tiny_patch16_224', # 选择预训练模型
'model': 'resnet50d', # 选择预训练模型
# 'model': 'efficientnet_b3a', # 选择预训练模型
"img_size": 224, # 图片输入大小
"train_dir": osp.join(data_path, "train"), # todo 训练集路径
"val_dir": osp.join(data_path, "val"), # todo 验证集路径
'device': device, # 设备
'lr': 1e-3, # 学习率
'batch_size': 4, # 批次大小
'num_workers': 0, # 进程
'epochs': 10, # 轮数
"save_dir": "checkpoints", # todo 保存路径
"pretrained": True,
"num_classes": len(os.listdir(osp.join(data_path, "train"))), # 类别数目, 自适应获取类别数目
'weight_decay': 1e-5 # 学习率衰减
}


# 定义模型
class SELFMODEL(nn.Module):
def __init__(self, model_name=params['model'], out_features=params['num_classes'],
pretrained=True):
super().__init__()
self.model = timm.create_model(model_name, pretrained=pretrained) # 从预训练的库中加载模型
# self.model = timm.create_model(model_name, pretrained=pretrained, checkpoint_path="pretrained/resnet50d_ra2-464e36ba.pth") # 从预训练的库中加载模型
# classifier
if model_name[:3] == "res":
n_features = self.model.fc.in_features # 修改全连接层数目
self.model.fc = nn.Linear(n_features, out_features) # 修改为本任务对应的类别数目
elif model_name[:3] == "vit":
n_features = self.model.head.in_features # 修改全连接层数目
self.model.head = nn.Linear(n_features, out_features) # 修改为本任务对应的类别数目
else:
n_features = self.model.classifier.in_features
self.model.classifier = nn.Linear(n_features, out_features)
# resnet修改最后的全链接层
print(self.model) # 返回模型

def forward(self, x): # 前向传播
x = self.model(x)
return x


# 定义训练流程
def train(train_loader, model, criterion, optimizer, epoch, params):
metric_monitor = MetricMonitor() # 设置指标监视器
model.train() # 模型设置为训练模型
nBatch = len(train_loader)
stream = tqdm(train_loader)
for i, (images, target) in enumerate(stream, start=1): # 开始训练
images = images.to(params['device'], non_blocking=True) # 加载数据
target = target.to(params['device'], non_blocking=True) # 加载模型
output = model(images) # 数据送入模型进行前向传播
loss = criterion(output, target.long()) # 计算损失
f1_macro = calculate_f1_macro(output, target) # 计算f1分数
recall_macro = calculate_recall_macro(output, target) # 计算recall分数
acc = accuracy(output, target) # 计算准确率分数
metric_monitor.update('Loss', loss.item()) # 更新损失
metric_monitor.update('F1', f1_macro) # 更新f1
metric_monitor.update('Recall', recall_macro) # 更新recall
metric_monitor.update('Accuracy', acc) # 更新准确率
optimizer.zero_grad() # 清空学习率
loss.backward() # 损失反向传播
optimizer.step() # 更新优化器
lr = adjust_learning_rate(optimizer, epoch, params, i, nBatch) # 调整学习率
stream.set_description( # 更新进度条
"Epoch: {epoch}. Train. {metric_monitor}".format(
epoch=epoch,
metric_monitor=metric_monitor)
)
return metric_monitor.metrics['Accuracy']["avg"], metric_monitor.metrics['Loss']["avg"] # 返回结果


# 定义验证流程
def validate(val_loader, model, criterion, epoch, params):
metric_monitor = MetricMonitor() # 验证流程
model.eval() # 模型设置为验证格式
stream = tqdm(val_loader) # 设置进度条
with torch.no_grad(): # 开始推理
for i, (images, target) in enumerate(stream, start=1):
images = images.to(params['device'], non_blocking=True) # 读取图片
target = target.to(params['device'], non_blocking=True) # 读取标签
output = model(images) # 前向传播
loss = criterion(output, target.long()) # 计算损失
f1_macro = calculate_f1_macro(output, target) # 计算f1分数
recall_macro = calculate_recall_macro(output, target) # 计算recall分数
acc = accuracy(output, target) # 计算acc
metric_monitor.update('Loss', loss.item()) # 后面基本都是更新进度条的操作
metric_monitor.update('F1', f1_macro)
metric_monitor.update("Recall", recall_macro)
metric_monitor.update('Accuracy', acc)
stream.set_description(
"Epoch: {epoch}. Validation. {metric_monitor}".format(
epoch=epoch,
metric_monitor=metric_monitor)
)
return metric_monitor.metrics['Accuracy']["avg"], metric_monitor.metrics['Loss']["avg"]


# 展示训练过程的曲线
def show_loss_acc(acc, loss, val_acc, val_loss, sava_dir):
# 从history中提取模型训练集和验证集准确率信息和误差信息
# 按照上下结构将图画输出
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()), 1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
# 保存在savedir目录下。
save_path = osp.join(save_dir, "results.png")
plt.savefig(save_path, dpi=100)


if __name__ == '__main__':
accs = []
losss = []
val_accs = []
val_losss = []
data_transforms = get_torch_transforms(img_size=params["img_size"]) # 获取图像预处理方式
train_transforms = data_transforms['train'] # 训练集数据处理方式
valid_transforms = data_transforms['val'] # 验证集数据集处理方式
train_dataset = datasets.ImageFolder(params["train_dir"], train_transforms) # 加载训练集
valid_dataset = datasets.ImageFolder(params["val_dir"], valid_transforms) # 加载验证集
if params['pretrained'] == True:
save_dir = osp.join(params['save_dir'], params['model'] + "_pretrained_" + str(params["img_size"])) # 设置模型保存路径
else:
save_dir = osp.join(params['save_dir'],
params['model'] + "_nopretrained_" + str(params["img_size"])) # 设置模型保存路径
if not osp.isdir(save_dir): # 如果保存路径不存在的话就创建
os.makedirs(save_dir) #
print("save dir {} created".format(save_dir))
train_loader = DataLoader( # 按照批次加载训练集
train_dataset, batch_size=params['batch_size'], shuffle=True,
num_workers=params['num_workers'], pin_memory=True,
)
val_loader = DataLoader( # 按照批次加载验证集
valid_dataset, batch_size=params['batch_size'], shuffle=False,
num_workers=params['num_workers'], pin_memory=True,
)
print(train_dataset.classes)
model = SELFMODEL(model_name=params['model'], out_features=params['num_classes'],
pretrained=params['pretrained']) # 加载模型
# model = nn.DataParallel(model) # 模型并行化,提高模型的速度
# resnet50d_1epochs_accuracy0.50424_weights.pth
model = model.to(params['device']) # 模型部署到设备上
criterion = nn.CrossEntropyLoss().to(params['device']) # 设置损失函数
optimizer = torch.optim.AdamW(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay']) # 设置优化器
# 损失函数和优化器可以自行设置修改。
# criterion = nn.CrossEntropyLoss().to(params['device']) # 设置损失函数
# optimizer = torch.optim.AdamW(model.parameters(), lr=params['lr'], weight_decay=params['weight_decay']) # 设置优化器
best_acc = 0.0 # 记录最好的准确率
# 只保存最好的那个模型。
for epoch in range(1, params['epochs'] + 1): # 开始训练
acc, loss = train(train_loader, model, criterion, optimizer, epoch, params)
val_acc, val_loss = validate(val_loader, model, criterion, epoch, params)
accs.append(acc)
losss.append(loss)
val_accs.append(val_acc)
val_losss.append(val_loss)
if val_acc >= best_acc:
# 保存的时候设置一个保存的间隔,或者就按照目前的情况,如果前面的比后面的效果好,就保存一下。
# 按照间隔保存的话得不到最好的模型。
save_path = osp.join(save_dir, f"{params['model']}_{epoch}epochs_accuracy{acc:.5f}_weights.pth")
torch.save(model.state_dict(), save_path)
best_acc = val_acc
show_loss_acc(accs, losss, val_accs, val_losss, save_dir)
print("训练已完成,模型和训练日志保存在: {}".format(save_dir))

image.png

模型测试

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164

from torch.utils.data import DataLoader
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

# 最好是把配置文件写在一起,如果写在一起的话,方便进行查看
from torchutils import *
from torchvision import datasets, models, transforms
import os.path as osp
import os
from train import SELFMODEL
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt


if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
print(f'Using device: {device}')
# 固定随机种子,保证实验结果是可以复现的
seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

data_path = "data/split" # todo 修改为数据集根目录
model_path = "checkpoints/resnet50d_pretrained_224/resnet50d_10epochs_accuracy0.99479_weights.pth" # todo 模型地址
model_name = 'resnet50d' # todo 模型名称
img_size = 224 # todo 数据集训练时输入模型的大小
# 注: 执行之前请先划分数据集
# 超参数设置
params = {
# 'model': 'vit_tiny_patch16_224', # 选择预训练模型
# 'model': 'efficientnet_b3a', # 选择预训练模型
'model': model_name, # 选择预训练模型
"img_size": img_size, # 图片输入大小
"test_dir": osp.join(data_path, "test"), # todo 测试集子目录
'device': device, # 设备
'batch_size': 4, # 批次大小
'num_workers': 0, # 进程
"num_classes": len(os.listdir(osp.join(data_path, "train"))), # 类别数目, 自适应获取类别数目
}


def test(val_loader, model, params, class_names):
metric_monitor = MetricMonitor() # 验证流程
model.eval() # 模型设置为验证格式
stream = tqdm(val_loader) # 设置进度条

# 对模型分开进行推理
test_real_labels = []
test_pre_labels = []
with torch.no_grad(): # 开始推理
for i, (images, target) in enumerate(stream, start=1):
images = images.to(params['device'], non_blocking=True) # 读取图片
target = target.to(params['device'], non_blocking=True) # 读取标签
output = model(images) # 前向传播
# loss = criterion(output, target.long()) # 计算损失
# print(output)
target_numpy = target.cpu().numpy()
y_pred = torch.softmax(output, dim=1)
y_pred = torch.argmax(y_pred, dim=1).cpu().numpy()
test_real_labels.extend(target_numpy)
test_pre_labels.extend(y_pred)
# print(target_numpy)
# print(y_pred)
f1_macro = calculate_f1_macro(output, target) # 计算f1分数
recall_macro = calculate_recall_macro(output, target) # 计算recall分数
acc = accuracy(output, target) # 计算acc
# metric_monitor.update('Loss', loss.item()) # 后面基本都是更新进度条的操作
metric_monitor.update('F1', f1_macro)
metric_monitor.update("Recall", recall_macro)
metric_monitor.update('Accuracy', acc)
stream.set_description(
"mode: {epoch}. {metric_monitor}".format(
epoch="test",
metric_monitor=metric_monitor)
)
class_names_length = len(class_names)
heat_maps = np.zeros((class_names_length, class_names_length))
for test_real_label, test_pre_label in zip(test_real_labels, test_pre_labels):
heat_maps[test_real_label][test_pre_label] = heat_maps[test_real_label][test_pre_label] + 1

# print(heat_maps)
heat_maps_sum = np.sum(heat_maps, axis=1).reshape(-1, 1)
# print(heat_maps_sum)
# print()
heat_maps_float = heat_maps / heat_maps_sum
# print(heat_maps_float)
# title, x_labels, y_labels, harvest
show_heatmaps(title="heatmap", x_labels=class_names, y_labels=class_names, harvest=heat_maps_float,
save_name="record/heatmap_{}.png".format(model_name))
# 加上模型名称

return metric_monitor.metrics['Accuracy']["avg"], metric_monitor.metrics['F1']["avg"], \
metric_monitor.metrics['Recall']["avg"]


def show_heatmaps(title, x_labels, y_labels, harvest, save_name):
# 这里是创建一个画布
fig, ax = plt.subplots()
# cmap https://blog.csdn.net/ztf312/article/details/102474190
im = ax.imshow(harvest, cmap="OrRd")
# 这里是修改标签
# We want to show all ticks...
ax.set_xticks(np.arange(len(y_labels)))
ax.set_yticks(np.arange(len(x_labels)))
# ... and label them with the respective list entries
ax.set_xticklabels(y_labels)
ax.set_yticklabels(x_labels)

# 因为x轴的标签太长了,需要旋转一下,更加好看
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")

# 添加每个热力块的具体数值
# Loop over data dimensions and create text annotations.
for i in range(len(x_labels)):
for j in range(len(y_labels)):
text = ax.text(j, i, round(harvest[i, j], 2),
ha="center", va="center", color="black")
ax.set_xlabel("Predict label")
ax.set_ylabel("Actual label")
ax.set_title(title)
fig.tight_layout()
plt.colorbar(im)
plt.savefig(save_name, dpi=100)
# plt.show()


if __name__ == '__main__':
data_transforms = get_torch_transforms(img_size=params["img_size"]) # 获取图像预处理方式
# train_transforms = data_transforms['train'] # 训练集数据处理方式
valid_transforms = data_transforms['val'] # 验证集数据集处理方式
# valid_dataset = datasets.ImageFolder(params["val_dir"], valid_transforms) # 加载验证集
# print(valid_dataset)
test_dataset = datasets.ImageFolder(params["test_dir"], valid_transforms)
class_names = test_dataset.classes
print(class_names)
# valid_dataset = datasets.ImageFolder(params["val_dir"], valid_transforms) # 加载验证集
test_loader = DataLoader( # 按照批次加载训练集
test_dataset, batch_size=params['batch_size'], shuffle=True,
num_workers=params['num_workers'], pin_memory=True,
)

# 加载模型
model = SELFMODEL(model_name=params['model'], out_features=params['num_classes'],
pretrained=False) # 加载模型结构,加载模型结构过程中pretrained设置为False即可。
weights = torch.load(model_path)
model.load_state_dict(weights)
model.eval()
model.to(device)
# 指标上的测试结果包含三个方面,分别是acc f1 和 recall, 除此之外,应该还有相应的热力图输出,整体会比较好看一些。
acc, f1, recall = test(test_loader, model, params, class_names)
print("测试结果:")
print(f"acc: {acc}, F1: {f1}, recall: {recall}")
print("测试完成,heatmap保存在{}下".format("record"))

测试结果:

  • acc: 0.8940298507462686
  • F1:0.8430395640843399
  • recall:0.8469320066334989

image.png

项目优化

模型优化

评价指标

accuracy=TP+TNTP+TN+FP+FNaccuracy = \frac{TP+TN}{TP+TN+FP+FN}

precision=TPTP+FPprecision = \frac{TP}{TP+FP}

recall=TPTP+FNrecall = \frac{TP}{TP+FN}

F1=2×(precision×recall)precision+recallF1=\frac{2×(precision×recall)}{precision+recall}

其中,评价指标中各参数含义如下图所示
image.png

模型对比

不同预训练模型在测试集上效果对比

模型 Accuracy F1 Recall
resnet50d 0.8940298507462686 0.8430395640843399 0.8469320066334989
vit_tiny_patch16_224 0.4246268656716418 0.3227683013503913 0.3438379530916846
efficientnet_b3a 0.8649253731343284 0.801660743899549 0.8065588723051409
resnet14t 0.8671641791044776 0.805491589670693 0.8111774461028192
模型优化

resnet50dresnet14tefficientnet_b3a预训练模型进行训练后,进行模型融合(平均池化)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# -*- coding: utf-8 -*-
# @Time : 2024/1/18 0:10
# @Author : wpixiu
# @File : test_model_fusion.py
# @Software: PyCharm

from torch.utils.data import DataLoader
from PIL import ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True

# 最好是把配置文件写在一起,如果写在一起的话,方便进行查看
from torchutils import *
from torchvision import datasets
import os.path as osp
import os
from train import SELFMODEL
import numpy as np
from tqdm import tqdm

if torch.cuda.is_available():
device = torch.device('cuda:0')
else:
device = torch.device('cpu')
print(f'Using device: {device}')
# 固定随机种子,保证实验结果是可以复现的
seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

data_path = "data/split" # todo 修改为数据集根目录

img_size = 224 # todo 数据集训练时输入模型的大小
# 注: 执行之前请先划分数据集
# 超参数设置
params = {
# 'model': 'vit_tiny_patch16_224', # 选择预训练模型
# 'model': 'efficientnet_b3a', # 选择预训练模型
# 'model': model_name, # 选择预训练模型
"img_size": img_size, # 图片输入大小
"test_dir": osp.join(data_path, "test"), # todo 测试集子目录
'device': device, # 设备
'batch_size': 4, # 批次大小
'num_workers': 0, # 进程
"num_classes": len(os.listdir(osp.join(data_path, "train"))), # 类别数目, 自适应获取类别数目
}


def test_with_model_fusion(val_loader, models, params, class_names):
metric_monitor = MetricMonitor() # 验证流程
stream = tqdm(val_loader) # 设置进度条

# 初始化融合预测结果列表
fused_preds = []

with torch.no_grad(): # 开始推理
for i, (images, target) in enumerate(stream, start=1):
images = images.to(params['device'], non_blocking=True) # 读取图片
target = target.to(params['device'], non_blocking=True) # 读取标签

# 对每个模型进行推理并获取预测结果
model_preds = [model(images) for model in models]

# 执行模型融合(这里使用平均池化作为示例)
fused_pred = torch.mean(torch.stack(model_preds), dim=0)
fused_preds.append(fused_pred)

# 计算评价指标
f1_macro = calculate_f1_macro(fused_pred, target) # 计算f1分数
recall_macro = calculate_recall_macro(fused_pred, target) # 计算recall分数
acc = accuracy(fused_pred, target) # 计算准确率

metric_monitor.update('F1', f1_macro)
metric_monitor.update("Recall", recall_macro)
metric_monitor.update('Accuracy', acc)
stream.set_description(
"mode: {epoch}. {metric_monitor}".format(
epoch="test",
metric_monitor=metric_monitor)
)

return metric_monitor.metrics['Accuracy']["avg"], metric_monitor.metrics['F1']["avg"], \
metric_monitor.metrics['Recall']["avg"]


if __name__ == '__main__':
data_transforms = get_torch_transforms(img_size=params["img_size"]) # 获取图像预处理方式
# train_transforms = data_transforms['train'] # 训练集数据处理方式
valid_transforms = data_transforms['val'] # 验证集数据集处理方式
# valid_dataset = datasets.ImageFolder(params["val_dir"], valid_transforms) # 加载验证集
# print(valid_dataset)
test_dataset = datasets.ImageFolder(params["test_dir"], valid_transforms)
class_names = test_dataset.classes
print(class_names)
# valid_dataset = datasets.ImageFolder(params["val_dir"], valid_transforms) # 加载验证集
test_loader = DataLoader( # 按照批次加载训练集
test_dataset, batch_size=params['batch_size'], shuffle=True,
num_workers=params['num_workers'], pin_memory=True,
)

# 加载模型
model_1 = SELFMODEL(model_name="resnet50d", out_features=params['num_classes'],
pretrained=False) # 加载模型结构,加载模型结构过程中pretrained设置为False即可。
weights_1 = torch.load("checkpoints/resnet50d_pretrained_224/resnet50d_10epochs_accuracy0.99479_weights.pth")
model_1.load_state_dict(weights_1)

model_2 = SELFMODEL(model_name="efficientnet_b3a", out_features=params['num_classes'],
pretrained=False) # 加载模型结构,加载模型结构过程中pretrained设置为False即可。
weights_2 = torch.load(
"checkpoints/efficientnet_b3a_pretrained_224/efficientnet_b3a_8epochs_accuracy0.98735_weights.pth")
model_2.load_state_dict(weights_2)

model_3 = SELFMODEL(model_name="resnet14t", out_features=params['num_classes'],
pretrained=False) # 加载模型结构,加载模型结构过程中pretrained设置为False即可。
weights_3 = torch.load("checkpoints/resnet14t_pretrained_224/resnet14t_9epochs_accuracy0.97718_weights.pth")
model_3.load_state_dict(weights_3)

model_1.eval()
model_1.to(device)

model_2.eval()
model_2.to(device)

model_3.eval()
model_3.to(device)

models = [model_1, model_2, model_3]

acc, f1, recall = test_with_model_fusion(test_loader, models, params, class_names)

print("测试结果:")
print(f"acc: {acc}, F1: {f1}, recall: {recall}")


模型融合后,在测试集上性能得到提高。
image.png

数据优化

采用backgroundremover库去除图片背景,减少对图片分类的干扰
效果图:
image.png

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# -*- coding: utf-8 -*-
# @Time : 2024/1/15 0:08
# @Author : wpixiu
# @File : data_enhance.py
# @Software: PyCharm

import os


# C:\Users\wj\.u2net\u2net.pth
# os.system('backgroundremover -i "data/Mushrooms/Agaricus/000_ePQknW8cTp8.jpg" -o "test.png"')


def background_removal(data_dir, enhance_dir):
# 遍历Mushrooms文件夹下的所有子文件夹
for foldername in os.listdir(data_dir):
folder_path = os.path.join(data_dir, foldername)
# print(folder_path)
if os.path.isdir(folder_path):
# 获取原文件夹名(不带路径)
base_folder_name = os.path.basename(folder_path)
# print(base_folder_name)
# 创建目标保存文件夹(如果不存在)
enhance_sub_dir = os.path.join(enhance_dir, base_folder_name)
# print(enhance_sub_dir)
os.makedirs(enhance_sub_dir, exist_ok=True)
# 遍历文件夹中的所有文件
for filename in os.listdir(folder_path):
file_path = os.path.join(folder_path, filename)
# 检查文件是否为图片(这里假设图片扩展名是常见的格式)
if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.gif')):
# 构建背景移除命令字符串
command = f'backgroundremover -i "{file_path}" -o "{enhance_sub_dir}/{filename}"'
print(command)
# 执行命令
os.system(command)


if __name__ == '__main__':
background_removal("data/merge", "data/data_enhance/Mushrooms")


测试集resnet50d模型不同数据集效果对比

图形化界面构建

采用pyqt5构建图像化界面
主界面
image.png
蘑菇识别界面
image.png
图片上传界面
image.png
图片识别界面
image.png

项目部署

项目打包

导出环境依赖

1
2
3
pip install pipreqs

pipreqs ./ --encoding=utf8 --force

使用pyinstaller将项目打包成.exe文件

1
2
3
4
5
pip install pyinstaller

# --python=[] 解释器位置
# --noconsole 不显示命令窗口
pyinstaller --python=["D:\anaconda3\envs\flowers\pythonw.exe"] -D --noconsole window.py

打包后的项目存放在dist
imagescheckpoints复制到dist/window
image.png

项目部署

dist/window上传到所需位置,运行window.exe即可。