diff options
| -rw-r--r-- | packages/sdk/src/local-embedding-provider.test.ts | 105 |
1 files changed, 105 insertions, 0 deletions
diff --git a/packages/sdk/src/local-embedding-provider.test.ts b/packages/sdk/src/local-embedding-provider.test.ts new file mode 100644 index 0000000..b002e00 --- /dev/null +++ b/packages/sdk/src/local-embedding-provider.test.ts @@ -0,0 +1,105 @@ +import { describe, it, expect, beforeAll } from "vitest"; +import { LocalEmbeddingProvider } from "./local-embedding-provider.js"; +import type { EmbeddingProvider } from "./embedding-provider.js"; + +describe("LocalEmbeddingProvider", () => { + describe("interface compliance", () => { + it("implements EmbeddingProvider interface", () => { + const provider: EmbeddingProvider = new LocalEmbeddingProvider(); + + expect(provider.generate).toBeDefined(); + expect(provider.generateBatch).toBeDefined(); + expect(provider.dimensions).toBeDefined(); + }); + }); + describe("dimensions", () => { + it("returns 384 for MiniLM model", () => { + const provider = new LocalEmbeddingProvider({ + model: "Xenova/all-MiniLM-L6-v2", + }); + + expect(provider.dimensions).toBe(384); + }); + it("returns 384 for bge-small model", () => { + const provider = new LocalEmbeddingProvider({ + model: "Xenova/bge-small-en-v1.5", + }); + + expect(provider.dimensions).toBe(384); + }); + it("returns 768 for bge-base model", () => { + const provider = new LocalEmbeddingProvider({ + model: "Xenova/bge-base-en-v1.5", + }); + + expect(provider.dimensions).toBe(768); + }); + it("defaults to MiniLM with 384 dimensions", () => { + const provider = new LocalEmbeddingProvider(); + + expect(provider.dimensions).toBe(384); + }); + }); + describe("generate", () => { + let provider: LocalEmbeddingProvider; + + beforeAll(() => { + provider = new LocalEmbeddingProvider({ + model: "Xenova/all-MiniLM-L6-v2", + }); + }); + it("returns an array of numbers", async () => { + const embedding = await provider.generate("Hello world"); + + expect(Array.isArray(embedding)).toBe(true); + expect(embedding.every((element) => typeof element === "number")).toBe( + true, + ); + }); + it("returns embedding with correct dimensions", async () => { + const embedding = await provider.generate("Hello world"); + + expect(embedding).toHaveLength(provider.dimensions); + }); + it("returns normalized embedding vectors", async () => { + const embedding = await provider.generate("Test normalization"); + const magnitude = Math.sqrt( + embedding.reduce( + (accumulatedSquareSum, component) => + accumulatedSquareSum + component * component, + 0, + ), + ); + + expect(magnitude).toBeCloseTo(1, 4); + }); + }); + describe("generateBatch", () => { + let provider: LocalEmbeddingProvider; + + beforeAll(() => { + provider = new LocalEmbeddingProvider({ + model: "Xenova/all-MiniLM-L6-v2", + }); + }); + it("returns correct number of embeddings", async () => { + const texts = ["Hello", "World", "Test"]; + const embeddings = await provider.generateBatch(texts); + + expect(embeddings).toHaveLength(texts.length); + }); + it("returns embeddings with correct dimensions for each text", async () => { + const texts = ["Hello", "World"]; + const embeddings = await provider.generateBatch(texts); + + for (const embedding of embeddings) { + expect(embedding).toHaveLength(provider.dimensions); + } + }); + it("throws error for empty input", async () => { + await expect(provider.generateBatch([])).rejects.toThrow( + "text array must be non-empty", + ); + }); + }); +}); |