Skip to content

Commit c71b7fe

Browse files
committed
feat: add pooling to HuggingFaceEmbedder, add CompositeEmbedder
1 parent 1603f44 commit c71b7fe

File tree

2 files changed

+38
-1
lines changed

2 files changed

+38
-1
lines changed

meilisearch/index.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from meilisearch.errors import version_error_hint_message
2626
from meilisearch.models.document import Document, DocumentsResults
2727
from meilisearch.models.index import (
28+
CompositeEmbedder,
2829
Embedders,
2930
Faceting,
3031
HuggingFaceEmbedder,
@@ -961,6 +962,7 @@ def get_settings(self) -> Dict[str, Any]:
961962
| HuggingFaceEmbedder
962963
| OllamaEmbedder
963964
| RestEmbedder
965+
| CompositeEmbedder
964966
| UserProvidedEmbedder,
965967
] = {}
966968
for k, v in settings["embedders"].items():
@@ -972,6 +974,8 @@ def get_settings(self) -> Dict[str, Any]:
972974
embedders[k] = HuggingFaceEmbedder(**v)
973975
elif v.get("source") == "rest":
974976
embedders[k] = RestEmbedder(**v)
977+
elif v.get("source") == "composite":
978+
embedders[k] = CompositeEmbedder(**v)
975979
else:
976980
embedders[k] = UserProvidedEmbedder(**v)
977981

@@ -1900,6 +1904,7 @@ def get_embedders(self) -> Embedders | None:
19001904
| HuggingFaceEmbedder
19011905
| OllamaEmbedder
19021906
| RestEmbedder
1907+
| CompositeEmbedder
19031908
| UserProvidedEmbedder,
19041909
] = {}
19051910
for k, v in response.items():
@@ -1911,6 +1916,8 @@ def get_embedders(self) -> Embedders | None:
19111916
embedders[k] = HuggingFaceEmbedder(**v)
19121917
elif v.get("source") == "rest":
19131918
embedders[k] = RestEmbedder(**v)
1919+
elif v.get("source") == "composite":
1920+
embedders[k] = CompositeEmbedder(**v)
19141921
else:
19151922
embedders[k] = UserProvidedEmbedder(**v)
19161923

meilisearch/models/index.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ class LocalizedAttributes(CamelBase):
6464
locales: List[str]
6565

6666

67+
class PoolingOpt(str, Enum):
68+
USE_MODEL = "useModel"
69+
FORCE_MEAN = "forceMean"
70+
FORCE_CLS = "forceCls"
71+
72+
6773
class OpenAiEmbedder(CamelBase):
6874
source: str = "openAi"
6975
url: Optional[str] = None
@@ -84,6 +90,7 @@ class HuggingFaceEmbedder(CamelBase):
8490
document_template_max_bytes: Optional[int] = None # Default to 400
8591
distribution: Optional[EmbedderDistribution] = None
8692
binary_quantized: Optional[bool] = None
93+
pooling: Optional[PoolingOpt] = None
8794

8895

8996
class OllamaEmbedder(CamelBase):
@@ -117,10 +124,33 @@ class UserProvidedEmbedder(CamelBase):
117124
binary_quantized: Optional[bool] = None
118125

119126

127+
class CompositeEmbedder(CamelBase):
128+
source: str = "composite"
129+
search_embedder: Union[
130+
OpenAiEmbedder,
131+
HuggingFaceEmbedder,
132+
OllamaEmbedder,
133+
RestEmbedder,
134+
UserProvidedEmbedder
135+
]
136+
indexing_embedder: Union[
137+
OpenAiEmbedder,
138+
HuggingFaceEmbedder,
139+
OllamaEmbedder,
140+
RestEmbedder,
141+
UserProvidedEmbedder
142+
]
143+
144+
120145
class Embedders(CamelBase):
121146
embedders: Dict[
122147
str,
123148
Union[
124-
OpenAiEmbedder, HuggingFaceEmbedder, OllamaEmbedder, RestEmbedder, UserProvidedEmbedder
149+
OpenAiEmbedder,
150+
HuggingFaceEmbedder,
151+
OllamaEmbedder,
152+
RestEmbedder,
153+
UserProvidedEmbedder,
154+
CompositeEmbedder,
125155
],
126156
]

0 commit comments

Comments
 (0)