Skip to content

Commit a3709ba

Browse files
Adding unit tests for dspy.retrievers.Embeddings (#8129)
* Adding unit test for dspy.retrievers.Embeddings Added three unit tests for the embeddings. 1) test_embeddings_basic_search: Verifies that the retriever returns the correct top k relevant passages and their indices for a single query. 2) test_embeddings_forward_batch: Ensures the retriever handles batch queries correctly, returning the top k relevant passages and indices for each query. 3) test_normalization: Confirms that the embeddings are correctly normalized to have a norm close to 1 after processing. * updating test file 1) Removed tests that were calling private functions. 2) Added a new test to check robustness on high concurrency. 3) Updated dummy embedder by keeping similar data close and different data far away. * style fix --------- Co-authored-by: chenmoneygithub <[email protected]>
1 parent 9272605 commit a3709ba

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

tests/retrievers/test_embeddings.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
from concurrent.futures import ThreadPoolExecutor, as_completed
2+
import numpy as np
3+
import pytest
4+
from dspy.retrievers.embeddings import Embeddings
5+
6+
7+
def dummy_corpus():
8+
return [
9+
"The cat sat on the mat.",
10+
"The dog barked at the mailman.",
11+
"Birds fly in the sky.",
12+
]
13+
14+
15+
def dummy_embedder(texts):
16+
embeddings = []
17+
for text in texts:
18+
if "cat" in text:
19+
embeddings.append(np.array([1, 0, 0], dtype=np.float32))
20+
elif "dog" in text:
21+
embeddings.append(np.array([0, 1, 0], dtype=np.float32))
22+
else:
23+
embeddings.append(np.array([0, 0, 1], dtype=np.float32))
24+
return np.stack(embeddings)
25+
26+
27+
def test_embeddings_basic_search():
28+
corpus = dummy_corpus()
29+
embedder = dummy_embedder
30+
31+
retriever = Embeddings(corpus=corpus, embedder=embedder, k=1)
32+
33+
query = "I saw a dog running."
34+
result = retriever(query)
35+
36+
assert hasattr(result, "passages")
37+
assert hasattr(result, "indices")
38+
39+
assert isinstance(result.passages, list)
40+
assert isinstance(result.indices, list)
41+
42+
assert len(result.passages) == 1
43+
assert len(result.indices) == 1
44+
45+
assert result.passages[0] == "The dog barked at the mailman."
46+
47+
48+
def test_embeddings_multithreaded_search():
49+
corpus = dummy_corpus()
50+
embedder = dummy_embedder
51+
52+
retriever = Embeddings(corpus=corpus, embedder=embedder, k=1)
53+
54+
queries = [
55+
("A cat is sitting on the mat.", "The cat sat on the mat."),
56+
("My dog is awesome!", "The dog barked at the mailman."),
57+
("Birds flying high.", "Birds fly in the sky."),
58+
] * 10
59+
60+
def worker(query_text, expected_passage):
61+
result = retriever(query_text)
62+
assert result.passages[0] == expected_passage
63+
return result.passages[0]
64+
65+
with ThreadPoolExecutor(max_workers=10) as executor:
66+
futures = [executor.submit(worker, q, expected) for q, expected in queries]
67+
# Results will be in original order
68+
results = [f.result() for f in futures]
69+
assert results[0] == "The cat sat on the mat."
70+
assert results[1] == "The dog barked at the mailman."
71+
assert results[2] == "Birds fly in the sky."

0 commit comments

Comments
 (0)