Skip to content

Commit 6ddaf6a

Browse files
committed
test: improve VectorSearch tests to validate SentenceTransformer parameters
- Add tests that verify only valid parameters are passed to SentenceTransformer constructor - Add test for custom settings scenario with invalid parameters - Add test for normalize_embeddings parameter behavior - Ensure tests would catch future compatibility issues with the SentenceTransformer library
1 parent b61cfff commit 6ddaf6a

File tree

1 file changed

+65
-4
lines changed

1 file changed

+65
-4
lines changed

tests/unit/test_vector_search.py

Lines changed: 65 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,61 @@ def test_vector_search_initialization(mock_sentence_transformer, mock_qdrant_cli
7474
model_config={"device": "cpu"},
7575
)
7676

77-
# Check that the model was loaded with some parameters
78-
mock_sentence_transformer.assert_called_once()
77+
# Check that the model was loaded with the correct parameters
78+
# SentenceTransformer should only be called with model name and device
79+
mock_sentence_transformer.assert_called_once_with("test_model", device="cpu")
80+
81+
# Check that normalize_embeddings was properly extracted from model_config
82+
assert hasattr(vs, "normalize_embeddings")
83+
assert vs.normalize_embeddings is True # Default value
7984

8085
# Check that the client was created
8186
mock_qdrant_client.assert_called_once_with(host="localhost", port=6333)
8287

8388
# Check that the collection was initialized
8489
vs.client.get_collections.assert_called_once()
8590
vs.client.create_collection.assert_called_once()
91+
92+
def test_vector_search_initialization_with_custom_settings(mock_sentence_transformer, mock_qdrant_client):
93+
"""Test VectorSearch initialization with custom settings"""
94+
# Create a VectorSearch instance with custom model_config
95+
vs = VectorSearch(
96+
host="localhost",
97+
port=6333,
98+
embedding_model="test_model",
99+
model_config={
100+
"device": "cuda:0",
101+
"cache_folder": "/tmp/cache",
102+
"normalize_embeddings": False,
103+
"prompt_template": "Code: {text}",
104+
"invalid_param": "should_be_ignored"
105+
},
106+
)
107+
108+
# Check that the model was loaded with ONLY the valid parameters
109+
# Only model_name, device, and cache_folder should be passed to the constructor
110+
mock_sentence_transformer.assert_called_once_with(
111+
"test_model",
112+
device="cuda:0",
113+
cache_folder="/tmp/cache"
114+
)
115+
116+
# Check that normalize_embeddings was properly extracted from model_config
117+
assert vs.normalize_embeddings is False
118+
119+
# Test with normalize_embeddings explicitly included in model_config
120+
mock_sentence_transformer.reset_mock()
121+
122+
vs2 = VectorSearch(
123+
host="localhost",
124+
port=6333,
125+
embedding_model="test_model",
126+
model_config={"normalize_embeddings": False},
127+
)
128+
129+
# normalize_embeddings should NOT be passed to the constructor
130+
mock_sentence_transformer.assert_called_once_with("test_model", device=None)
131+
assert vs2.normalize_embeddings is False
86132

87133

88134
def test_generate_embedding(mock_sentence_transformer, mock_qdrant_client):
@@ -98,17 +144,32 @@ def test_generate_embedding(mock_sentence_transformer, mock_qdrant_client):
98144
# Generate an embedding
99145
embedding = vs._generate_embedding("test text")
100146

101-
# Check that the prompt template was applied
147+
# Check that the prompt template was applied and normalize_embeddings was correctly passed
102148
vs.model.encode.assert_called_once_with(
103149
"query: test text",
104150
batch_size=32,
105-
normalize_embeddings=True,
151+
normalize_embeddings=True, # This should match the value in model_config
106152
convert_to_tensor=False,
107153
show_progress_bar=False,
108154
)
109155

110156
# Check that the embedding was converted to a list
111157
assert isinstance(embedding, list)
158+
159+
# Test with normalize_embeddings set to False
160+
vs.model.encode.reset_mock()
161+
vs.normalize_embeddings = False
162+
163+
embedding = vs._generate_embedding("test text")
164+
165+
# Check that normalize_embeddings=False was passed to encode
166+
vs.model.encode.assert_called_once_with(
167+
"query: test text",
168+
batch_size=32,
169+
normalize_embeddings=False, # Should use instance variable
170+
convert_to_tensor=False,
171+
show_progress_bar=False,
172+
)
112173

113174

114175
def test_index_file(mock_sentence_transformer, mock_qdrant_client):

0 commit comments

Comments
 (0)