aboutsummaryrefslogtreecommitdiff
path: root/packages/sdk/src/local-embedding-provider.test.ts
blob: 2674158a98554daf65b7ef5c9a2f781e8dd2a5d9 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import { beforeAll, describe, expect, it } from "vitest";
import type { EmbeddingProvider } from "./embedding-provider.js";
import { LocalEmbeddingProvider } from "./local-embedding-provider.js";

describe("LocalEmbeddingProvider", () => {
	describe("interface compliance", () => {
		it("implements EmbeddingProvider interface", () => {
			const provider: EmbeddingProvider = new LocalEmbeddingProvider();

			expect(provider.generate).toBeDefined();
			expect(provider.generateBatch).toBeDefined();
			expect(provider.dimensions).toBeDefined();
		});
	});
	describe("dimensions", () => {
		it("returns 384 for MiniLM model", () => {
			const provider = new LocalEmbeddingProvider({
				model: "Xenova/all-MiniLM-L6-v2",
			});

			expect(provider.dimensions).toBe(384);
		});
		it("returns 384 for bge-small model", () => {
			const provider = new LocalEmbeddingProvider({
				model: "Xenova/bge-small-en-v1.5",
			});

			expect(provider.dimensions).toBe(384);
		});
		it("returns 768 for bge-base model", () => {
			const provider = new LocalEmbeddingProvider({
				model: "Xenova/bge-base-en-v1.5",
			});

			expect(provider.dimensions).toBe(768);
		});
		it("defaults to MiniLM with 384 dimensions", () => {
			const provider = new LocalEmbeddingProvider();

			expect(provider.dimensions).toBe(384);
		});
	});
	describe("generate", () => {
		let provider: LocalEmbeddingProvider;

		beforeAll(() => {
			provider = new LocalEmbeddingProvider({
				model: "Xenova/all-MiniLM-L6-v2",
			});
		});
		it("returns an array of numbers", async () => {
			const embedding = await provider.generate("Hello world");

			expect(Array.isArray(embedding)).toBe(true);
			expect(embedding.every((element) => typeof element === "number")).toBe(
				true,
			);
		});
		it("returns embedding with correct dimensions", async () => {
			const embedding = await provider.generate("Hello world");

			expect(embedding).toHaveLength(provider.dimensions);
		});
		it("returns normalized embedding vectors", async () => {
			const embedding = await provider.generate("Test normalization");
			const magnitude = Math.sqrt(
				embedding.reduce(
					(accumulatedSquareSum, component) =>
						accumulatedSquareSum + component * component,
					0,
				),
			);

			expect(magnitude).toBeCloseTo(1, 4);
		});
	});
	describe("generateBatch", () => {
		let provider: LocalEmbeddingProvider;

		beforeAll(() => {
			provider = new LocalEmbeddingProvider({
				model: "Xenova/all-MiniLM-L6-v2",
			});
		});
		it("returns correct number of embeddings", async () => {
			const texts = ["Hello", "World", "Test"];
			const embeddings = await provider.generateBatch(texts);

			expect(embeddings).toHaveLength(texts.length);
		});
		it("returns embeddings with correct dimensions for each text", async () => {
			const texts = ["Hello", "World"];
			const embeddings = await provider.generateBatch(texts);

			for (const embedding of embeddings) {
				expect(embedding).toHaveLength(provider.dimensions);
			}
		});
		it("throws error for empty input", async () => {
			await expect(provider.generateBatch([])).rejects.toThrow(
				"text array must be non-empty",
			);
		});
	});
});