在 MLX 上使用 LoRA / QLoRA 微调 Text2SQL(八):使用 LoRA 基于 TinyLlama 微调
类别: MLX Text2SQL 标签: MLX LoRA TinyLlama Text2SQL WikiSQL MacBookProM2Max目录
TinyLlama
- TinyLlama: The Mini AI Model with a Trillion-Token Punch
- Just using MLX to fine-tune TinyLlama with LoRA locally on a 8 GB Mac Mini.
TinyLlama/TinyLlama-1.1B-Chat-v1.0
- 输入
<|system|>
You are a chatbot who can help code!</s>
<|user|>
Write me a function to calculate the first 10 digits of the fibonacci sequence in Python and print it out to the CLI.</s>
<|assistant|>
- 输出
[
{
"generated_text": "<|system|>\nYou are a chatbot who can help code!</s>\n<|user|>\nWrite me a function to calculate the first 10 digits of the fibonacci sequence in Python and print it out to the CLI.</s>\n<|assistant|>\nHere's a Python function that calculates the first 10 digits of the Fib"
}
]
- 生成
python -m mlx_lm.generate --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Using the database information above, help me generate the SQL for the following question:\nQuery information such as name, age, school, grade, etc. The condition is that the name is equal to Wang Junjian and the age is less than 20 years old."
==========
Prompt: <|user|>
table: students
columns: Name, Age, School, Grade, Height, Weight
Using the database information above, help me generate the SQL for the following question:\nQuery information such as name, age, school, grade, etc. The condition is that the name is equal to Wang Junjian and the age is less than 20 years old.</s>
<|assistant|>
To generate the SQL for the query you provided, you can use the following:
SELECT *
FROM students
WHERE name = 'Wang Junjian'
AND age < 20
This SQL will retrieve all students whose name is "Wang Junjian" and whose age is less than 20 years old.
==========
Prompt: 866.228 tokens-per-sec
Generation: 180.262 tokens-per-sec
TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
- 输入
My name is Merve and my favorite color
- 输出
[
{
"generated_text": "My name is Merve and my favorite color is blue. I am 10 years old and I live in Turkey. I have a"
}
]
数据集 WikiSQL
修改脚本 mlx-examples/lora/data/wikisql.py
if __name__ == "__main__":
# ......
for dataset, name, size in datasets:
with open(f"data/{name}.jsonl", "w") as fid:
for e, t in zip(range(size), dataset):
# TinyLlama/TinyLlama-1.1B-Chat-v1.0
t = t[3:]
q, a = t.split("A: ")
q += "A: "
# template = "<|system|>\nYou are a chatbot that can help write SQL statements!</s>\n<|user|>\n{q}</s>\n<|assistant|>\n{a}"
template = "<|user|>\n{q}</s>\n<|assistant|>\n{a}"
t = template.format(q=q, a=a)
print(t)
print()
# TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
# t = t[3:]
json.dump({"text": t}, fid)
fid.write("\n")
执行脚本 data/wikisql.py
生成数据集。
data/wikisql.py
微调
使用 LoRA 微调 TinyLlama/TinyLlama-1.1B-Chat-v1.0
time python lora.py --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --train --iters 500
Loading pretrained model
Total parameters 1100.868M
Trainable parameters 0.819M
Loading datasets
Training
Iter 1: Val loss 2.507, Val took 4.411s
Iter 10: Train loss 2.458, It/sec 3.036, Tokens/sec 1424.852
Iter 20: Train loss 2.126, It/sec 1.517, Tokens/sec 720.801
Iter 30: Train loss 1.868, It/sec 1.464, Tokens/sec 702.935
Iter 40: Train loss 1.616, It/sec 1.616, Tokens/sec 738.103
Iter 50: Train loss 1.388, It/sec 1.791, Tokens/sec 838.185
Iter 60: Train loss 1.190, It/sec 1.411, Tokens/sec 656.398
Iter 70: Train loss 1.179, It/sec 1.387, Tokens/sec 657.852
Iter 80: Train loss 1.205, It/sec 1.762, Tokens/sec 828.095
Iter 90: Train loss 1.157, It/sec 1.878, Tokens/sec 842.579
Iter 100: Train loss 1.182, It/sec 1.665, Tokens/sec 744.982
Iter 100: Saved adapter weights to adapters.npz.
Iter 110: Train loss 1.065, It/sec 1.724, Tokens/sec 787.045
Iter 120: Train loss 1.110, It/sec 1.660, Tokens/sec 760.937
Iter 130: Train loss 1.051, It/sec 1.409, Tokens/sec 656.706
Iter 140: Train loss 1.127, It/sec 1.705, Tokens/sec 782.127
Iter 150: Train loss 1.145, It/sec 1.605, Tokens/sec 761.741
Iter 160: Train loss 1.041, It/sec 1.745, Tokens/sec 810.177
Iter 170: Train loss 1.062, It/sec 1.697, Tokens/sec 815.289
Iter 180: Train loss 1.044, It/sec 1.978, Tokens/sec 854.335
Iter 190: Train loss 1.077, It/sec 1.450, Tokens/sec 676.531
Iter 200: Train loss 1.109, It/sec 1.620, Tokens/sec 723.184
Iter 200: Val loss 1.090, Val took 7.295s
Iter 200: Saved adapter weights to adapters.npz.
Iter 210: Train loss 1.051, It/sec 1.721, Tokens/sec 769.713
Iter 220: Train loss 1.045, It/sec 1.716, Tokens/sec 801.495
Iter 230: Train loss 0.985, It/sec 1.649, Tokens/sec 753.308
Iter 240: Train loss 0.987, It/sec 1.832, Tokens/sec 826.279
Iter 250: Train loss 0.997, It/sec 1.602, Tokens/sec 779.581
Iter 260: Train loss 0.957, It/sec 1.982, Tokens/sec 939.073
Iter 270: Train loss 0.950, It/sec 1.857, Tokens/sec 855.743
Iter 280: Train loss 0.969, It/sec 2.526, Tokens/sec 1141.648
Iter 290: Train loss 0.967, It/sec 1.425, Tokens/sec 644.808
Iter 300: Train loss 0.871, It/sec 1.880, Tokens/sec 893.419
Iter 300: Saved adapter weights to adapters.npz.
Iter 310: Train loss 0.988, It/sec 1.851, Tokens/sec 821.308
Iter 320: Train loss 0.907, It/sec 1.717, Tokens/sec 785.215
Iter 330: Train loss 0.986, It/sec 1.813, Tokens/sec 855.241
Iter 340: Train loss 0.915, It/sec 1.464, Tokens/sec 661.628
Iter 350: Train loss 0.891, It/sec 1.722, Tokens/sec 775.787
Iter 360: Train loss 0.981, It/sec 1.854, Tokens/sec 868.989
Iter 370: Train loss 0.930, It/sec 1.481, Tokens/sec 708.436
Iter 380: Train loss 0.946, It/sec 1.460, Tokens/sec 659.100
Iter 390: Train loss 0.919, It/sec 1.453, Tokens/sec 659.954
Iter 400: Train loss 0.946, It/sec 1.213, Tokens/sec 600.219
Iter 400: Val loss 1.018, Val took 10.151s
Iter 400: Saved adapter weights to adapters.npz.
Iter 410: Train loss 0.910, It/sec 1.011, Tokens/sec 476.125
Iter 420: Train loss 0.907, It/sec 1.083, Tokens/sec 475.708
Iter 430: Train loss 0.878, It/sec 0.995, Tokens/sec 439.132
Iter 440: Train loss 0.928, It/sec 0.861, Tokens/sec 397.056
Iter 450: Train loss 0.930, It/sec 0.775, Tokens/sec 368.793
Iter 460: Train loss 0.993, It/sec 0.734, Tokens/sec 351.590
Iter 470: Train loss 0.899, It/sec 0.817, Tokens/sec 379.844
Iter 480: Train loss 0.910, It/sec 0.788, Tokens/sec 363.797
Iter 490: Train loss 0.869, It/sec 0.884, Tokens/sec 401.832
Iter 500: Train loss 0.944, It/sec 0.810, Tokens/sec 390.466
Iter 500: Saved adapter weights to adapters.npz.
python lora.py --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 --train --iters 500 36.69s user 103.96s system 36% cpu 6:22.46 total
生成
python lora.py --model TinyLlama/TinyLlama-1.1B-Chat-v1.0 \
--adapter-file adapters.npz \
--max-tokens 50 \
--prompt "<|user|>table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query information such as name, age, school, grade, etc. The condition is that the name is equal to Wang Junjian and the age is less than 20 years old.
A: </s>\n<|assistant|>"
Loading pretrained model
Total parameters 1100.868M
Trainable parameters 0.819M
Loading datasets
Generating
<|user|>table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query information such as name, age, school, grade, etc. The condition is that the name is equal to Wang Junjian and the age is less than 20 years old.
A: </s>\n<|assistant|>
SELECT Name FROM students WHERE Name = 'Wang Junjian' AND Age < 20
使用 mlx_lm.generate
不能很好的生成,目前还不清楚 fuse
的问题,还是 generate
的问题。
python -m mlx_lm.generate --model lora_fused_model --prompt "table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query information such as name, age, school, grade, etc. The condition is that the name is equal to Wang Junjian and the age is less than 20 years old.
A: "
Prompt: table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query information such as name, age, school, grade, etc. The condition is that the name is equal to Wang Junjian and the age is less than 20 years old.
A:
1. Statement
SELECT name, age, school, grade, height, weight FROM students WHERE name = 'Wang Junjian' and age < 20
2. Query
SELECT name, age, school, grade, height, weight FROM students WHERE name = 'Wang Junjian'
3. Output
NAME AGE SCHOOL GRADE HEIGHT WEIGHT
Wang Junjian 20 1
python -m mlx_lm.generate --model lora_fused_model --prompt "<|user|>table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query information such as name, age, school, grade, etc. The condition is that the name is equal to Wang Junjian and the age is less than 20 years old.
A: </s>\n<|assistant|>"
Prompt: <|user|>table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query information such as name, age, school, grade, etc. The condition is that the name is equal to Wang Junjian and the age is less than 20 years old.
A: </s>\n<|assistant|>
is a new tool for supporting the research process of Assistant, an open source machine learning and deep learning platform. It consists of a set of python scripts that can be used to perform the tasks of data pre-processing, feature engineering, data cleaning, data exploration, and supervised learning.
\nAssistant is built on top of Scikit-Learn, a Python machine learning library, and designed with ease of use in mind. It is easy to install and uses
SQL 问题(英文)
table: students
columns: Name, Age, School, Grade, Height, Weight
Q: What is Wang Junjian's name?
A:
table: students
columns: Name, Age, School, Grade, Height, Weight
Q: How old is Wang Junjian?
A:
table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Which school did Wang Junjian come from?
A:
table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query Wang Junjian’s name, age, and school information.
A:
table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query all information about Wang Junjian.
A:
table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Query information such as name, age, school, grade, etc. The condition is that the name is equal to Wang Junjian and the age is less than 20 years old.
A:
table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Count how many students there are in ninth grade.
A:
table: students
columns: Name, Age, School, Grade, Height, Weight
Q: Count how many students there are in ninth grade.(The value for ninth grade is 9th.)
A:
参考资料
- MLX Community
- Fine-Tuning with LoRA or QLoRA
- Generate Text with LLMs and MLX
- Awesome Text2SQL
- Awesome Text2SQL(中文)
- Mistral AI
- A Beginner’s Guide to Fine-Tuning Mistral 7B Instruct Model
- Mistral Instruct 7B Finetuning on MedMCQA Dataset
- Fine-tuning Mistral on your own data
- mlx-examples llms Mistral
- deepseek-ai/deepseek-coder-7b-base-v1.5
- DeepSeek LLM: Scaling Open-Source Language Models with Longtermism
- Benchmarking Apple’s MLX vs. llama.cpp
- Apple MLX —> GGUF
- A Simple Voice Assistant Script
- I made an app that runs Mistral 7B 0.2 LLM locally on iPhone Pros
- MLX Chat
- MLX Stable Diffusion UI
- Graphically display the information collected by powermetrics