目录

使用 LoRA 和 QLoRA 基于 Mistral-7B 微调的实验

LoRA 和 QLoRA 对比

微调

Iteration LoRA Train Loss LoRA Val Loss LoRA Tokens/sec QLoRA Train Loss QLoRA Val Loss QLoRA Tokens/sec
1   2.343     2.420  
100 1.204   221.348 1.216   166.377
200 1.091 1.111 207.353 1.095 1.130 187.795
300 0.818   234.182 1.065   194.826
400 0.837 1.076 207.763 0.998 1.006 170.072
500 0.774   223.036 0.726   189.288
600 0.609 1.001 218.118 0.607 1.015 186.397

微调的参数量

  • LoRA 微调万分之 2.35 (1.704M / 7243.436M * 10000)的模型参数。
  • QLoRA 微调万分之 13.70(1.704M / 1244.041M * 10000)的模型参数。

微调的耗时

  • LoRA 微调 600 次迭代,耗时 20 分 26 秒。
  • QLoRA 微调 600 次迭代,耗时 23 分 40 秒。

微调占用内存

  • LoRA 46G
  • QLoRA 46G

评估

计算测试集困惑度(PPL)和交叉熵损失(Loss)。

Iteration LoRA Test Loss LoRA Test PPL QLoRA Test Loss QLoRA Test PPL
600 1.351 3.863 1.396 4.040

评估占用内存

  • LoRA 26G
  • QLoRA 15G

融合

模型的大小

  • LoRA 13G
  • QLoRA 4G

生成 SQL

王军建的姓名是什么?

Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
Q: What is Wang Junjian's name?
A: 
  • LoRA
    SELECT Name FROM students WHERE School = 'Wang Junjian'
    
    Prompt: 88.790 tokens-per-sec
    Generation: 16.811 tokens-per-sec
    
  • QLoRA
    SELECT Name FROM students WHERE School = 'Wang Junjian'
    
    Prompt: 154.798 tokens-per-sec
    Generation: 107.496 tokens-per-sec
    

王军建的年龄是多少?

Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
Q: How old is Wang Junjian?
A: 
  • LoRA
    SELECT Age FROM Students WHERE Name = 'Wang Junjian'
    
    Prompt: 84.460 tokens-per-sec
    Generation: 16.801 tokens-per-sec
    
  • QLoRA
    11 SELECT Age FROM students WHERE Name = 'Wang Junjian'
    
    Prompt: 156.206 tokens-per-sec
    Generation: 101.284 tokens-per-sec
    

王军建来自哪所学校?

Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Which school did Wang Junjian come from?
A: 
  • LoRA
    SELECT School FROM Students WHERE Name = 'Wang Junjian'
    
    Prompt: 89.124 tokens-per-sec
    Generation: 16.718 tokens-per-sec
    
  • QLoRA
    SELECT School FROM students WHERE Name = 'Wang Junjian'
    
    Prompt: 160.011 tokens-per-sec
    Generation: 121.895 tokens-per-sec
    

查询王军建的姓名、年龄、学校信息。

Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query Wang Junjian’s name, age, and school information.
A: 
  • LoRA
    SELECT Name, Age, School FROM Students WHERE Name = 'Wang Junjian'
    
    Prompt: 100.919 tokens-per-sec
    Generation: 17.139 tokens-per-sec
    
  • QLoRA
    SELECT Name FROM students WHERE Age = 13 AND School = 'Hangzhou Foreign Language School'
    
    Prompt: 183.456 tokens-per-sec
    Generation: 121.376 tokens-per-sec
    

查询王军建的所有信息。

Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query all information about Wang Junjian.
A: 
  • LoRA
    SELECT Name FROM students WHERE Name = 'Wang Junjian'
    
    Prompt: 88.225 tokens-per-sec
    Generation: 16.781 tokens-per-sec
    
  • QLoRA
    SELECT Name FROM students WHERE School = 'Wang Junjian'
    
    Prompt: 151.962 tokens-per-sec
    Generation: 122.067 tokens-per-sec
    

统计一下九年级有多少学生。

Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Count how many students there are in ninth grade.
A: 
  • LoRA
    SELECT COUNT Name FROM Students WHERE Grade = '9th'
    
    Prompt: 93.829 tokens-per-sec
    Generation: 16.546 tokens-per-sec
    
  • QLoRA
    SELECT COUNT Name FROM students WHERE Grade = 9
    
    Prompt: 164.480 tokens-per-sec
    Generation: 115.851 tokens-per-sec
    

统计一下九年级有多少学生(九年级的值是9)。

  • LoRA
Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
The value for ninth grade is 9.
Q: Count how many students there are in ninth grade.
A: 
SELECT COUNT Name FROM students WHERE Grade = 9
Prompt: 117.893 tokens-per-sec
Generation: 16.298 tokens-per-sec
  • QLoRA
Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
The value for ninth grade is 9th.
Q: Count how many students there are in ninth grade.
A: 
 SELECT COUNT Name FROM students WHERE Grade = '9th'
Prompt: 216.152 tokens-per-sec
Generation: 114.300 tokens-per-sec

在相同的 Iteration 次数下 QLoRA 不如 LoRA 的效果。

生成速度

Fine-Tuning Prompt tokens/sec Generation tokens/sec Fine-Tuning Prompt tokens/sec Generation tokens/sec
LoRA 88.790 16.811 QLoRA 154.798 107.496
LoRA 84.460 16.801 QLoRA 156.206 101.284
LoRA 89.124 16.718 QLoRA 160.011 121.895
LoRA 100.919 17.139 QLoRA 183.456 121.376
LoRA 88.225 16.781 QLoRA 151.962 122.067
LoRA 93.829 16.546 QLoRA 164.480 115.851
LoRA 117.893 16.298 QLoRA 216.152 114.300
  • Prompt tokens/sec: QLoRA 是 LoRA 的 1.79 倍
  • Generation tokens/sec: QLoRA 是 LoRA 的 8.87 倍

参考资料