目录

ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 General Language Model (GLM) 架构,具有 62 亿参数。

聊天 ChatGLM-6B

下载

克隆

https://github.com/THUDM/ChatGLM-6B.git
cd ChatGLM-6B

下载模型

git clone https://huggingface.co/THUDM/chatglm-6b THUDM/chatglm-6b
  • 在国内为了加快下载速度,模型文件可以单独从 清华云 下载。
GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/THUDM/chatglm-6b THUDM/chatglm-6b

wget "https://cloud.tsinghua.edu.cn/d/fb9f16d6dc8f482596c2/files/?p=%2Fice_text.model&dl=1" -O THUDM/chatglm-6b/ice_text.model
for i in {1..8}; do wget "https://cloud.tsinghua.edu.cn/d/fb9f16d6dc8f482596c2/files/?p=%2Fpytorch_model-0000${i}-of-00008.bin&dl=1" -O THUDM/chatglm-6b/pytorch_model-0000${i}-of-00008.bin; done

搭建环境

创建虚拟环境

conda create --name pytorch python
conda activate pytorch

安装最新版 PyTorch

pip install --upgrade --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu

安装依赖

pip install -r requirements.txt

运行

需要修改以下文件,指定计算设备:

  • cli_demo.py 命令行交互
  • web_demo.py Web 交互(Gradio)
  • web_demo2.py Web 交互(Streamlit)
  • api.py REST API 服务

CUDA (默认)

model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().cuda()

MPS

model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).half().to('mps')

CPU

model = AutoModel.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True).float()

默认对话的最大长度为 2048,所以聊天的时候,如果超过 2048 个字,就会报错。

Input length of input_ids is 2059, but `max_length` is set to 2048. This can lead to unexpected behavior. You should consider increasing `max_new_tokens`.

资源使用

模型参数 模型大小 量化精度 CPU MPS GPU
6.2B 13G FP16 22G 13G 12G

使用 MPS 的时候,内存会随着聊天持续且快速增长。

微调 ChatGLM-6B-PT

搭建环境

创建虚拟环境

conda create -n chatglm6b python==3.10.9
conda activate chatglm6b

安装依赖

conda install -n chatglm6b  pytorch==1.12 torchvision torchaudio cudatoolkit -c pytorch
pip install rouge_chinese nltk jieba datasets
pip install protobuf==3.20 transformers==4.27.1 cpm_kernels gradio mdtex2html sentencepiece
pip install absl-py wrapt opt_einsum astunparse termcolor flatbuffers chardet cchardet

数据集

下载 ADGEN 数据集

```bash
wget https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1 -O AdvertiseGen.tar.gz
tar xzvf AdvertiseGen.tar.gz

ADGEN 数据集任务为根据输入(content)生成一段广告词(summary)。

{
    "content": "类型#上衣*版型#宽松*版型#显瘦*图案#线条*衣样式#衬衫*衣袖型#泡泡袖*衣款式#抽绳",
    "summary": "这件衬衫的款式非常的宽松,利落的线条可以很好的隐藏身材上的小缺点,穿在身上有着很好的显瘦效果。领口装饰了一个可爱的抽绳,漂亮的绳结展现出了十足的个性,配合时尚的泡泡袖型,尽显女性甜美可爱的气息。"
}

预训练模型

cp -r ../THUDM .

训练 P-tuning v2

默认量化精度 INT4

sh train.sh

可以根据需要修改 train.sh 中的参数。

量化精度是 FP6

PRE_SEQ_LEN=128
LR=2e-2

CUDA_VISIBLE_DEVICES=0 python3 main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path THUDM/chatglm-6b \
    --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 1 \ 
    --per_device_eval_batch_size 1 \ 
    --gradient_accumulation_steps 16 \
    --predict_with_generate \
    --max_steps 3000 \
    --logging_steps 10 \
    --save_steps 1000 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN 

评估

如果使用量化精度是 FP6,需要移除 --quantization_bit 4 参数。

sh evaluate.sh

模型部署

DEMO

import os
import platform
import signal
import torch
from transformers import AutoConfig, AutoTokenizer, AutoModel


# 加载预训练模型
PRE_TRAINING_MODEL_PATH = 'THUDM/chatglm-6b'
tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINING_MODEL_PATH, trust_remote_code=True)
config = AutoConfig.from_pretrained(PRE_TRAINING_MODEL_PATH, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(PRE_TRAINING_MODEL_PATH, config=config, trust_remote_code=True)

# 加载微调的参数
CHECKPOINT_PATH = 'output/adgen-chatglm-6b-pt-128-2e-2/checkpoint-3000/'
prefix_state_dict = torch.load(os.path.join(CHECKPOINT_PATH, "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v 
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

# model = model.quantize(4) # 使用 FP16 量化
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()

os_name = platform.system()
clear_command = 'cls' if os_name == 'Windows' else 'clear'
stop_stream = False


def build_prompt(history):
    prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序"
    for query, response in history:
        prompt += f"\n\n用户:{query}"
        prompt += f"\n\nChatGLM-6B:{response}"
    return prompt


def signal_handler(signal, frame):
    global stop_stream
    stop_stream = True


def main():
    history = []
    global stop_stream
    print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
    while True:
        query = input("\n用户:")
        if query.strip() == "stop":
            break
        if query.strip() == "clear":
            history = []
            os.system(clear_command)
            print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序")
            continue
        count = 0 
        for response, history in model.stream_chat(tokenizer, query, history=history):
            if stop_stream:
                stop_stream = False
                break
            else:
                count += 1
                if count % 8 == 0:
                    os.system(clear_command)
                    print(build_prompt(history), flush=True)
                    signal.signal(signal.SIGINT, signal_handler)
        os.system(clear_command)
        print(build_prompt(history), flush=True)


if __name__ == "__main__":
    main()

资源使用

训练

精度 显存 用时
FP16 13.5G 4 小时 33 分
INT4 5.3G 9 小时 43 分
  • 训练指标(FP16)
    epoch                    =       0.42
    train_loss               =     3.9119
    train_runtime            = 4:33:44.72
    train_samples            =     114599
    train_samples_per_second =      2.922
    train_steps_per_second   =      0.183
    
  • 训练指标(INT4)
    epoch                    =       0.42
    train_loss               =     3.9329
    train_runtime            = 9:43:08.80
    train_samples            =     114599
    train_samples_per_second =      1.372
    train_steps_per_second   =      0.086
    

评估

精度 显存 用时
FP16 13G 1 小时
INT4 5.7G 2 小时 32 分
  • 评估指标(FP16)
    predict_bleu-4             =     7.8867
    predict_rouge-1            =    31.1265
    predict_rouge-2            =     7.0495
    predict_rouge-l            =    24.8459
    predict_runtime            = 0:59:45.02
    predict_samples            =       1070
    predict_samples_per_second =      0.298
    predict_steps_per_second   =      0.298
    
  • 评估指标(INT4)
    predict_bleu-4             =     8.0218
    predict_rouge-1            =    31.4474
    predict_rouge-2            =     7.2794
    predict_rouge-l            =    25.0934
    predict_runtime            = 2:32:33.50
    predict_samples            =       1070
    predict_samples_per_second =      0.117
    predict_steps_per_second   =      0.117
    

推理

精度 显存
FP16 13G
INT4 4.7G

对话效果

示例1

{
    "content": "类型#裙*裙下摆#弧形*裙腰型#高腰*裙长#半身裙*裙款式#不规则*裙款式#收腰", 
    "summary": "这款来自梵凯的半身裙富有十足的设计感,采用了别致的不规则设计,凸显出时尚前卫的格调,再搭配俏皮的高腰设计,收腰提臀的同时还勾勒出优美迷人的身材曲线,而且还帮你拉长腿部比例,释放出优雅娇俏的小女人味。并且独特的弧形下摆还富有流畅的线条美,一颦一动间展现出灵动柔美的气质。"
}
微调前
高腰收腰连衣裙
类型:裙
版型:弧形
裙腰型:高腰
裙长:半身裙
设计元素:不规则
裙款式:收腰
微调后(FP16)
半身裙是春季的必备单品,收腰的版型设计,修饰出纤细的身姿。腰部弧形设计,凸显曼妙身姿,收腰的版型设计,拉长身材比例,尽显高挑曼妙的身姿。裙摆不规则设计,灵动飘逸,凸显个性,让你穿出与众不同的时尚。高腰的版型设计,提高腰线,展现高挑
微调后(INT4)
一款优雅大气的半身裙,高腰的版型,修饰腰身,更显身材高挑,不规则的裙摆,展现女性的独特气质,更加优雅大气。收腰的弧形设计,修饰腰部,更显身材纤细。

示例2

{
    "content": "类型#上衣*版型#宽松*风格#英伦*风格#复古*图案#格子*图案#复古*图案#线条*衣样式#外套*衣样式#西装*衣领型#翻领", 
    "summary": "这件西装外套选用了经久不衰的格纹元素,通过色彩的明暗对比,展现出丰富的视觉层次,又缔造了英伦风的复古气息。法式的大翻领,延长颈部线条,彰显出女性帅气干练的特殊魅力。宽松舒适的版型完美掩藏了身材的小秘密,给身体自由活动空间。"
}
微调前
宽松复古外套
类型:上衣
版型:宽松
风格:英伦
复古
图案:格子
复古
图案:线条
衣领型:翻领
设计元素:外套
微调后(FP16)
西装外套是春季必备单品,宽松的版型设计,凸显随性慵懒的气质。复古的格子图案,凸显复古的英伦风,尽显个性。高腰的收腰设计,拉长身材比例,尽显高挑曼妙的身姿。不规则裙摆设计,灵动飘逸,展现灵动飘逸的韵味。
微调后(INT4)
这款外套采用了经典的西装翻领设计,展现女性干练优雅的一面,搭配修身的的版型,展现女性高挑的身材。宽松的衣身,修饰腿部线条,更显高挑。收腰的弧形设计,修饰腰部,更显纤细。不规则的裙摆,展现女性独特的气质。

示例3

{
    "content": "类型#裤*版型#显瘦*版型#h*材质#蕾丝*风格#ol*风格#潮*图案#蝴蝶结*图案#蕾丝*裤款式#绑带*裤腰型#高腰*裤口#微喇裤", 
    "summary": "优雅气质的花边领口设计,凸显服装的时尚新潮。时尚喇叭袖口搭配绑带蝴蝶结,蕾丝裙摆设计,穿着飘>     逸大方,彰显女神范。高腰设计,拉长腿部比例,a字裙摆,遮肉显瘦,有范优雅显气质,谁都能hold住的实穿款。"
}
微调前
高腰微喇裤
类型:裤
版型:显瘦
材质:h
风格:ol
潮
图案:蝴蝶结
蕾丝
裤款式:绑带
裤腰型:高腰
微喇裤
微调后(FP16)
这是一款非常时尚个性的裤子,高腰的设计,提高腰线,拉长身材比例,展现高挑曼妙的身姿。时尚的蕾丝设计,彰显浪漫气质。不规则裙摆设计,灵动飘逸,尽显个性。蝴蝶结的绑带设计,更显甜美气质。微喇的裤型设计,修身显瘦,尽显苗条身姿
微调后(INT4)
一款时尚又优雅的半身裙,采用经典的蕾丝设计,展现女性优雅的气质。高腰的版型,修饰腰部,更显纤细,同时展现女性高挑的身材。微喇的裤型,修饰腿部线条,更显高挑。蝴蝶结的绑带设计,展现女性的浪漫情怀。

FAQ

  • RuntimeError: Internal: a\sentencepiece\sentencepiece\src\sentencepiece_processor.cc(1102) [model_proto->ParseFromArray(serialized.data(), serialized.size())] 这个问题是因为我的模型文件分开下载的,我没有从清华云下载 ice_text.model 文件,版本不一致导致的。

  • RuntimeError: Currently topk on mps works only for k<=16

我重新安装了最新版的 PyTorch 就可以了。

pip install --upgrade --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu

参考资料