目录

预训练模型 mistralai/Mistral-7B-v0.1

量化

QLoRA 微调需要量化,生成 4 位量化的 Mistral 7B 并默认将其存储在 mlx_model 目录中

python convert.py --hf-path mistralai/Mistral-7B-v0.1 -q

mlx_model 目录结构如下:

mlx_model
├── config.json
├── model.safetensors
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer.model
├── tokenizer_config.json
└── weights.00.safetensors

量化后的模型 8.0G

微调

QLoRA 微调

python lora.py --model mlx_model \
               --train \
               --iters 600
Loading pretrained model
Total parameters 1244.041M
Trainable parameters 1.704M
Loading datasets
Training
Iter 1: Val loss 2.420, Val took 32.229s
Iter 10: Train loss 2.308, It/sec 0.388, Tokens/sec 155.982
Iter 20: Train loss 1.748, It/sec 0.459, Tokens/sec 186.088
Iter 30: Train loss 1.487, It/sec 0.451, Tokens/sec 185.720
Iter 40: Train loss 1.373, It/sec 0.466, Tokens/sec 180.710
Iter 50: Train loss 1.263, It/sec 0.462, Tokens/sec 183.856
Iter 60: Train loss 1.128, It/sec 0.483, Tokens/sec 191.427
Iter 70: Train loss 1.193, It/sec 0.415, Tokens/sec 168.129
Iter 80: Train loss 1.187, It/sec 0.438, Tokens/sec 175.615
Iter 90: Train loss 1.165, It/sec 0.505, Tokens/sec 191.464
Iter 100: Train loss 1.216, It/sec 0.437, Tokens/sec 166.377
Iter 100: Saved adapter weights to adapters.npz.
Iter 110: Train loss 1.106, It/sec 0.504, Tokens/sec 194.826
Iter 120: Train loss 1.086, It/sec 0.443, Tokens/sec 172.983
Iter 130: Train loss 1.096, It/sec 0.458, Tokens/sec 181.211
Iter 140: Train loss 1.087, It/sec 0.488, Tokens/sec 190.167
Iter 150: Train loss 1.089, It/sec 0.429, Tokens/sec 174.838
Iter 160: Train loss 1.017, It/sec 0.459, Tokens/sec 180.264
Iter 170: Train loss 1.041, It/sec 0.475, Tokens/sec 195.793
Iter 180: Train loss 0.998, It/sec 0.525, Tokens/sec 190.633
Iter 190: Train loss 1.054, It/sec 0.477, Tokens/sec 191.109
Iter 200: Train loss 1.095, It/sec 0.493, Tokens/sec 187.795
Iter 200: Val loss 1.130, Val took 32.258s
Iter 200: Saved adapter weights to adapters.npz.
Iter 210: Train loss 1.065, It/sec 0.455, Tokens/sec 173.140
Iter 220: Train loss 0.999, It/sec 0.488, Tokens/sec 194.324
Iter 230: Train loss 0.993, It/sec 0.473, Tokens/sec 183.961
Iter 240: Train loss 0.937, It/sec 0.502, Tokens/sec 192.688
Iter 250: Train loss 0.982, It/sec 0.448, Tokens/sec 185.853
Iter 260: Train loss 0.875, It/sec 0.462, Tokens/sec 186.802
Iter 270: Train loss 0.845, It/sec 0.470, Tokens/sec 183.813
Iter 280: Train loss 0.857, It/sec 0.504, Tokens/sec 193.769
Iter 290: Train loss 0.924, It/sec 0.479, Tokens/sec 183.488
Iter 300: Train loss 0.751, It/sec 0.473, Tokens/sec 192.737
Iter 300: Saved adapter weights to adapters.npz.
Iter 310: Train loss 0.871, It/sec 0.520, Tokens/sec 194.461
Iter 320: Train loss 0.726, It/sec 0.465, Tokens/sec 181.200
Iter 330: Train loss 0.817, It/sec 0.458, Tokens/sec 184.790
Iter 340: Train loss 0.746, It/sec 0.479, Tokens/sec 184.412
Iter 350: Train loss 0.768, It/sec 0.496, Tokens/sec 190.082
Iter 360: Train loss 0.791, It/sec 0.467, Tokens/sec 186.836
Iter 370: Train loss 0.751, It/sec 0.444, Tokens/sec 181.680
Iter 380: Train loss 0.775, It/sec 0.469, Tokens/sec 180.115
Iter 390: Train loss 0.798, It/sec 0.484, Tokens/sec 187.477
Iter 400: Train loss 0.769, It/sec 0.402, Tokens/sec 170.072
Iter 400: Val loss 1.006, Val took 33.268s
Iter 400: Saved adapter weights to adapters.npz.
Iter 410: Train loss 0.741, It/sec 0.444, Tokens/sec 178.822
Iter 420: Train loss 0.744, It/sec 0.478, Tokens/sec 176.976
Iter 430: Train loss 0.714, It/sec 0.497, Tokens/sec 186.067
Iter 440: Train loss 0.780, It/sec 0.472, Tokens/sec 186.158
Iter 450: Train loss 0.727, It/sec 0.442, Tokens/sec 179.839
Iter 460: Train loss 0.762, It/sec 0.452, Tokens/sec 184.472
Iter 470: Train loss 0.698, It/sec 0.491, Tokens/sec 193.616
Iter 480: Train loss 0.740, It/sec 0.453, Tokens/sec 177.975
Iter 490: Train loss 0.681, It/sec 0.516, Tokens/sec 199.231
Iter 500: Train loss 0.726, It/sec 0.460, Tokens/sec 189.288
Iter 500: Saved adapter weights to adapters.npz.
Iter 510: Train loss 0.748, It/sec 0.440, Tokens/sec 178.288
Iter 520: Train loss 0.629, It/sec 0.480, Tokens/sec 189.089
Iter 530: Train loss 0.577, It/sec 0.452, Tokens/sec 182.326
Iter 540: Train loss 0.636, It/sec 0.510, Tokens/sec 195.363
Iter 550: Train loss 0.670, It/sec 0.484, Tokens/sec 190.476
Iter 560: Train loss 0.690, It/sec 0.480, Tokens/sec 190.871
Iter 570: Train loss 0.625, It/sec 0.501, Tokens/sec 193.686
Iter 580: Train loss 0.628, It/sec 0.490, Tokens/sec 190.567
Iter 590: Train loss 0.627, It/sec 0.455, Tokens/sec 179.460
Iter 600: Train loss 0.607, It/sec 0.454, Tokens/sec 186.397
Iter 600: Val loss 1.015, Val took 33.347s
Iter 600: Saved adapter weights to adapters.npz.
python lora.py --model mlx_model --train --iters 600  49.62s user 174.63s system 15% cpu 23:40.75 total

微调万分之 13.70(1.704M / 1244.041M * 10000)的模型参数。

QLoRA 微调 600 次迭代,耗时 23 分 40 秒,占用内存 46G。

Iteration Train Loss Val Loss Tokens/sec
1   2.420  
100 1.216   166.377
200 1.095 1.130 187.795
300 1.065   194.826
400 0.998 1.006 170.072
500 0.726   189.288
600 0.607 1.015 186.397

评估

计算测试集困惑度(PPL)和交叉熵损失(Loss)。

python lora.py --model mlx_model \
               --adapter-file adapters.npz \
               --test
Test loss 1.396, Test ppl 4.040.

评估占用内存 15G。

融合(Fuse)

python fuse.py --model mlx_model \
               --adapter-file adapters.npz \
               --save-path lora_fused_model

lora_fused_model 目录结构如下:

lora_fused_model
├── config.json
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer.model
├── tokenizer_config.json
└── weights.00.safetensors

融合后的模型 4.0G

生成 SQL

王军建的姓名是什么?

python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: What is Wang Junjian's name?
A: "
SELECT Name FROM students WHERE School = 'Wang Junjian'
==========
Prompt: 154.798 tokens-per-sec
Generation: 107.496 tokens-per-sec

王军建的年龄是多少?

python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: How old is Wang Junjian?
A: "
11 SELECT Age FROM students WHERE Name = 'Wang Junjian'
==========
Prompt: 156.206 tokens-per-sec
Generation: 101.284 tokens-per-sec

王军建来自哪所学校?

python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Which school did Wang Junjian come from?
A: "
SELECT School FROM students WHERE Name = 'Wang Junjian'
==========
Prompt: 160.011 tokens-per-sec
Generation: 121.895 tokens-per-sec

查询王军建的姓名、年龄、学校信息。

python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query Wang Junjian’s name, age, and school information.
A: "
SELECT Name FROM students WHERE Age = 13 AND School = 'Hangzhou Foreign Language School'
==========
Prompt: 183.456 tokens-per-sec
Generation: 121.376 tokens-per-sec

查询王军建的所有信息。

python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query all information about Wang Junjian.
A: "
SELECT Name FROM students WHERE School = 'Wang Junjian'
==========
Prompt: 151.962 tokens-per-sec
Generation: 122.067 tokens-per-sec

统计一下九年级有多少学生。

python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Count how many students there are in ninth grade.
A: "
SELECT COUNT Name FROM students WHERE Grade = 9
==========
Prompt: 164.480 tokens-per-sec
Generation: 115.851 tokens-per-sec

统计一下九年级有多少学生(九年级的值是9th)。

python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
The value for ninth grade is 9th.
Q: Count how many students there are in ninth grade.
A: "
python -m mlx_lm.generate --model lora_fused_model \
                          --max-tokens 50 \
                          --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Count how many students there are in ninth grade.(The value for ninth grade is 9th.)
A: "
 SELECT COUNT Name FROM students WHERE Grade = '9th'
==========
Prompt: 216.152 tokens-per-sec
Generation: 114.300 tokens-per-sec

附加的提示信息可以轻松添加,不用太在意放置的位置。

参考资料