返回首页
Continue AICodingAssistant

Continue 源码分析 - transformers.js 大模型提供者

Continue

Transformers.js 支持的模型

模型语言类型URL
jina-embeddings-v2-base-zh中文Embeddinghttps://huggingface.co/Xenova/jina-embeddings-v2-base-zh
bge-small-zh-v1.5中文Embeddinghttps://huggingface.co/Xenova/bge-small-zh-v1.5
bge-base-zh-v1.5中文Embeddinghttps://huggingface.co/Xenova/bge-base-zh-v1.5
bge-large-zh-v1.5中文Embeddinghttps://huggingface.co/Xenova/bge-large-zh-v1.5
all-MiniLM-L6-v2英文Embeddinghttps://huggingface.co/Xenova/all-MiniLM-L6-v2
bge-base-en-v1.5英文Embeddinghttps://huggingface.co/Xenova/bge-base-en-v1.5
bge-reranker-base英文Rerankerhttps://huggingface.co/Xenova/bge-reranker-base
TinyLlama-1.1B-Chat-v1.0英文LLMhttps://huggingface.co/Xenova/TinyLlama-1.1B-Chat-v1.0
phi-1_5_dev英文LLMhttps://huggingface.co/Xenova/phi-1_5_dev
Qwen1.5-0.5B中文LLMhttps://huggingface.co/Xenova/Qwen1.5-0.5B
Qwen1.5-1.8B中文LLMhttps://huggingface.co/Xenova/Qwen1.5-1.8B
Qwen1.5-0.5B-Chat中文LLMhttps://huggingface.co/Xenova/Qwen1.5-0.5B-Chat
codegen-350M-mono英文LLMhttps://huggingface.co/Xenova/codegen-350M-mono
codegen-350M-multi多语言Code LLMhttps://huggingface.co/Xenova/codegen-350M-multi
deepseek-coder-1.3b-base中文Code LLMhttps://huggingface.co/Xenova/deepseek-coder-1.3b-base
deepseek-coder-1.3b-instruct中文Code LLMhttps://huggingface.co/Xenova/deepseek-coder-1.3b-instruct

MTEB Chinese leaderboard (C-MTEB)

ModelModel Size (Million Parameters)Memory Usage (GB, fp32)Embedding DimensionsMax TokensAverage (35 datasets)Classification Average (9 datasets)Clustering Average (4 datasets)PairClassification Average (2 datasets)Reranking Average (4 datasets)Retrieval Average (8 datasets)STS Average (8 datasets)
bge-large-zh-v1.53261.21102451264.5369.1348.9981.665.8470.4656.25
bge-base-zh-v1.51020.3876851263.1368.0747.5379.7665.469.4953.72
bge-small-zh-v1.5240.0951251257.8263.9644.1870.460.9261.7749.1

Reindex the Workspaces

索引整个工作区,源代码:extensions/vscode/src/extension/VsCodeExtension.ts

export class VsCodeExtension {
  constructor(context: vscode.ExtensionContext) {
    // ...

    // 这段代码实现了一个监听文件保存事件的功能。当文件保存时,它会执行一系列操作。
    vscode.workspace.onDidSaveTextDocument(async (event) => {
      // Listen for file changes in the workspace
      const filepath = event.uri.fsPath;

      // ...

      // Reindex the workspaces
      this.core.invoke("index/forceReIndex", undefined);
    });

    // Refresh index when branch is changed
    // 这段代码是一个用于刷新索引的功能。它的作用是在分支更改时刷新索引。
    // 代码中的主要逻辑是通过调用 `this.ide.getWorkspaceDirs()` 获取工作区目录列表,然后对每个目录进行遍历。在遍历的过程中,通过调用 `this.ide.getRepo(vscode.Uri.file(dir))` 获取每个目录对应的代码仓库。
    // 如果获取到了代码仓库,就会注册一个回调函数,该回调函数会在代码仓库的状态发生变化时被触发。在回调函数中,会获取当前分支的名称,并与之前保存的分支名称进行比较。如果当前分支与之前保存的分支不一致,就会触发索引的强制刷新操作。
    // 值得注意的是,回调函数中的参数 `dir` 表示当前遍历的目录,而 `this.PREVIOUS_BRANCH_FOR_WORKSPACE_DIR` 是一个对象,用于保存每个目录对应的上一次的分支名称。
    // 总结一下,这段代码的作用是在分支更改时刷新索引,它通过监听代码仓库的状态变化来实现这一功能。
    this.ide.getWorkspaceDirs().then((dirs) =>
      dirs.forEach(async (dir) => {
        const repo = await this.ide.getRepo(vscode.Uri.file(dir));
        if (repo) {
          repo.state.onDidChange(() => {
            // args passed to this callback are always undefined, so keep track of previous branch
            const currentBranch = repo?.state?.HEAD?.name;
            if (currentBranch) {
              if (this.PREVIOUS_BRANCH_FOR_WORKSPACE_DIR[dir]) {
                if (
                  currentBranch !== this.PREVIOUS_BRANCH_FOR_WORKSPACE_DIR[dir]
                ) {
                  // Trigger refresh of index only in this directory
                  this.core.invoke("index/forceReIndex", dir);
                }
              }

              this.PREVIOUS_BRANCH_FOR_WORKSPACE_DIR[dir] = currentBranch;
            }
          });
        }
      }),
    );
  }
}

监听 index/forceReIndex,源代码:core/core.ts

export class Core {
  // TODO: It shouldn't actually need an IDE type, because this can happen
  // through the messenger (it does in the case of any non-VS Code IDEs already)
  // TODO:实际上它不应该需要一个 IDE 类型,因为这可以通过 messenger 来实现(在任何非 VS Code IDE 中已经实现了)。
  constructor(
    private readonly messenger: IMessenger<ToCoreProtocol, FromCoreProtocol>,
    private readonly ide: IDE,
    private readonly onWrite: (text: string) => Promise<void> = async () => {},
  ) {
    // ...

    // 监听 "index/forceReIndex" 事件。当该事件被触发时,代码从 IDE 中获取工作区目录,然后调用 refreshCodebaseIndex 方法来刷新这些目录的代码库索引。
    on("index/forceReIndex", async (msg) => {
      const dirs = msg.data ? [msg.data] : await this.ide.getWorkspaceDirs();
      await this.refreshCodebaseIndex(dirs);
    });

    // ...
  }
}

调用 refreshCodebaseIndex 方法,源代码:core/core.ts

  private indexingCancellationController: AbortController | undefined;

  private async refreshCodebaseIndex(dirs: string[]) {
    if (this.indexingCancellationController) {
      this.indexingCancellationController.abort();
    }
    this.indexingCancellationController = new AbortController();
    // 刷新代码库索引
    for await (const update of (await this.codebaseIndexerPromise).refresh(
      dirs,
      this.indexingCancellationController.signal,
    )) {
      this.messenger.request("indexProgress", update);
      this.indexingState = update;
    }

    this.messenger.send("refreshSubmenuItems", undefined);
  }

调用 refresh 方法,源代码:core/indexing/CodebaseIndexer.ts

export class CodebaseIndexer {
  async *refresh(
    workspaceDirs: string[],
    abortSignal: AbortSignal,
  ): AsyncGenerator<IndexingProgressUpdate> {
    let progress = 0;

    if (workspaceDirs.length === 0) {
      yield {
        progress,
        desc: "Nothing to index",
        status: "disabled",
      };
      return;
    }

    const config = await this.configHandler.loadConfig();
    if (config.disableIndexing) {
      yield {
        progress,
        desc: "Indexing is disabled in config.json",
        status: "disabled",
      };
      return;
    } else {
      yield {
        progress,
        desc: "Starting indexing",
        status: "loading",
      };
    }

    const indexesToBuild = await this.getIndexesToBuild();
    let completedDirs = 0;
    const totalRelativeExpectedTime = indexesToBuild.reduce(
      (sum, index) => sum + index.relativeExpectedTime,
      0,
    );

    // Wait until Git Extension has loaded to report progress
    // so we don't appear stuck at 0% while waiting
    await this.ide.getRepoName(workspaceDirs[0]);

    yield {
      progress,
      desc: "Starting indexing...",
      status: "loading",
    };

    for (const directory of workspaceDirs) {
      const files = await walkDir(directory, this.ide);
      const stats = await this.ide.getLastModified(files);
      const branch = await this.ide.getBranch(directory);
      const repoName = await this.ide.getRepoName(directory);
      let completedRelativeExpectedTime = 0;

      for (const codebaseIndex of indexesToBuild) {
        // TODO: IndexTag type should use repoName rather than directory
        const tag: IndexTag = {
          directory,
          branch,
          artifactId: codebaseIndex.artifactId,
        };
        const [results, lastUpdated, markComplete] = await getComputeDeleteAddRemove(
          tag,
          { ...stats },
          (filepath) => this.ide.readFile(filepath),
          repoName,
        );

        try {
          for await (let {
            progress: indexProgress,
            desc,
          } of codebaseIndex.update(tag, results, markComplete, repoName)) {
            // Handle pausing in this loop because it's the only one really taking time
            if (abortSignal.aborted) {
              yield {
                progress: 1,
                desc: "Indexing cancelled",
                status: "disabled",
              };
              return;
            }

            if (this.pauseToken.paused) {
              yield {
                progress,
                desc: "Paused",
                status: "paused",
              };
              while (this.pauseToken.paused) {
                await new Promise((resolve) => setTimeout(resolve, 100));
              }
            }

            progress =
              (completedDirs +
                (completedRelativeExpectedTime +
                  Math.min(1.0, indexProgress) *
                    codebaseIndex.relativeExpectedTime) /
                  totalRelativeExpectedTime) /
              workspaceDirs.length;
            yield {
              progress,
              desc,
              status: "indexing",
            };
          }

          lastUpdated.forEach((lastUpdated, path) => {
            markComplete([lastUpdated], IndexResultType.UpdateLastUpdated);
          });

          completedRelativeExpectedTime += codebaseIndex.relativeExpectedTime;
          yield {
            progress:
              (completedDirs +
                completedRelativeExpectedTime / totalRelativeExpectedTime) /
              workspaceDirs.length,
            desc: "Completed indexing " + codebaseIndex.artifactId,
            status: "indexing",
          };
        } catch (e: any) {
          let errMsg = `${e}`;

          const errorRegex =
            /Invalid argument error: Values length (\d+) is less than the length \((\d+)\) multiplied by the value size \(\d+\)/;
          const match = e.message.match(errorRegex);

          if (match) {
            const [_, valuesLength, expectedLength] = match;
            errMsg = `Generated embedding had length ${valuesLength} but was expected to be ${expectedLength}. This may be solved by deleting ~/.continue/index and refreshing the window to re-index.`;
          }

          yield {
            progress: 0,
            desc: errMsg,
            status: "failed",
          };

          console.warn(
            `Error updating the ${codebaseIndex.artifactId} index: ${e}`,
          );
          return;
        }
      }

      completedDirs++;
      progress = completedDirs / workspaceDirs.length;
      yield {
        progress,
        desc: "Indexing Complete",
        status: "done",
      };
    }
  }
}

调用 update 方法,源代码:core/indexing/LanceDbIndex.ts

export class LanceDbIndex implements CodebaseIndex {
  async *update(
    tag: IndexTag,
    results: RefreshIndexResults,
    markComplete: (
      items: PathAndCacheKey[],
      resultType: IndexResultType,
    ) => void,
    repoName: string | undefined,
  ): AsyncGenerator<IndexingProgressUpdate> {
    const lancedb = await import("vectordb");
    const tableName = this.tableNameForTag(tag);
    const db = await lancedb.connect(getLanceDbPath());

    const sqlite = await SqliteDb.get();
    await this.createSqliteCacheTable(sqlite);

    // ...

    const progressReservedForTagging = 0.1;
    let accumulatedProgress = 0;

    let computedRows: LanceDbRow[] = [];
    for await (const update of this.computeChunks(results.compute)) {
      if (Array.isArray(update)) {
        const [progress, row, data, desc] = update;
        computedRows.push(row);

        // Add the computed row to the cache
        await sqlite.run(
          "INSERT INTO lance_db_cache (uuid, cacheKey, path, artifact_id, vector, startLine, endLine, contents) VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
          row.uuid,
          row.cachekey,
          row.path,
          this.artifactId,
          JSON.stringify(row.vector),
          data.startLine,
          data.endLine,
          data.contents,
        );

        accumulatedProgress = progress * (1 - progressReservedForTagging);
        yield {
          progress: accumulatedProgress,
          desc,
          status: "indexing",
        };
      } else {
        await addComputedLanceDbRows(update, computedRows);
        computedRows = [];
      }
    }

    // ...
  }
}

调用 computeChunks 方法,源代码:core/indexing/LanceDbIndex.ts

export class LanceDbIndex implements CodebaseIndex {
  private async *computeChunks(
    items: PathAndCacheKey[],
  ): AsyncGenerator<
    | [
        number,
        LanceDbRow,
        { startLine: number; endLine: number; contents: string },
        string,
      ]
    | PathAndCacheKey
  > {
    const contents = await Promise.all(
      items.map(({ path }) => this.readFile(path)),
    );

    for (let i = 0; i < items.length; i++) {
      // Break into chunks
      const content = contents[i];
      const chunks: Chunk[] = [];

      let hasEmptyChunks = false;

      for await (const chunk of chunkDocument({
        filepath: items[i].path,
        contents: content,
        maxChunkSize: this.embeddingsProvider.maxChunkSize,
        digest: items[i].cacheKey,
      })) {
        if (chunk.content.length == 0) {
          hasEmptyChunks = true;
          break;
        }
        chunks.push(chunk);
      }

      if (hasEmptyChunks) {
        // File did not chunk properly, let's skip it.
        continue;
      }

      if (chunks.length > 20) {
        // Too many chunks to index, probably a larger file than we want to include
        continue;
      }

      let embeddings: number[][];
      try {
        // Calculate embeddings
        embeddings = await this.embeddingsProvider.embed(
          chunks.map((c) => c.content),
        );
      } catch (e) {
        // Rather than fail the entire indexing process, we'll just skip this file
        // so that it may be picked up on the next indexing attempt
        console.warn(
          `Failed to generate embedding for ${chunks[0]?.filepath} with provider: ${this.embeddingsProvider.id}: ${e}`,
        );
        continue;
      }

      if (embeddings.some((emb) => emb === undefined)) {
        throw new Error(
          `Failed to generate embedding for ${chunks[0]?.filepath} with provider: ${this.embeddingsProvider.id}`,
        );
      }

      // Create row format
      for (let j = 0; j < chunks.length; j++) {
        const progress = (i + j / chunks.length) / items.length;
        const row = {
          vector: embeddings[j],
          path: items[i].path,
          cachekey: items[i].cacheKey,
          uuid: uuidv4(),
        };
        const chunk = chunks[j];
        yield [
          progress,
          row,
          {
            contents: chunk.content,
            startLine: chunk.startLine,
            endLine: chunk.endLine,
          },
          `Indexing ${getBasename(chunks[j].filepath)}`,
        ];
      }

      yield items[i];
    }
  }
}

调用 embed 方法,源代码:core/indexing/embeddings/TransformersJsEmbeddingsProvider.ts

  • 调用 EmbeddingsPipeline.getInstance 方法,获取 EmbeddingsPipeline 实例
class EmbeddingsPipeline {
  static task: PipelineType = "feature-extraction";
  static model = "all-MiniLM-L6-v2";
  static instance: any | null = null;

  static async getInstance() {
    if (EmbeddingsPipeline.instance === null) {
      // @ts-ignore
      // prettier-ignore
      const { env, pipeline } = await import("../../vendor/modules/@xenova/transformers/src/transformers.js");

      env.allowLocalModels = true;
      env.allowRemoteModels = false;
      env.localModelPath = path.join(
        typeof __dirname === "undefined"
          ? // @ts-ignore
            path.dirname(new URL(import.meta.url).pathname)
          : __dirname,
        "..",
        "models",
      );

      EmbeddingsPipeline.instance = await pipeline(
        EmbeddingsPipeline.task,
        EmbeddingsPipeline.model,
      );
    }

    return EmbeddingsPipeline.instance;
  }
}

export class TransformersJsEmbeddingsProvider extends BaseEmbeddingsProvider {
  static providerName: EmbeddingsProviderName = "transformers.js";
  static maxGroupSize: number = 4;
  static model: string = "all-MiniLM-L6-v2";

  constructor() {
    super({ model: TransformersJsEmbeddingsProvider.model }, () =>
      Promise.resolve(null),
    );
  }

  async embed(chunks: string[]) {
    const extractor = await EmbeddingsPipeline.getInstance();

    if (!extractor) {
      throw new Error("TransformerJS embeddings pipeline is not initialized");
    }

    if (chunks.length === 0) {
      return [];
    }

    const outputs = [];
    for (
      let i = 0;
      i < chunks.length;
      i += TransformersJsEmbeddingsProvider.maxGroupSize
    ) {
      const chunkGroup = chunks.slice(
        i,
        i + TransformersJsEmbeddingsProvider.maxGroupSize,
      );
      const output = await extractor(chunkGroup, {
        pooling: "mean",
        normalize: true,
      });
      outputs.push(...output.tolist());
    }
    return outputs;
  }
}

下载 Xenova/bge-small-zh-v1.5 模型到 extensions/vscode/models 目录。

bge-small-zh-v1.5
├── README.md
├── config.json
├── onnx
│   └── model_quantized.onnx
├── quantize_config.json
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer_config.json
└── vocab.txt

模型大小:23M

在配置文件 ~/.continue/config.json 中配置 transformers.js 模型,并不能生效,目前没有指定模型的功能。

{
  "embeddingsProvider": {
    "provider": "transformers.js",
    "model": "bge-small-zh-v1.5"  // 无效
  },
}

可以通过文件 core/indexing/embeddings/TransformersJsEmbeddingsProvider.ts 修改来指定模型。

  static model: string = "bge-small-zh-v1.5";

两处都要修改

Adding an LLM Provider

  1. 在 core/llm/llms 目录中创建一个新文件。该文件的名称应为提供程序(provider)的名称(name),并且它应该导出一个扩展 BaseLLM 的类。这个类应该包含以下最小实现。LlamaCpp 提供程序是一个很好的简单示例。
  • providerName - 您的提供程序的标识符
  • 至少 _streamComplete_streamChat 之一 - 这是向 API 发出请求并返回流式响应的函数。您只需要实现其中一个,因为 Continue 可以在 “chat” 和 “raw completion” 之间自动转换。
  1. 将您的提供程序添加到 LLMs 数组中,位于 core/llm/llms/index.ts。
  2. 如果您的提供程序支持图像,将其添加到 PROVIDER_SUPPORTS_IMAGES 数组中,位于core/llm/index.ts。
  3. 将必要的 JSON 模式类型添加到 config_schema.json。这确保当用户编辑 config.json时,Intellisense 向用户显示您的提供程序可用的选项。
  4. 为您的提供程序在 docs/docs/reference/Model Providers 中添加一个文档页面。这应该展示在 config.json 中配置您的提供程序的示例,并解释可用的选项。

1. 创建一个新的 LLM 提供程序 TransformersJS

创建文件:core/llm/llms/TransformersJS.ts

import path from "path";
import {
  ChatMessage,
  CompletionOptions,
  LLMOptions,
  ModelProvider,
} from "../../index.js";
import { stripImages } from "../images.js";
import { BaseLLM } from "../index.js";
import { streamResponse } from "../stream.js";
// @ts-ignore
// prettier-ignore
import { type PipelineType } from "../../vendor/modules/@xenova/transformers/src/transformers.js";

class LLMPipeline {
  static task: PipelineType = "text-generation";
  static model = "gpt2";
  static instance: any | null = null;

  static async getInstance() {
    if (LLMPipeline.instance === null) {
      // @ts-ignore
      // prettier-ignore
      const { env, pipeline } = await import("../../vendor/modules/@xenova/transformers/src/transformers.js");

      env.allowLocalModels = true;
      env.allowRemoteModels = false;
      env.localModelPath = path.join(
        typeof __dirname === "undefined"
          ? // @ts-ignore
            path.dirname(new URL(import.meta.url).pathname)
          : __dirname,
        "..",
        "models",
      );

      LLMPipeline.instance = await pipeline(
        LLMPipeline.task,
        LLMPipeline.model,
      );
    }

    return LLMPipeline.instance;
  }
}

class TransformersJS extends BaseLLM {
  static providerName: ModelProvider = "transformers.js";

  constructor(options: LLMOptions) {
    super(options);

    if (options.model === "AUTODETECT") {
      return;
    }
  }

  protected async *_streamComplete(
    prompt: string,
    options: CompletionOptions,
  ): AsyncGenerator<string> {
    const generator = await LLMPipeline.getInstance();
    const output = await generator(prompt);
    const content = output[0].generated_text;
    yield content;
  }

  protected async *_streamChat(
    messages: ChatMessage[],
    options: CompletionOptions,
  ): AsyncGenerator<ChatMessage> {
    const lastMessage = messages[messages.length - 1].content;
    const generator = await LLMPipeline.getInstance();
    const output = await generator(lastMessage);
    const content = output[0].generated_text;
    yield {
      role: "assistant",
      content: content,
    };
  }

}

export default TransformersJS;

2. 添加到 LLMs 数组

编辑文件:core/llm/llms/index.ts

const LLMs = [
  Anthropic,
  Cohere,
  FreeTrial,
  Gemini,
  Llamafile,
  Ollama,
  Replicate,
  TextGenWebUI,
  Together,
  HuggingFaceTGI,
  HuggingFaceInferenceAPI,
  LlamaCpp,
  OpenAI,
  LMStudio,
  Mistral,
  Bedrock,
  DeepInfra,
  Flowise,
  Groq,
  Fireworks,
  ContinueProxy,
  Cloudflare,
  Deepseek,
  Msty,
  Azure,
  TransformersJS
];

4. 添加 JSON 模式类型

编辑文件:extensions/vscode/config_schema.json

{
  "definitions": {
    "ModelDescription": {
      "title": "ModelDescription",
      "type": "object",
      "properties": {
        "title": {
          "title": "Title",
          "description": "The title you wish to give your model.",
          "type": "string"
        },
        "provider": {
          "title": "Provider",
          "description": "The provider of the model. This is used to determine the type of model, and how to interact with it.",
          "enum": [
            //...
            "transformers.js"
          ],
          "markdownEnumDescriptions": [
            //...
            "### Transformers.js\nTransformers.js is a JavaScript library that provides a simple interface to run models. To get started, follow the steps [here](https://docs.continue.dev/reference/Model%20Providers/transformersjs)"
          ],
          "type": "string"
        }
      }
    }
  }
}

下载 GPT2 模型到目录 extensions/vscode/models

gpt2
├── README.md
├── config.json
├── generation_config.json
├── merges.txt
├── onnx
│   └── decoder_model_merged_quantized.onnx
├── quantize_config.json
├── special_tokens_map.json
├── tokenizer.json
├── tokenizer_config.json
└── vocab.json

模型大小:126M

修改配置 ~/.continue/config.json

{
  "models": [
    {
      "title": "Transformer.js gpt2",
      "model": "gpt2",
      "contextLength": 4096,
      "provider": "transformers.js"
    }
  ]
}

运行 Continue

🤖

智能问答助手

Ollama + AI 问答

⏳ 初始化...

💡 配置和聊天记录仅保存在本地浏览器中