目录

MLX 介绍

MLX 是一个为 Apple Silicon 芯片上的机器学习研究设计的 array 框架,由 Apple 机器学习研究团队提供。

  • 熟悉的 API:MLX 拥有一个与 NumPy 紧密对应的 Python API。MLX 还拥有功能齐全的 C++、C 和 Swift API,这些 API 也紧密地反映了 Python API。MLX 拥有更高级别的包,如 mlx.nn 和 mlx.optimizers,它们的 API 紧密跟随 PyTorch,以简化构建更复杂模型的过程。
  • 统一内存:MLX 与其他框架的一个显著区别在于其统一内存模型。MLX 中的数组存在于共享内存中。可以在任何支持的设备类型上执行 MLX 数组的操作,无需数据传输。
  • MLX 的设计受到了像 NumPyPyTorchJaxArrayFire 这样的框架的启发。

安装

  • pip
    pip install mlx
    pip install mlx-lm
    
  • conda
    conda install -c conda-forge mlx
    conda install -c conda-forge mlx-lm
    
pip install sentence_transformers   # Mistral requires
pip install jinja2                  # Mistral requires
pip install tiktoken                # Qwen requires

生成

  • Mistral-7B-Instruct-v0.2
    python -m mlx_lm.generate \
      --model mistralai/Mistral-7B-Instruct-v0.2 \
      --prompt "Why is the sky blue?" \
      --max-tokens 500
    
==========
Prompt: <s>[INST] Why is the sky blue? [/INST]
The sky appears blue due to a phenomenon called Rayleigh scattering. As sunlight reaches Earth's atmosphere, 
it interacts with molecules and particles in the air, causing the scattering of light. Blue light has a 
shorter wavelength and gets scattered more easily than other colors, such as red or yellow, which have longer 
wavelengths. As a result, when we look up at the sky, we predominantly see the blue light that has been 
scattered, giving the sky its familiar blue hue. However, the color of the sky can change depending on the 
time of day, weather conditions, and location, as other factors can influence the type and amount of particles 
in the atmosphere that scatter light.
==========
Prompt: 34.115 tokens-per-sec
Generation: 19.374 tokens-per-sec
  • Qwen-7B-Chat
    python -m mlx_lm.generate \
      --model Qwen/Qwen-7B-Chat \
      --prompt "Why is the sky blue?" \
      --trust-remote-code \
      --eos-token "<|endoftext|>" \
      --max-tokens 500
    

对于某些模型(例如 Qwenplamo),分词器要求您启用 trust_remote_code 选项,信任终端中的远程代码。

对于 Qwen 模型,您还必须指定 eos_token。 您可以通过在命令行中传递 --eos-token "<|endoftext|>" 来完成此操作。

量化

  • 4-bit
    python -m mlx_lm.convert \
      --hf-path mistralai/Mistral-7B-Instruct-v0.2 \
      -q
    

量化后保存到 mlx_model 目录,可以使用参数 --mlx-path 指定保存目录。

mlx_model
├── config.json
├── model.safetensors
├── model.safetensors.index.json
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer.model
└── tokenizer_config.json

量化

  • float16
    python -m mlx_lm.convert \
      --hf-path mistralai/Mistral-7B-Instruct-v0.2 \
      --mlx-path Mistral-7B-Instruct-v0.2-float16 \
      --dtype float16
    

量化后的模型可以使用 mlx_lm.generate 运行。

python -m mlx_lm.generate \
    --model mlx_model \
    --prompt "Why is the sky blue?"

速度对比

模型 量化 Size (GB) Prompt (Tokens/S) Generation (Tokens/S)
mistralai/Mistral-7B-Instruct-v0.2(Hugging Face) bfloat16 14 43.115 19.415
Mistral-7B-Instruct-v0.2-float16 float16 14 37.357 20.494
Mistral-7B-Instruct-v0.2-4bit int4 4 30.121 52.568

数据集 WikiSQL

样本格式

{"text": "table: <table_name>
columns: <column_name1>, <column_name2>, <column_name3>
Q: <question>
A: SELECT <column_name2> FROM <table_name> WHERE <>"}

样本示例

{"text": "table: 1-1000181-1\n
columns: State/territory, Text/background colour, Format, Current slogan, Current series, Notes\n
Q: What is the current series where the new series began in June 2011?\n
A: SELECT Current series FROM 1-1000181-1 WHERE Notes = 'New series began in June 2011'"}

上面的示例是一行数据,使用 JSONL 格式存储。

微调(LoRA / QLoRA)

python -m mlx_lm.lora \
    --model mistralai/Mistral-7B-v0.1 \
    --train \
    --data <path_to_data> \
    --iters 600

默认适配器权重保存在 adapters.npz 文件中。您可以使用 --adapter-file 指定输出位置。

数据目录中应该包含 train.jsonlvalid.jsonl 文件。

评估

python -m mlx_lm.lora \
    --model mistralai/Mistral-7B-v0.1 \
    --adapter-file adapters.npz \
    --data <path_to_data> \
    --test

计算测试集困惑度。

数据目录中应该包含 test.jsonl 文件。

使用微调模型生成

python -m mlx_lm.generate \
    --model mistralai/Mistral-7B-v0.1 \
    --adapter-file adapters.npz \
    --prompt "Why is the sky blue?"

融合

python -m mlx_lm.fuse \
    --model mistralai/Mistral-7B-v0.1 \
    --adapter-file adapters.npz \
    --save-path fused_model

HTTP 服务

python -m mlx_lm.server \
    --model mistralai/Mistral-7B-Instruct-v0.2
  • --host HOST Host for the HTTP server (default: 127.0.0.1)
  • --port PORT Port for the HTTP server (default: 8080)
  • --adapter-file ADAPTER_FILE

访问模型服务

curl localhost:8080/v1/chat/completions \
  -H "Content-Type: application/json" \
  -d '{
     "messages": [{"role": "user", "content": "Why is the sky blue?"}],
     "temperature": 0.7,
     "max_tokens": 250
   }'

参考资料