
我训练了一个医疗多模态大模型帮家里老人看病 原创
前言
随着多模态大模型的发展,其不仅限于文字处理,更能够在图像、视频、音频方面进行识别与理解。医疗领域中,医生们往往需要对各种医学图像进行处理,以辅助诊断和治疗。如果将多模态大模型与图像诊断相结合,那么这会极大地提升诊断效率。
项目目标
训练一个医疗多模态大模型,用于图像诊断。
刚好家里老爷子近期略感头疼,去医院做了脑部CT,诊断患有垂体瘤,我将尝试使用多模态大模型进行进一步诊断。
实现过程
1. 数据集准备
为了训练模型,需要准备大量的医学图像数据。通过搜索我们找到以下训练数据:
数据名称:MedTrinity-25M
数据地址:https://github.com/UCSC-VLAA/MedTrinity-25M
数据简介:MedTrinity-25M数据集是一个用于医学图像分析和计算机视觉研究的大型数据集。
数据来源:该数据集由加州大学圣克鲁兹分校(UCSC)提供,旨在促进医学图像处理和分析的研究。
数据量:MedTrinity-25M包含约2500万条医学图像数据,涵盖多种医学成像技术,如CT、MRI和超声等。
数据内容:该数据集有两份,分别是 25Mdemo
和 25Mfull
。
25Mdemo
(约162,000条)数据集内容如下:
25Mfull
(约24,800,000条)数据集内容如下:
2. 数据下载
2.1 安装Hugging Face的Datasets库
2.2 下载数据集
执行结果:
说明:
- 以上方法是使用HuggingFace的Datasets库下载数据集,下载的路径为当前脚本所在路径下的cache文件夹。
- 使用HuggingFace下载需要能够访问https://huggingface.co/ 并且在网站上申请数据集读取权限才可以。
如果没有权限访问HuggingFace,可以关注以下公众号后,回复 “MedTrinity”获取百度网盘下载地址。
2.3 预览数据集
运行结果:
使用如下命令对数据集的图片进行可视化查看:
运行结果:
3. 数据预处理
由于后续我们要通过LLama Factory进行多模态大模型微调,所以我们需要对上述的数据集进行预处理以符合LLama Factory的要求。
3.1 LLama Factory数据格式
查看LLama Factory的多模态数据格式要求如下:
3.2 实现数据格式转换脚本
运行结果:
4. 模型下载
本次微调,我们使用阿里最新发布的多模态大模型:Qwen2-VL-2B-Instruct
作为底座模型。
模型说明地址:https://modelscope.cn/models/Qwen/Qwen2-VL-2B-Instruct
使用如下命令下载模型
5. 环境准备
5.1 机器环境
硬件:
- 显卡:4080 Super
- 显存:16GB
软件:
- 系统:Ubuntu 20.04 LTS
- python:3.10
- pytorch:2.1.2 + cuda12.1
5.2 准备虚拟环境
6. 准备训练框架
下载并安装LLamaFactory框架的具体步骤,请见【课程总结】day24(上):大模型三阶段训练方法(LLaMa Factory)中 准备训练框架 部分内容,本章不再赘述。
6.1 修改LLaMaFactory源码以适配transformer
由于Qwen2-VL使用的transformer
的版本为4.47.0.dev0
,LLamaFactory还不支持,所以需要修改LLaMaFactory的代码,具体方法如下:
第一步:在 llamafactory
源码中,找到 check_dependencies()
函数,这个函数位于 src/llamafactory/extras/misc.py
文件的第 82
行。
第二步:修改 check_dependencies()
函数并保存
第三步:重新启动LLaMaFactory服务
这个过程可能会提示 ImportError: accelerate>=0.34.0 is required for a normal functioning of this module, but found accelerate==0.32.0. 如遇到上述问题,可以重新安装accelerate,如下:
7. 测试当前模型
第一步:启动LLaMa Factory后,访问http://0.0.0.0:7860
第二步:在web页面配置模型路径为 4.步骤
下载的模型路径,并点击加载模型
第三步:上传一张CT图片并输入问题:“请使用中文描述下这个图像并给出你的诊断结果”
由上图可以看到,模型能够识别到这是一个CT图像,显示了大概的位置以及相应的器官,但是并不能给出是否存在诊断结果。
8. 模型训练
8.1 数据准备
第一步:将 3.2步骤
生成的mllm_data文件拷贝到LLaMaFactory的data目录下
第二步:将 4.步骤
下载的底座模型Qwen2-VL 拷贝到LLaMaFactory的model目录下
第三步:修改 LLaMaFactory data目录下的dataset_info.json,增加自定义数据集:
8.2 配置训练参数
访问LLaMaFactory的web页面,配置微调的训练参数:
- Model name:
Qwen2-VL-2B-Instruct
- Model path:
models/Qwen2-VL-2B-Instruct
- Finetuning method:
lora
- Stage :
Supervised Fine-Tuning
- Dataset:
mllm_med
- Output dir:
saves/Qwen2-VL/lora/Qwen2-VL-sft-demo1
配置参数中最好将
save_steps
设置大一点,否则训练过程会生成非常多的训练日志,导致硬盘空间不足而训练终止。
点击Preview Command预览命令行无误后,点击Run按钮开始训练。 训练参数:
训练过程:
训练的过程中,可以通过
watch -n 1 nvidia-smi
实时查看GPU显存的消耗情况。
经过35小时的训练,模型训练完成,损失函数如下:
损失函数一般降低至1.2左右,太低会导致模型过拟合。
8.3 合并导出模型
接下来,我们将 Lora补丁
与 原始模型
合并导出:
- 切换到
Expert
标签下 - Model path: 选择Qwen2-VL的基座模型,即:
models/Qwen2-VL-2B-Instruct
- Checkpoint path: 选择lora微调的输出路径,即
saves/Qwen2-VL/lora/Qwen2-VL-sft-demo1
- Export path:设置一个新的路径,例如:
Qwen2-VL-sft-final
- 点击
开始导出
按钮
导出完毕后,会在LLaMaFactory的根目录下生成一个 Qwen2-VL-sft-final
的文件夹。
9. 模型验证
9.1 模型效果对比
第一步:在LLaMa Factory中卸载之前的模型
第二步:在LLaMa Factory中加载导出的模型,并配置模型路径为 Qwen2-VL-sft-final
第三步:加载模型并上传之前的CT图片提问同样的问题
可以看到,经过微调后的模型,可以给出具体区域存在的可能异常问题。
9.2 实际诊断
接下来,我将使用微调后的模型,为家里老爷子的CT片做诊断,看看模型给出的诊断与大夫的异同点。
我总计测试了CT片上的52张局部结果,其中具有代表性的为上述三张,可以看到模型还是比较准确地诊断出:脑部有垂体瘤,可能会影响到眼部。这与大夫给出的诊断和后续检查方案一致。
不足之处
训练集:
- 多模态:本次训练只是采用了MedTrinity-25Mdemo数据集,如果使用MedTrinity-25Mfull数据集,效果应该会更好。
- 中英文:本次训练集中使用的MedTrinity-25Mdemo数据集,只包含了英文数据,如果将英文标注翻译为中文,提供中英文双文数据集,相信效果会更好。
- 对话数据集:本次训练只是使用了多模态数据集,如果增加中文对话(如:中文医疗对话数据-Chinese-medical-dialogue),相信效果会更好。
前端页面:
- 前端页面:本次实践曾使用streamlit构建前端页面,以便图片上传和问题提出,但是在加载微调后的模型时,会出现:
ValueError: No chat template is set for this processor
问题,所以转而使用LLaMaFactory的web页面进行展示。 - 多个图片推理:在Qwen2-VL的官方指导文档中,提供了
Multi image inference
方法,本次未进行尝试,相信将多个图片交给大模型进行推理,效果会更好。
内容小结
- Qwen2-VL-2B作为多模态大模型,具备有非常强的多模态处理能力,除了能够识别图片内容,还可以进行相关的推理。
- 我们可以通过
LLaMaFactory
对模型进行微调,使得其具备医疗方面的处理能力。 - 微调数据集采用开源的MedTrinity-25M数据集,该数据集有两个版本:25Mdemo和25Mfull。
- 训练前需要对数据集进行预处理,使得其适配LLaMaFactory的微调格式。
- 经过微调后的多模态大模型,不但可以详细地描述图片中的内容,还可以给出可能的诊断结果。
本文转载自公众号一起AI技术 作者:Dongming
原文链接:https://mp.weixin.qq.com/s/_NFvbMlbH7N5YSdo2OoiIQ
