返回首页
MLX Text2SQL

在 MLX 上使用 LoRA / QLoRA 微调 Text2SQL(五):对比使用 LoRA 和 QLoRA 基于 Mistral-7B 微调的效果

使用 LoRA 和 QLoRA 基于 Mistral-7B 微调的实验

LoRA 和 QLoRA 对比

微调

IterationLoRA Train LossLoRA Val LossLoRA Tokens/secQLoRA Train LossQLoRA Val LossQLoRA Tokens/sec
12.3432.420
1001.204221.3481.216166.377
2001.0911.111207.3531.0951.130187.795
3000.818234.1821.065194.826
4000.8371.076207.7630.9981.006170.072
5000.774223.0360.726189.288
6000.6091.001218.1180.6071.015186.397

微调的参数量

  • LoRA 微调万分之 2.35 (1.704M / 7243.436M * 10000)的模型参数。
  • QLoRA 微调万分之 13.70(1.704M / 1244.041M * 10000)的模型参数。

微调的耗时

  • LoRA 微调 600 次迭代,耗时 20 分 26 秒。
  • QLoRA 微调 600 次迭代,耗时 23 分 40 秒。

微调占用内存

  • LoRA 46G
  • QLoRA 46G

评估

计算测试集困惑度(PPL)和交叉熵损失(Loss)。

IterationLoRA Test LossLoRA Test PPLQLoRA Test LossQLoRA Test PPL
6001.3513.8631.3964.040

评估占用内存

  • LoRA 26G
  • QLoRA 15G

融合

模型的大小

  • LoRA 13G
  • QLoRA 4G

生成 SQL

王军建的姓名是什么?

Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
Q: What is Wang Junjian's name?
A: 
  • LoRA
SELECT Name FROM students WHERE School = 'Wang Junjian'
Prompt: 88.790 tokens-per-sec
Generation: 16.811 tokens-per-sec
  • QLoRA
SELECT Name FROM students WHERE School = 'Wang Junjian'
Prompt: 154.798 tokens-per-sec
Generation: 107.496 tokens-per-sec

王军建的年龄是多少?

Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
Q: How old is Wang Junjian?
A: 
  • LoRA
SELECT Age FROM Students WHERE Name = 'Wang Junjian'
Prompt: 84.460 tokens-per-sec
Generation: 16.801 tokens-per-sec
  • QLoRA
11 SELECT Age FROM students WHERE Name = 'Wang Junjian'
Prompt: 156.206 tokens-per-sec
Generation: 101.284 tokens-per-sec

王军建来自哪所学校?

Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Which school did Wang Junjian come from?
A: 
  • LoRA
SELECT School FROM Students WHERE Name = 'Wang Junjian'
Prompt: 89.124 tokens-per-sec
Generation: 16.718 tokens-per-sec
  • QLoRA
SELECT School FROM students WHERE Name = 'Wang Junjian'
Prompt: 160.011 tokens-per-sec
Generation: 121.895 tokens-per-sec

查询王军建的姓名、年龄、学校信息。

Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query Wang Junjian’s name, age, and school information.
A: 
  • LoRA
SELECT Name, Age, School FROM Students WHERE Name = 'Wang Junjian'
Prompt: 100.919 tokens-per-sec
Generation: 17.139 tokens-per-sec
  • QLoRA
SELECT Name FROM students WHERE Age = 13 AND School = 'Hangzhou Foreign Language School'
Prompt: 183.456 tokens-per-sec
Generation: 121.376 tokens-per-sec

查询王军建的所有信息。

Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query all information about Wang Junjian.
A: 
  • LoRA
SELECT Name FROM students WHERE Name = 'Wang Junjian'
Prompt: 88.225 tokens-per-sec
Generation: 16.781 tokens-per-sec
  • QLoRA
SELECT Name FROM students WHERE School = 'Wang Junjian'
Prompt: 151.962 tokens-per-sec
Generation: 122.067 tokens-per-sec

统计一下九年级有多少学生。

Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Count how many students there are in ninth grade.
A: 
  • LoRA
SELECT COUNT Name FROM Students WHERE Grade = '9th'
Prompt: 93.829 tokens-per-sec
Generation: 16.546 tokens-per-sec
  • QLoRA
SELECT COUNT Name FROM students WHERE Grade = 9
Prompt: 164.480 tokens-per-sec
Generation: 115.851 tokens-per-sec

统计一下九年级有多少学生(九年级的值是9)。

  • LoRA
Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
The value for ninth grade is 9.
Q: Count how many students there are in ninth grade.
A: 
SELECT COUNT Name FROM students WHERE Grade = 9
Prompt: 117.893 tokens-per-sec
Generation: 16.298 tokens-per-sec
  • QLoRA
Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
The value for ninth grade is 9th.
Q: Count how many students there are in ninth grade.
A: 
 SELECT COUNT Name FROM students WHERE Grade = '9th'
Prompt: 216.152 tokens-per-sec
Generation: 114.300 tokens-per-sec

在相同的 Iteration 次数下 QLoRA 不如 LoRA 的效果。

生成速度

Fine-TuningPrompt tokens/secGeneration tokens/secFine-TuningPrompt tokens/secGeneration tokens/sec
LoRA88.79016.811QLoRA154.798107.496
LoRA84.46016.801QLoRA156.206101.284
LoRA89.12416.718QLoRA160.011121.895
LoRA100.91917.139QLoRA183.456121.376
LoRA88.22516.781QLoRA151.962122.067
LoRA93.82916.546QLoRA164.480115.851
LoRA117.89316.298QLoRA216.152114.300
  • Prompt tokens/sec: QLoRA 是 LoRA 的 1.79 倍
  • Generation tokens/sec: QLoRA 是 LoRA 的 8.87 倍

参考资料

🤖

智能问答助手

Ollama + AI 问答

⏳ 初始化...

💡 配置和聊天记录仅保存在本地浏览器中