百度360必应搜狗淘宝本站头条
当前位置:网站首页 > 技术文章 > 正文

通过Unsloth微调Qwen2.5-VL实现复杂数学公式的OCR

nanshan 2025-05-08 20:15 19 浏览 0 评论

【学习目标】

  • 理解Unsloth的核心优化原理与基础实践;
  • 掌握基于Unsloth的高效微调工作流。

【知识储备】

1. Unsloth简介

Unsloth是一个专为大型语言模型(LLM)设计的微调框架,旨在提高微调效率并减少显存占用。 它通过手动推导计算密集型数学步骤并手写 GPU 内核,实现了无需硬件更改即可显著加快训练速度。

主要功能点:

  • 高效微调:Unsloth通过深度优化,使 LLM 的微调速度提高 2-5 倍,显存使用量减少约 80%,且准确度无明显下降。
  • 广泛的模型支持:目前支持的模型包括目前各类主流模型,用户可以根据需求适合的模型进行微

调。

  • 兼容性:Unsloth与HuggingFace生态兼容,用户可以轻松将其与 traformers、peft、trl 等库结合,实现模型的全参微调(full)、监督微调(SFT)和广义强化学习优化(GRPO)、基于人类反馈的奖励建模(包括DPO、ORPO、KTO等方法)、持续预训练(continued pretraining)、文本补全(text completion)以及其他前沿训练方法。
  • 内存优化: 通过 4 位和 16 位的 QLoRA/LoRA 微调,unsloth 显著了显存占用,使得在资源受限的环境中也能大的微调。

Unsloth核心优势:

  • Unsloth简化了整个微调工作流程,包括模型加载、量化、训练、评估、运行、保存、导出,以及与推理引擎(如Ollama、llama.cpp和vLLM)的集成;
  • Unsloth相比传统方法,Unsloth 能够在更短的时间内、更少的显存消耗完成微调任务,节省时间及硬件成本;
  • Unsloth定期与Huggingface、Google和Meta团队合作,以修复LLM训练和模型中的错误(例如,之前有报告有为Gemma 3和Phi-4所做的错误排查工作)。因此,在使用Unsloth进行模型微调时能看到最准确的结果。
  • 开源免费: Unsloth提供开源版本,用户可以在 Google Colab 或 Kaggle Notebooks 上免费试用,方便上手体验。

总的来说,unsloth 为大型语言模型的微调提供了高效、低成本的解决方案,适合希望在有限资源下进行模型微调的开发者和研究人员。

【任务实施】

1. 运行环境要求

1.1、硬件环境

名称

建议配置

1

CPU

Intel I7

2

显卡

NVIDIA GeForce RTX 4090

3

内存

16G

4

系统

Ubuntu20.04 +

注:根据微调的模型参数及量化方法不同,显存要求也会不一样,参考值如下:

参数量

QLoRA (4-bit)

LoRA (16-bit)

3B

3.5 GB

8 GB

7B

5 GB

19 GB

8B

6 GB

22 GB

9B

6.5 GB

24 GB

11B

7.5 GB

29 GB

14B

8.5 GB

33 GB

27B

22 GB

64 GB

32B

26 GB

76 GB

40B

30 GB

96 GB

70B

41 GB

164 GB

81B

48 GB

192 GB

90B

53 GB

212 GB

405B

237 GB

950 GB

1.2、软件环境

名称

版本

1

Python

3.10+

2

CUDA

12.1+

3

JupyterLab

3.5+

2. Unsloth安装

2.1、创建并配置虚拟环境

打开一个新的命令行终端,创建Conda新环境,名称可自定义,这里以"unsloth"为例:

$ conda create -n unsloth python=3.11 ipykernel -y

激活新建的环境:

$ conda activate unsloth

激活后,终端提示符通常会显示环境名称(unsloth),表示您已在该环境当中。

unsloth虚拟环境加入到Jupyterlab的内核中,以便后续.ipynb文档可以选择该环境运行:

$ python -m ipykernel install --user --name=unsloth --display-name "unsloth"

运行后,点击右上角内核切换按钮,进行内核切换,查看是否有出现unsloth内核,如果没有请在菜单栏重启内核再操作:

2.2、Unsloth安装

In [ ]:

import sys
PYTHON_PATH=sys.executable
print(PYTHON_PATH)

In [ ]:

%%capture
!{PYTHON_PATH} -m pip install unsloth modelscope ipywidgets tensorboard
  • %%capture:隐藏命令的输出,避免安装过程中的冗长日志刷屏。但注意观察右上角的运行状态,显示"忙碌",请耐心等待。

如果是开发环境,可以继续运行以下命令,从 GitHub 仓库安装Unsloth的最新开发版(可能包含未发布的修复或功能)。

In [ ]:

!{PYTHON_PATH} -m pip install \
--force-reinstall \
--no-cache-dir \
--no-deps \
git+https://github.com/unslothai/unsloth.git

2.3、验证Unsloth

运行以下命令查看Unsloth的安装情况 ,如果安装成功,会显示版本号等信息。

In [ ]:

!{PYTHON_PATH} -m pip show unsloth

3. 通过Unsloth进行Qwen2.5-VL模型推理

3.1、Qwen多模态模型下载

通过ModelScope SDK将Qwen2.5-VL多模态模型下载到指定目录,使用的是7B经过指令微调后的模型。

In [ ]:

import os
from modelscope import snapshot_download

# 定义基座模型以及模型存放目录
MODEL_NAME_OR_PATH = "models/Qwen2.5-VL-7B-Instruct"
BASE_MODEL = "unsloth/Qwen2.5-VL-7B-Instruct"


# 如目录不存在,则下载模型
if not os.path.exists(MODEL_NAME_OR_PATH):
    snapshot_download(BASE_MODEL, local_dir=MODEL_NAME_OR_PATH)
# 目录已存在,打印文件列表
else:
    print("模型已存在,跳过下载")
    files = [item for item in os.listdir(MODEL_NAME_OR_PATH) if not item.startswith('.')]
    for file in files:
        print(file)

3.2. 导入相关依赖库

In [ ]:

from unsloth import FastVisionModel  
import torch
from PIL import Image, ImageOps
from IPython.display import display
from transformers import TextStreamer

3.3、加载模型和分词器

In [ ]:

model, tokenizer = FastVisionModel.from_pretrained(
    model_name=MODEL_NAME_OR_PATH,
    max_seq_length=2048,
    dtype=None,
    load_in_4bit=True,
    load_in_8bit=False,
    full_finetuning=False,
)

3.4、微调前的模型推理

将推理过程封装成一个函数,方便后续多次调用,代码如下:

In [ ]:

def inference(text, image_file, system_prompt = None):
    """
    推理函数

    Args
        text: 输入的文本
        image_file: 图片文件路径
        system_prompt: 系统提示语,默认为None
    """
    # 显示图片
    image = Image.open(image_file)
    image = ImageOps.exif_transpose(image) 
    display(image)

    # 将模型切换到推理模式(会关闭 dropout 等训练专用层,优化推理速度)
    FastVisionModel.for_inference(model)

    # 构造符合ChatML风格的输入消息
    messages = []
    if system_prompt: 
        messages.append({"role": "system", "content": [{"type":"text", "text": system_prompt}]})
    messages = [
        {"role": "user", "content":
            [
                {"type": "image"},
                {"type": "text", "text": text}
            ]}
    ]

    # 将messages转换为模型所需的对话格式字符串
    input_text = tokenizer.apply_chat_template(
        messages, 
        tokenize=False,
        add_generation_prompt=True)

    # 图像会被编码为视觉特征向量,文本按正常分词流程处理
    # 输出包含input_ids(文本)、pixel_values(图像)等键的字典
    model_inputs = tokenizer(
        text=input_text,
        images=image,
        padding=True,
        add_special_tokens=False,
        return_tensors="pt"
    )

    model_inputs = model_inputs.to(model.device)

    # 通过TextStreamer实现流式输出
    model.generate(
        **model_inputs,
        max_new_tokens=512,
        use_cache = True, 
        temperature = 1.5, 
        min_p = 0.1,
        streamer=TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True),
    )

函数编写完成后,现在我们来先传递一个问题、一张图片给函数,看看Unsloth框架的模型推理效果。

In [ ]:

inference(text="图片表达了什么?" , image_file="assets/candy.jpg")

那么在进行微调前,我们首先验证下原模型Qwen2.5-VL,对含有数学公式的图片识别效果怎么样,同样调用上面的推理函数inference:

In [ ]:

inference(
    text="为图片生成LaTeX表达式", 
    image_file="assets/demo_pic_1.jpg", 
    system_prompt="你是一个LaText OCR助手,目标是读取用户输入的照片,转换成LaTex公式"
)

观察模型生成的结果,将该结果通过LaTeX公式生成器验证下,看是否正确,同时记录起来供后续做对比。

对比两者,可以发现模型虽然在提示词的作用下,发挥作用,但是回答的并不正确。但下来我们需要对它进行微调,使其更适应处理复杂数学公式。

4. 通过Unsloth微调Qwen2.5-VL实现复杂数学公式的OCR

4.1、微调数据集的准备

通过huggingface datasets库下载数据集:

In [ ]:

from datasets import load_dataset

# 定义数据集名称及保存路径
dataset_name = "unsloth/LaTeX_OCR"
dataset_dir = "datasets/LaTeX_OCR"
exist = os.path.exists(dataset_dir)

dataset = load_dataset(dataset_name, split="train", cache_dir=dataset_dir )

if not exist:
    print(f"数据集已下载保存到{dataset_dir},共 {len(dataset)} 条样本")
else:
    print(f"数据集已存在,已从{dataset_dir}加载数据集")
数据集已存在,已从datasets/LaTeX_OCR加载数据集

让我们来简单了解一下这个数据集。我们看一看第三张图片是什么,以及对应的标题是什么。

In [3]:

dataset[2]["image"]

Out[3]:

In [4]:

dataset[2]["text"]

Out[4]:

'H ^ { \\prime } = \\beta N \\int d \\lambda \\biggl \\{ \\frac { 1 } { 2 \\beta ^ { 2 } N ^ { 2 } } \\partial _ { \\lambda } \\zeta ^ { \\dagger } \\partial _ { \\lambda } \\zeta + V ( \\lambda ) \\zeta ^ { \\dagger } \\zeta \\biggr \\} \\ .'

我们运行下一行代码,直接在JupyterLab中渲染上述dataset[2]["text"]的LaTeX表达式,看是否与图片一致:

In [5]:

from IPython.display import display, Math

latex = dataset[2]["text"]
display(Math(latex))

H′=βN∫dλ{12β2N2λζ+λζ+V(λ)ζ+ζ} .H′=βN∫dλ{12β2N2λζ+λζ+V(λ)ζ+ζ} .

可以发现与原图的公式一模型一样。

那么了解完数据集结构之后,我们需要将这些数据格式化成Qwen2.5-VL需要的Json格式(本质上所有视觉微调任务都是类似ChatML格式,ChatML格式仅仅是sharegpt格式的一种特殊情况),如下所示:

[
    { "role": "user",
    "content": [{"type": "text",  "text": Q}, {"type": "image", "image": image} ]
    },
    { "role": "assistant",
    "content": [{"type": "text",  "text": A} ]
    },
]

定义数据预处理函数data_process,目的是处理数据集的每条数据,将其格式化成Qwen2.5-VL需要的Json格式:

In [ ]:

instruction = "为图片生成LaTeX表达式"
def data_process(sample):
    conversation = [
        { "role": "user", "content" : [
            {"type" : "text",  "text"  : instruction},
            {"type" : "image", "image" : sample["image"]} ]
        },
        { "role" : "assistant",
          "content" : [
            {"type" : "text",  "text"  : sample["text"]} ]
        },
    ]
    return { "messages" : conversation }

调用数据处理预函数data_process,批量将所有数据格式化为微调输入格式,返回给新的变量converted_dataset

In [ ]:

converted_dataset = [data_process(sample) for sample in dataset]

我们展示下经过格式化后的首条数据内容:

In [ ]:

converted_dataset[0]

4.2、LoRA微调配置

In [ ]:

model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True,
    finetune_language_layers   = True,
    finetune_attention_modules = True,
    finetune_mlp_modules       = True,
    # target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],

    r = 16,
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing="unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

4.3、训练参数配置

In [ ]:

from trl import SFTTrainer, SFTConfig
from unsloth.trainer import UnslothVisionDataCollator
from unsloth import is_bf16_supported
from datetime import datetime

output_dir = f"outputs/exp_{datetime.now().strftime('%Y%m%d_%H%M')}"


# 将模型切换到训练模式
FastVisionModel.for_training(model)

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    data_collator = UnslothVisionDataCollator(model, tokenizer),
    train_dataset = converted_dataset,
    args = SFTConfig(
        output_dir = output_dir,
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 10,
        # num_train_epochs = 2,
        learning_rate = 2e-4,
        fp16 = not is_bf16_supported(),
        bf16 = is_bf16_supported(),

        report_to = "tensorboard",
        logging_steps = 5,
        logging_dir=output_dir,

        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        
        remove_unused_columns = False,
        dataset_text_field = "",
        dataset_kwargs = {"skip_prepare_dataset": True},
        dataset_num_proc = 4,
        max_seq_length = 2048,
    ),
)

4.4、启动训练

打印当前GPU显存信息:

In [ ]:

# 获取索引为0的GPU设备的详细属性
gpu_stats = torch.cuda.get_device_properties(0)
# 返回PyTorch当前预留的显存峰值
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
# GPU的物理显存总量
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}.")
print(f"1)最大显存 = {max_memory} GB.")
print(f"2)预留 {start_gpu_memory} GB 的显存.")

调用train()开始训练:

In [ ]:

trainer_stats = trainer.train()

显示最终内存和时间统计:

In [ ]:

used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"训练耗时:{trainer_stats.metrics['train_runtime']}秒.")
print(
    f"训练耗时:{round(trainer_stats.metrics['train_runtime']/60, 2)}分钟."
)
print(f"峰值预留显存 = {used_memory} GB.")
print(f"LoRA训练专用显存峰值 = {used_memory_for_lora} GB.")
print(f"峰值预留显存占总显存比例 = {used_percentage} %.")
print(f"LoRA训练显存占总显存比例 = {lora_percentage} %.")

4.5、微调后结果分析

指定训练日志所在目录,调用tensorboard命令启动,会在--port指定的端口启动一个可视化WEB服务,在浏览器中打开 http://localhost:6006(如果是云服务器的话,根据IP或映射访问) 即可查看可视化结果。

In [ ]:

!tensorboard --logdir {output_dir} --port 6006

运行以上命令后,打开浏览器访问,如果损失率没有稳定下降,需要调整训练参数重新开始训练。

4.6、模型微调后的推理

现在开始运行微调后的模型,使用相同的推理函数、相同的图片以及提示词:

In [ ]:

inference(
    text="为图片生成LaTeX表达式",
    image_file="assets/demo_pic_1.jpg",
    system_prompt="你是一个LaText OCR助手,目标是读取用户输入的照片,转换成LaTex公式"
)

将输出的结果拷贝到LaTeX公式生成器验证下:

继续与推理前的结果对比,可以发现经过微调后的模型,生成的结果更加接近、符合预期。但由于训练步数/轮次太少,因此生成的结果还并不能完全正确,感兴趣的大家可以继续增大训练轮次,但时间会久些。

4.7、保存微调模型

将最终模型保存为LoRA适配器,可以使用Huggingface的save_pretrained方法进行本地保存,同时也要把分词器保存。

In [ ]:

model.save_pretrained(output_dir) 
tokenizer.save_pretrained(output_dir)
print(f"LoRA权重文件已保存在:{output_dir}")

但上述代码只是保存了LoRA适配器,而不是完整的模型,通过以下代码保存为完整的float16精度模型。该精度的模型可以使用vLLM、transformers等工具进行加载推理。

In [ ]:

new_model_dir = "models/Qwen2.5-VL-7B-LaTeXOCR"
model.save_pretrained_merged(
    new_model_dir,
    tokenizer, 
    save_method="merged_16bit",
    )
print(f"模型已合并并保存到:{new_model_dir}")

model.save_pretrained_merged方法会逐层检查基础模型,并去huggingface下载相应的基础模型,所以尽量开启HF国内镜像源或代理,不然会抵账,下载也需要点时间。

合并保存完成后,观察models/Qwen2.5-VL-7B-LaTeXOCR目录,生成了以下文件:

到此,我们使用Qwen2.5-VL多模态基座模型,通过Unsloth的QLoRA微调方法,成功训练了第一个模型,让其可以识别LaTeX公式。让你对Unsloth有个初始的认识,更多其它模型的训练方法,请继续往下实战。

相关推荐

基于 Linux 快速搭建企业级 DNS 服务器(Bind9 ...

一、引言在大型企业网络或自建系统中,搭建一套高可用、自控的DNS解析服务器至关重要。本文将带你基于Linux环境,从零搭建企业级DNS服务平台,采用Bind9实战配置,确保解析稳定、安...

Linux无法解析域名的解决办法(linux无法解析域名的解决办法有哪些)

如果由于误操作,删除了系统原有的dhcp相关设置就无法正常解析域名。  此时,需要手动修改配置文件:  /etc/resolv.conf  将域名解析服务器手动添加到配置文件中  该文件是DNS域名解...

在centos7 创建基于域名的虚拟主机nginx服务器

直接用ip地址访问首先是不安全,其次不太容易记住,如果你的服务器上的项目有很多个,你创建多个基于Ip的虚拟主机,很容易导致公网ip冲突或乱用的情况。这时候我们就可以选择基于域名的虚拟主机。第一步、安装...

Linux之DNS服务(linux dnsserver)

一、学习路线如下二、DNS介绍1.域名的概念域名由特定的格式组成,用来表示互联网中某一台计算机或者计算机组的名称,能够使人更方便的访问互联网,而不用记住能够被机器直接读取的IP地址。2.DNS(dom...

Linux环境下DNS服务器配置图文详细教程

测试环境为vmware虚拟机下,linux系统为RedHatEnterpriseLinuxServer6.0(Santiago),内核版本Linux2.6.32-71.el6.i686...

构建基于 Linux 的高性能 DNS 服务器

在现代网络架构中,DNS(域名解析)是访问互联网的关键环节。搭建一个高性能、低延迟、可缓存加速的私有DNS服务器,不仅可以提升访问速度,还能增强网络隐私和安全性。本文将基于Linux系统,详细...

从运维的角度带你初识neo4j图形数据库的安装及配置

前言随着公司业务架构的改变,以前我部署环境的时候,一般只是部署Mysql,jdk,tomcat即可,现在还要部署一些nosql,如redis,neo4j,在之前从来没了解过,随着学习的深入而做了一些笔...

[超全整理] Java 程序员必备的 100 条 Linux 命令大全

一、基础操作(10条)#1.ls-查看目录内容ls-l#长格式显示文件和目录ls-a#显示隐藏文件ls-lh#带单位显示文件大小#2.cd-切换目录...

软件测试|一文教你轻松搭建docker环境

前言Docker提供轻量的虚拟化,你能够从Docker获得一个额外抽象层,你能够在单台机器上运行多个Docker微容器,而每个微容器里都有一个微服务或独立应用,例如你可以将Tomcat运行在一个Do...

docker基础知识/尚硅谷docker学习笔记

最近看了好多docker的资料,找了一些尚硅谷docker的教学视频,大概总结了一下前前后后的学习笔记。分享给大家。安装Docker的基本组成镜像Docker镜像(Image)就是一个只读的模板。镜...

前端_react项目从windows部署到centos

前言:从工程角度来讲,本地开发完就要把项目部署到生产环境,此过程的快慢也直接影响着整体的效率。所以也有很多人做持续集成的工作,例如:CI/CD/一键部署。但对于个人开发者而言,如果能有工具支撑是最好的...

Springboot项目使用docker部署(docker中运行springboot项目)

环境:SpringBoot2.2.10.RELEASE+Docker+Centos7+JDK8安装配置Dockeryum包更新到最新yumupdate卸载旧版本dockeryumre...

Spring Boot 3.x + Redis 7.x,轻松掌握Redisson分布式锁实战技巧

大家好,我是袁庭新。在分布式环境中,确保数据的一致性和正确性是至关重要的。对于需要高性能、高并发和分布式数据存储的应用程序来说,Redisson是一个很好的选择。同时,Redisson提供的分布式锁功...

Docker篇(二):Docker实战,命令解析

大家好,我是杰哥上周我们通过几个问题,让大家对于Docker有了一个全局的认识。然而,说跟练往往是两个概念。从学习的角度来说,理论知识的学习,往往只是第一步,只有经过实战,才能真正掌握一门技术所以,本...

新手快速入门Docker,轻松掌握Docker安装与使用

安装使用官方安装脚本自动安装curl-fsSLhttps://get.docker.com|bash-sdocker--mirrorAliyun手动安装CentOS7(使用yum进...

取消回复欢迎 发表评论: