@@ -74,15 +74,61 @@ def test_vector_search_initialization(mock_sentence_transformer, mock_qdrant_cli
74
74
model_config = {"device" : "cpu" },
75
75
)
76
76
77
- # Check that the model was loaded with some parameters
78
- mock_sentence_transformer .assert_called_once ()
77
+ # Check that the model was loaded with the correct parameters
78
+ # SentenceTransformer should only be called with model name and device
79
+ mock_sentence_transformer .assert_called_once_with ("test_model" , device = "cpu" )
80
+
81
+ # Check that normalize_embeddings was properly extracted from model_config
82
+ assert hasattr (vs , "normalize_embeddings" )
83
+ assert vs .normalize_embeddings is True # Default value
79
84
80
85
# Check that the client was created
81
86
mock_qdrant_client .assert_called_once_with (host = "localhost" , port = 6333 )
82
87
83
88
# Check that the collection was initialized
84
89
vs .client .get_collections .assert_called_once ()
85
90
vs .client .create_collection .assert_called_once ()
91
+
92
+ def test_vector_search_initialization_with_custom_settings (mock_sentence_transformer , mock_qdrant_client ):
93
+ """Test VectorSearch initialization with custom settings"""
94
+ # Create a VectorSearch instance with custom model_config
95
+ vs = VectorSearch (
96
+ host = "localhost" ,
97
+ port = 6333 ,
98
+ embedding_model = "test_model" ,
99
+ model_config = {
100
+ "device" : "cuda:0" ,
101
+ "cache_folder" : "/tmp/cache" ,
102
+ "normalize_embeddings" : False ,
103
+ "prompt_template" : "Code: {text}" ,
104
+ "invalid_param" : "should_be_ignored"
105
+ },
106
+ )
107
+
108
+ # Check that the model was loaded with ONLY the valid parameters
109
+ # Only model_name, device, and cache_folder should be passed to the constructor
110
+ mock_sentence_transformer .assert_called_once_with (
111
+ "test_model" ,
112
+ device = "cuda:0" ,
113
+ cache_folder = "/tmp/cache"
114
+ )
115
+
116
+ # Check that normalize_embeddings was properly extracted from model_config
117
+ assert vs .normalize_embeddings is False
118
+
119
+ # Test with normalize_embeddings explicitly included in model_config
120
+ mock_sentence_transformer .reset_mock ()
121
+
122
+ vs2 = VectorSearch (
123
+ host = "localhost" ,
124
+ port = 6333 ,
125
+ embedding_model = "test_model" ,
126
+ model_config = {"normalize_embeddings" : False },
127
+ )
128
+
129
+ # normalize_embeddings should NOT be passed to the constructor
130
+ mock_sentence_transformer .assert_called_once_with ("test_model" , device = None )
131
+ assert vs2 .normalize_embeddings is False
86
132
87
133
88
134
def test_generate_embedding (mock_sentence_transformer , mock_qdrant_client ):
@@ -98,17 +144,32 @@ def test_generate_embedding(mock_sentence_transformer, mock_qdrant_client):
98
144
# Generate an embedding
99
145
embedding = vs ._generate_embedding ("test text" )
100
146
101
- # Check that the prompt template was applied
147
+ # Check that the prompt template was applied and normalize_embeddings was correctly passed
102
148
vs .model .encode .assert_called_once_with (
103
149
"query: test text" ,
104
150
batch_size = 32 ,
105
- normalize_embeddings = True ,
151
+ normalize_embeddings = True , # This should match the value in model_config
106
152
convert_to_tensor = False ,
107
153
show_progress_bar = False ,
108
154
)
109
155
110
156
# Check that the embedding was converted to a list
111
157
assert isinstance (embedding , list )
158
+
159
+ # Test with normalize_embeddings set to False
160
+ vs .model .encode .reset_mock ()
161
+ vs .normalize_embeddings = False
162
+
163
+ embedding = vs ._generate_embedding ("test text" )
164
+
165
+ # Check that normalize_embeddings=False was passed to encode
166
+ vs .model .encode .assert_called_once_with (
167
+ "query: test text" ,
168
+ batch_size = 32 ,
169
+ normalize_embeddings = False , # Should use instance variable
170
+ convert_to_tensor = False ,
171
+ show_progress_bar = False ,
172
+ )
112
173
113
174
114
175
def test_index_file (mock_sentence_transformer , mock_qdrant_client ):
0 commit comments