- 作者:老汪软件技巧
- 发表时间:2024-08-31 21:01
- 浏览量:
前言
最近在了解机器学习方面的知识,了解了一些分类相关的算法实现,于是乎也想上手玩玩,毕竟实践出真知。看了看网上有很多 pre-trained 模型,可以实现很多强大的功能,不过还是想基于这些模型做一个二次训练做一个自己的模型,这样才更有成就感,自己也能在实践中学习到一些书本上没有的细节。
灵感
现实中很多物体可以做分类,不过我的灵感来自我的2只可爱的狗狗,一只是金毛,一只是只串串。我一直好奇串串是个什么品种,所以我选择自己做一个狗狗分类器,这样我能比较直观的感受到自己训练的成果是不是准确,毕竟如果训练了一个自己不懂的物体,无法直接的判断模型的好坏。
框架选择
YOLOv8 是一个在目标检测,分割领域使用的最多的框架,使用他主要是他封装的比较齐全,使用简单,毕竟作为一个 Android 程序员我的主要领域都集成在了 Android,Kotlin 这块。
数据集准备
分类算法的数据集基本上都类似于下面的格式
数据集格式介绍
cifar-10-/
|
|-- train/
| |-- airplane/
| | |-- 10008_airplane.png
| | |-- 10009_airplane.png
| | |-- ...
| |
| |-- automobile/
| | |-- 1000_automobile.png
| | |-- 1001_automobile.png
| | |-- ...
| |
| |-- bird/
| | |-- 10014_bird.png
| | |-- 10015_bird.png
| | |-- ...
| |
| |-- ...
|
|-- test/
| |-- airplane/
| | |-- 10_airplane.png
| | |-- 11_airplane.png
| | |-- ...
| |
| |-- automobile/
| | |-- 100_automobile.png
| | |-- 101_automobile.png
| | |-- ...
| |
| |-- bird/
| | |-- 1000_bird.png
| | |-- 1001_bird.png
| | |-- ...
| |
| |-- ...
|
|-- val/ (optional)
| |-- airplane/
| | |-- 105_airplane.png
| | |-- 106_airplane.png
| | |-- ...
| |
| |-- automobile/
| | |-- 102_automobile.png
| | |-- 103_automobile.png
| | |-- ...
| |
| |-- bird/
| | |-- 1045_bird.png
| | |-- 1046_bird.png
| | |-- ...
| |
| |-- ...
数据集必须在根目录下以特定的分割目录结构进行组织,以方便正确的训练、测试和可选的验证过程。这种结构包括用于训练(train)和测试(test)阶段的单独目录,以及用于验证(val)的可选目录。
每个目录都应为数据集中的每个类包含一个子目录。子目录以相应的类命名,并包含该类的所有图像。确保每个图像文件都有唯一的名称,并以 JPEG 或 PNG 等通用格式存储。对于 Ultralytics YOLO 分类任务,数据集必须在根目录下以特定的分割目录结构进行组织,以方便正确的训练、测试和可选的验证过程。这种结构包括用于训练(train)和测试(test)阶段的单独目录,以及用于验证(val)的可选目录。
每个目录都应为数据集中的每个类包含一个子目录。子目录以相应的类命名,并包含该类的所有图像。确保每个图像文件都有唯一的名称,并以 JPEG 或 PNG 等通用格式存储。
一般来说上述的分类数据集是比较通用的,不仅仅适用于 YOLOv8 的分类数据集,也适用于 mobilenet,resnet 等分类框架。
数据集获取
深度学习需要海量的数据集来训练,以保住模型的能力,我们可以自己准备数据集,或者使用网上的开源数据集,一般来说对于开发者来说,有两个地方可以免费获得 dataset(数据集),并且轻易的使用。一个是 huggingface.co/, 另一个是 ,本文中以 huggingface 为例。打开 hugging face,选择 dataset 可以看到下面的页面。
搜索框里面我们输入 dog 选择我们需要的数据集,note:请选择 train 数据集 和 test 数据集都存在的dataset,否则将会存在以下风险,因为模型的训练是需要测试数据进行验证回馈,调整损失函数进行下一步训练的。
无法评估泛化能力:你不能验证模型在未见数据上的表现。过拟合风险:可能无法检测到模型是否过拟合。数据集结构展示
以上面的这个数据集为例,一般数据集结构可以使用 viewer 进行查看,通过上面的 split 可以看到当前数据集的完整性,是否存在 train,test,val split。通过切换我们可以知道这个dataset 的 test split 并没有完整分类,是无法直接用于分类任务的,我们需要换一个数据集来实现。
数据集使用
一般来说网站上都会提供完整的数据集的使用方法,示例如下
如果我们点击每个按钮就会有对应的集成代码,本例中以 datasets 库为例,点击之后会出现一个提示框出现类似下面的代码:
from datasets import load_dataset
ds = load_dataset("amaye15/stanford-dogs")
如果在 pycharm 上直接运行,你将会下载上面的数据集。放心数据集只会运行一次。现在你有了完整的数据集了,你需要的就是按照上面的数据集的结构将图片copy一份过来。
训练/ Train
现在万事俱备,只需要我们按照官方文档 copy 出来代码
from ultralytics import YOLO
# Load a model
model = YOLO("yolov8n-cls.pt") # load a pretrained model (recommended for training)
# Train the model
results = model.train(data="path/to/dataset", epochs=100, imgsz=640)
点击开始,训练就正式开始了。如果不出意外的话,意外就要开始了,你会发现你的进度条不动了,这时候你可能察觉不到异常,因为 YOLOv8 默认的训练是使用 cuda 的,是需要 NVIDIA 显卡来支持的,这时候你需要指定使用 cpu 来训练,当然如果你是高贵的 mac pro 用户,你可以指定设备为 mps 来训练,训练能力排序为 cuda>mps>cpu。
# Ensure the model uses CPU
model.to('cpu')
修复问题我们接着训练,当然为了节省我们不需要讲训练次数 epoch 设置成 100,可以设置成 30 或者50 先看看效果。
如果一切顺利的话,控制台会实时展示训练的成果。
模型验证与导出模型预测
from ultralytics import YOLO
# Load a model
checkpoint = 'your_model_path.pt'
model = YOLO(checkpoint) # pretrained YOLOv8n model
# Run batched inference on a list of images
results = model(["im1.jpg", "im2.jpg"]) # return a list of Results objects
# Process results list
for result in results:
boxes = result.boxes # Boxes object for bounding box outputs
masks = result.masks # Masks object for segmentation masks outputs
keypoints = result.keypoints # Keypoints object for pose outputs
probs = result.probs # Probs object for classification outputs
obb = result.obb # Oriented boxes object for OBB outputs
result.show() # display to screen
result.save(filename="result.jpg") # save to disk
我们可以使用上面的代码找几张测试图片预测一下当前模型,体验一下自己的劳动成果。checkpoint 这个参数可以在控制台看到,一般路径为:runs/classify/trainX/weights。
模型导出
YOLOv8 的模型导出非常简单,可以说是傻瓜式。
from ultralytics import YOLO
# Load a model
model = YOLO("path/to/best.pt") # 加载你自己训练数据集
# 导出为你需要的数据集
model.export(format="onnx")
一般来说移动端的部署框架选择很多,例如 腾讯的 ncnn,Facebook的 pytorch ,Google 的 tensorflow-lite,百度的 paddle,微软的 onnx-runtime。在这里我选择的是 onnx runtime 框架进行部署。确定了部署的框架之后我们就需要导出模型文件了,这里我的设置如上面所示。
在Android 手机上跑起来
我们可以基于 onnx runtime 原有 demo 稍微做下改造,将里面模型路径换成我们自己的路径,里面的 classes txt 换成我们自己的模型 labels 这样就可以成功部署啦。部署之后我们会兴冲冲打开自己App 验证一下模型,可是我们可能会发现自己的模型并不能正确识别自己的狗狗。在这里我就不绕弯子了,由于 onnx 输入图片已经做了 归一化 如果不了解的可以看一下相关介绍,在这里我们只需要在 上面的训练代码添加一个参数即可,
model.train(
data='dataset', # Path to your dataset configuration file
epochs=5, # Number of epochs
workers=4, # Number of workers
batch_size=16, # Batch size
augmented=True, # Use augmented data
)
model.export(format='onnx',simplify=True)
重新验证
经过重新部署验证之后,用自己的狗狗 看看自己训练的结果,终于知道自己的狗宝像哪个品种了了,原来是博美呀
结语
好啦,上面就是我的分享啦,祝周末愉快。附上自己的 repo