Skip to content

Commit b61cfff

Browse files
committed
fix: resolve normalize_embeddings parameter compatibility issue
- Fix compatibility with SentenceTransformer library by using normalize_embeddings parameter correctly - Move normalize_embeddings from constructor parameters to encode method parameter - Keep normalize_embeddings setting as instance variable for consistent encoding behavior
1 parent 23fe202 commit b61cfff

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

src/vector_search.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ def __init__(
3636
self.binary_embeddings = binary_embeddings
3737
self.collection_name = collection_name
3838
self.model_config = model_config or {}
39+
40+
# Store normalize_embeddings setting to use during encoding
41+
# Default to True if not specified in model_config
42+
self.normalize_embeddings = self.model_config.get("normalize_embeddings", True)
3943

4044
# Connect to Qdrant
4145
self.client = QdrantClient(host=host, port=port)
@@ -61,10 +65,12 @@ def _load_embedding_model(self, model_name: str) -> SentenceTransformer:
6165
# Apply model configuration if provided
6266
device = self.model_config.get("device", None) # Get device from config or use default
6367

64-
# Filter out invalid parameters for SentenceTransformer
65-
valid_params = {k: v for k, v in self.model_config.items()
66-
if k not in ["device", "quantization", "binary_embeddings"]}
67-
68+
# Extract valid parameters for SentenceTransformer constructor
69+
# Only 'device' and 'cache_folder' are valid for the constructor
70+
valid_params = {}
71+
if 'cache_folder' in self.model_config:
72+
valid_params['cache_folder'] = self.model_config['cache_folder']
73+
6874
# Load model with appropriate configuration
6975
return SentenceTransformer(
7076
model_name,
@@ -130,12 +136,11 @@ def _generate_embedding(self, text: str, batch_size: int = 32) -> List[float]:
130136
if prompt_template:
131137
text = prompt_template.format(text=text)
132138

133-
# Use appropriate encoding parameters based on model configuration
134-
normalize = self.model_config.get("normalize_embeddings", True)
139+
# Use the normalize_embeddings setting we saved during model initialization
135140
embedding = self.model.encode(
136141
text,
137142
batch_size=batch_size,
138-
normalize_embeddings=normalize,
143+
normalize_embeddings=self.normalize_embeddings,
139144
convert_to_tensor=False,
140145
show_progress_bar=False,
141146
)

0 commit comments

Comments
 (0)