目录

mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL

本次微调的模型我已经上传到了 HuggingFace Hub 上,大家可以进行尝试。

安装 mlx-lm

pip install mlx-lm

生成 SQL

python -m mlx_lm.generate --model mlx-community/Mistral-7B-v0.1-LoRA-Text2SQL \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Which school did Wang Junjian come from?
A: "
SELECT School FROM Students WHERE Name = 'Wang Junjian'

在 MLX 上使用 LoRA / QLoRA 微调 Text2SQL(一):使用 LoRA 基于 Mistral-7B 微调

📌 没有使用模型的标注格式生成数据集,导致不能结束,直到生成最大的 Tokens 数量。

这次我们来解决这个问题。

数据集 WikiSQL

修改脚本 mlx-examples/lora/data/wikisql.py

if __name__ == "__main__":
    # ......
    for dataset, name, size in datasets:
        with open(f"data/{name}.jsonl", "w") as fid:
            for e, t in zip(range(size), dataset):
                """
                t 变量的文本是这样的:
                ------------------------
                <s>table: 1-1058787-1
                columns: Approximate Age, Virtues, Psycho Social Crisis, Significant Relationship, Existential Question [ not in citation given ], Examples
                Q: How many significant relationships list Will as a virtue?
                A: SELECT COUNT Significant Relationship FROM 1-1058787-1 WHERE Virtues = 'Will'</s>                
                """
                t = t[3:] # 去掉开头的 <s>,因为 tokenizer 会自动添加 <s>
                json.dump({"text": t}, fid)
                fid.write("\n")

执行脚本 data/wikisql.py 生成数据集。

样本示例

table: 1-10753917-1
columns: Season, Driver, Team, Engine, Poles, Wins, Podiums, Points, Margin of defeat
Q: Which podiums did the alfa romeo team have?
A: SELECT Podiums FROM 1-10753917-1 WHERE Team = 'Alfa Romeo'</s>

微调

LoRA 微调

python lora.py --model mistralai/Mistral-7B-v0.1 \
               --train \
               --iters 600
Loading pretrained model
Total parameters 7243.436M
Trainable parameters 1.704M
Loading datasets
Training
Iter 1: Val loss 2.343, Val took 24.272s
Iter 10: Train loss 2.237, It/sec 0.412, Tokens/sec 165.740
Iter 20: Train loss 1.688, It/sec 0.510, Tokens/sec 206.577
Iter 30: Train loss 1.475, It/sec 0.526, Tokens/sec 216.519
Iter 40: Train loss 1.359, It/sec 0.539, Tokens/sec 208.807
Iter 50: Train loss 1.243, It/sec 0.567, Tokens/sec 225.619
Iter 60: Train loss 1.125, It/sec 0.567, Tokens/sec 224.679
Iter 70: Train loss 1.177, It/sec 0.485, Tokens/sec 196.413
Iter 80: Train loss 1.180, It/sec 0.512, Tokens/sec 205.216
Iter 90: Train loss 1.152, It/sec 0.593, Tokens/sec 224.874
Iter 100: Train loss 1.204, It/sec 0.581, Tokens/sec 221.348
Iter 100: Saved adapter weights to adapters.npz.
Iter 110: Train loss 1.080, It/sec 0.567, Tokens/sec 219.234
Iter 120: Train loss 1.065, It/sec 0.563, Tokens/sec 219.935
Iter 130: Train loss 1.083, It/sec 0.536, Tokens/sec 211.902
Iter 140: Train loss 1.072, It/sec 0.546, Tokens/sec 212.716
Iter 150: Train loss 1.061, It/sec 0.472, Tokens/sec 192.188
Iter 160: Train loss 0.991, It/sec 0.512, Tokens/sec 201.292
Iter 170: Train loss 1.028, It/sec 0.535, Tokens/sec 220.537
Iter 180: Train loss 0.978, It/sec 0.594, Tokens/sec 215.790
Iter 190: Train loss 1.033, It/sec 0.537, Tokens/sec 214.972
Iter 200: Train loss 1.091, It/sec 0.545, Tokens/sec 207.353
Iter 200: Val loss 1.111, Val took 30.101s
Iter 200: Saved adapter weights to adapters.npz.
Iter 210: Train loss 1.056, It/sec 0.573, Tokens/sec 217.968
Iter 220: Train loss 0.987, It/sec 0.552, Tokens/sec 220.129
Iter 230: Train loss 0.984, It/sec 0.578, Tokens/sec 225.119
Iter 240: Train loss 0.929, It/sec 0.593, Tokens/sec 227.224
Iter 250: Train loss 0.984, It/sec 0.504, Tokens/sec 209.164
Iter 260: Train loss 0.871, It/sec 0.529, Tokens/sec 213.830
Iter 270: Train loss 0.843, It/sec 0.549, Tokens/sec 214.504
Iter 280: Train loss 0.866, It/sec 0.606, Tokens/sec 233.129
Iter 290: Train loss 0.946, It/sec 0.564, Tokens/sec 216.089
Iter 300: Train loss 0.818, It/sec 0.574, Tokens/sec 234.182
Iter 300: Saved adapter weights to adapters.npz.
Iter 310: Train loss 0.939, It/sec 0.610, Tokens/sec 228.415
Iter 320: Train loss 0.811, It/sec 0.536, Tokens/sec 208.765
Iter 330: Train loss 0.890, It/sec 0.514, Tokens/sec 207.142
Iter 340: Train loss 0.825, It/sec 0.494, Tokens/sec 190.312
Iter 350: Train loss 0.845, It/sec 0.552, Tokens/sec 211.589
Iter 360: Train loss 0.872, It/sec 0.553, Tokens/sec 221.311
Iter 370: Train loss 0.832, It/sec 0.502, Tokens/sec 205.400
Iter 380: Train loss 0.855, It/sec 0.565, Tokens/sec 217.207
Iter 390: Train loss 0.873, It/sec 0.593, Tokens/sec 229.769
Iter 400: Train loss 0.837, It/sec 0.491, Tokens/sec 207.763
Iter 400: Val loss 1.076, Val took 31.449s
Iter 400: Saved adapter weights to adapters.npz.
Iter 410: Train loss 0.821, It/sec 0.556, Tokens/sec 223.608
Iter 420: Train loss 0.828, It/sec 0.593, Tokens/sec 219.316
Iter 430: Train loss 0.787, It/sec 0.573, Tokens/sec 214.802
Iter 440: Train loss 0.842, It/sec 0.529, Tokens/sec 208.544
Iter 450: Train loss 0.794, It/sec 0.531, Tokens/sec 215.918
Iter 460: Train loss 0.832, It/sec 0.520, Tokens/sec 212.107
Iter 470: Train loss 0.767, It/sec 0.578, Tokens/sec 228.089
Iter 480: Train loss 0.794, It/sec 0.548, Tokens/sec 215.279
Iter 490: Train loss 0.737, It/sec 0.612, Tokens/sec 236.395
Iter 500: Train loss 0.774, It/sec 0.542, Tokens/sec 223.036
Iter 500: Saved adapter weights to adapters.npz.
Iter 510: Train loss 0.750, It/sec 0.524, Tokens/sec 212.472
Iter 520: Train loss 0.636, It/sec 0.562, Tokens/sec 221.322
Iter 530: Train loss 0.587, It/sec 0.541, Tokens/sec 218.441
Iter 540: Train loss 0.631, It/sec 0.589, Tokens/sec 225.624
Iter 550: Train loss 0.661, It/sec 0.580, Tokens/sec 228.000
Iter 560: Train loss 0.686, It/sec 0.537, Tokens/sec 213.582
Iter 570: Train loss 0.630, It/sec 0.543, Tokens/sec 210.104
Iter 580: Train loss 0.632, It/sec 0.588, Tokens/sec 228.862
Iter 590: Train loss 0.632, It/sec 0.517, Tokens/sec 203.740
Iter 600: Train loss 0.609, It/sec 0.531, Tokens/sec 218.118
Iter 600: Val loss 1.001, Val took 30.002s
Iter 600: Saved adapter weights to adapters.npz.
python lora.py --model mistralai/Mistral-7B-v0.1 --train --iters 600  50.58s user 214.71s system 21% cpu 20:26.04 total

微调万分之 2.35 (1.704M / 7243.436M * 10000)的模型参数。

LoRA 微调 600 次迭代,耗时 20 分 26 秒,占用内存 46G。

Iteration Train Loss Val Loss Tokens/sec
1   2.343  
100 1.204   221.348
200 1.091 1.111 207.353
300 0.818   234.182
400 0.837 1.076 207.763
500 0.774   223.036
600 0.609 1.001 218.118

评估

计算测试集困惑度(PPL)和交叉熵损失(Loss)。

python lora.py --model mistralai/Mistral-7B-v0.1 \
               --adapter-file adapters.npz \
               --test
Iter 100: Test loss 1.351, Test ppl 3.862.
Iter 200: Test loss 1.327, Test ppl 3.770.
Iter 300: Test loss 1.353, Test ppl 3.869.
Iter 400: Test loss 1.355, Test ppl 3.875.
Iter 500: Test loss 1.294, Test ppl 3.646.
Iter 600: Test loss 1.351, Test ppl 3.863.
  Iteration Test Loss Test PPL
  100 1.351 3.862
  200 1.327 3.770
  300 1.353 3.869
  400 1.355 3.875
👍 500 1.294 3.646
  600 1.351 3.863

评估占用内存 26G。

融合(Fuse)

python fuse.py --model mistralai/Mistral-7B-v0.1 \
               --adapter-file adapters.npz \
               --save-path lora_fused_model

这里使用了 Iter 500 的模型参数,因为它的测试集困惑度最低。

生成 SQL

王军建的姓名是什么?

python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: What is Wang Junjian's name?
A: "
SELECT Name FROM students WHERE Name = 'Wang Junjian'
==========
Prompt: 88.790 tokens-per-sec
Generation: 16.811 tokens-per-sec

王军建的年龄是多少?

python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: How old is Wang Junjian?
A: "
SELECT Age FROM Students WHERE Name = 'Wang Junjian'
==========
Prompt: 84.460 tokens-per-sec
Generation: 16.801 tokens-per-sec

王军建来自哪所学校?

python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Which school did Wang Junjian come from?
A: "
SELECT School FROM Students WHERE Name = 'Wang Junjian'
==========
Prompt: 89.124 tokens-per-sec
Generation: 16.718 tokens-per-sec

查询王军建的姓名、年龄、学校信息。

python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query Wang Junjian’s name, age, and school information.
A: "
SELECT Name, Age, School FROM Students WHERE Name = 'Wang Junjian'
==========
Prompt: 100.919 tokens-per-sec
Generation: 17.139 tokens-per-sec

查询王军建的所有信息。

python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query all information about Wang Junjian.
A: "
SELECT Name FROM students WHERE Name = 'Wang Junjian'
==========
Prompt: 88.225 tokens-per-sec
Generation: 16.781 tokens-per-sec

可能训练数据不足。

统计一下九年级有多少学生。

python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Count how many students there are in ninth grade.
A: "
SELECT COUNT Name FROM Students WHERE Grade = '9th'
==========
Prompt: 93.829 tokens-per-sec
Generation: 16.546 tokens-per-sec

统计一下九年级有多少学生(九年级的值是9)。

python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
The value for ninth grade is 9.
Q: Count how many students there are in ninth grade.
A: "
python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Count how many students there are in ninth grade.(The value for ninth grade is 9.)
A: "
SELECT COUNT Name FROM students WHERE Grade = 9
==========
Prompt: 117.893 tokens-per-sec
Generation: 16.298 tokens-per-sec

附加的提示信息可以轻松添加,不用太在意放置的位置。

参考资料