import type { EmbeddingProvider } from "./embedding-provider.js"; export type LocalEmbeddingModel = | "Xenova/all-MiniLM-L6-v2" | "Xenova/bge-small-en-v1.5" | "Xenova/bge-base-en-v1.5"; export type LocalEmbeddingProviderConfiguration = { model?: LocalEmbeddingModel; }; const MODEL_DIMENSIONS: Record = { "Xenova/all-MiniLM-L6-v2": 384, "Xenova/bge-small-en-v1.5": 384, "Xenova/bge-base-en-v1.5": 768, }; type Pipeline = ( texts: string[], options: { pooling: string; normalize: boolean }, ) => Promise<{ tolist: () => number[][] }>; export class LocalEmbeddingProvider implements EmbeddingProvider { private model: LocalEmbeddingModel; private pipelinePromise: Promise | null = null; readonly dimensions: number; constructor(configuration: LocalEmbeddingProviderConfiguration = {}) { this.model = configuration.model ?? "Xenova/all-MiniLM-L6-v2"; this.dimensions = MODEL_DIMENSIONS[this.model]; } private getPipeline(): Promise { if (!this.pipelinePromise) { this.pipelinePromise = (async () => { const { pipeline } = await import("@xenova/transformers"); return (await pipeline("feature-extraction", this.model)) as Pipeline; })(); } return this.pipelinePromise; } async generate(text: string): Promise { const pipeline = await this.getPipeline(); const output = await pipeline([text], { pooling: "mean", normalize: true, }); return output.tolist()[0] ?? []; } async generateBatch(texts: string[]): Promise { const pipeline = await this.getPipeline(); const output = await pipeline(texts, { pooling: "mean", normalize: true, }); return output.tolist(); } }