Skip to content

Add composite embedders and pooling for hf models #1104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
7 changes: 7 additions & 0 deletions meilisearch/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from meilisearch.errors import version_error_hint_message
from meilisearch.models.document import Document, DocumentsResults
from meilisearch.models.embedders import (
CompositeEmbedder,
Embedders,
EmbedderType,
HuggingFaceEmbedder,
Expand Down Expand Up @@ -977,6 +978,8 @@ def get_settings(self) -> Dict[str, Any]:
embedders[k] = HuggingFaceEmbedder(**v)
elif v.get("source") == "rest":
embedders[k] = RestEmbedder(**v)
elif v.get("source") == "composite":
embedders[k] = CompositeEmbedder(**v)
else:
embedders[k] = UserProvidedEmbedder(**v)

Expand Down Expand Up @@ -1934,6 +1937,8 @@ def get_embedders(self) -> Embedders | None:
embedders[k] = OllamaEmbedder(**v)
elif source == "rest":
embedders[k] = RestEmbedder(**v)
elif source == "composite":
embedders[k] = CompositeEmbedder(**v)
elif source == "userProvided":
embedders[k] = UserProvidedEmbedder(**v)
else:
Expand Down Expand Up @@ -1977,6 +1982,8 @@ def update_embedders(self, body: Union[MutableMapping[str, Any], None]) -> TaskI
embedders[k] = OllamaEmbedder(**v)
elif source == "rest":
embedders[k] = RestEmbedder(**v)
elif source == "composite":
embedders[k] = CompositeEmbedder(**v)
elif source == "userProvided":
embedders[k] = UserProvidedEmbedder(**v)
else:
Expand Down
62 changes: 62 additions & 0 deletions meilisearch/models/embedders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from enum import Enum
from typing import Any, Dict, Optional, Union

from camel_converter.pydantic_base import CamelBase
Expand All @@ -20,6 +21,24 @@ class Distribution(CamelBase):
sigma: float


class PoolingType(str, Enum):
"""Pooling strategies for HuggingFaceEmbedder.
Attributes
----------
USE_MODEL : str
Use the model's default pooling strategy.
FORCE_MEAN : str
Force mean pooling over the token embeddings.
FORCE_CLS : str
Use the [CLS] token embedding as the sentence representation.
"""

USE_MODEL = "useModel"
FORCE_MEAN = "forceMean"
FORCE_CLS = "forceCls"


class OpenAiEmbedder(CamelBase):
"""OpenAI embedder configuration.
Expand Down Expand Up @@ -79,6 +98,8 @@ class HuggingFaceEmbedder(CamelBase):
Describes the natural distribution of search results
binary_quantized: Optional[bool]
Once set to true, irreversibly converts all vector dimensions to 1-bit values
pooling: Optional[PoolingType]
Configures how individual tokens are merged into a single embedding
"""

source: str = "huggingFace"
Expand All @@ -90,6 +111,7 @@ class HuggingFaceEmbedder(CamelBase):
document_template_max_bytes: Optional[int] = None # Default to 400
distribution: Optional[Distribution] = None
binary_quantized: Optional[bool] = None
pooling: Optional[PoolingType] = PoolingType.USE_MODEL


class OllamaEmbedder(CamelBase):
Expand Down Expand Up @@ -191,13 +213,53 @@ class UserProvidedEmbedder(CamelBase):
binary_quantized: Optional[bool] = None


class CompositeEmbedder(CamelBase):
"""Composite embedder configuration.
Parameters
----------
source: str
The embedder source, must be "composite"
indexing_embedder: Union[
OpenAiEmbedder,
HuggingFaceEmbedder,
OllamaEmbedder,
RestEmbedder,
UserProvidedEmbedder,
]
search_embedder: Union[
OpenAiEmbedder,
HuggingFaceEmbedder,
OllamaEmbedder,
RestEmbedder,
UserProvidedEmbedder,
]"""

source: str = "composite"
search_embedder: Union[
OpenAiEmbedder,
HuggingFaceEmbedder,
OllamaEmbedder,
RestEmbedder,
UserProvidedEmbedder,
]
indexing_embedder: Union[
OpenAiEmbedder,
HuggingFaceEmbedder,
OllamaEmbedder,
RestEmbedder,
UserProvidedEmbedder,
]


# Type alias for the embedder union type
EmbedderType = Union[
OpenAiEmbedder,
HuggingFaceEmbedder,
OllamaEmbedder,
RestEmbedder,
UserProvidedEmbedder,
CompositeEmbedder,
]


Expand Down
17 changes: 17 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,3 +274,20 @@ def new_embedders():
"default": UserProvidedEmbedder(dimensions=1).model_dump(by_alias=True),
"open_ai": OpenAiEmbedder().model_dump(by_alias=True),
}


@fixture
def enable_composite_embedders():
requests.patch(
f"{common.BASE_URL}/experimental-features",
headers={"Authorization": f"Bearer {common.MASTER_KEY}"},
json={"compositeEmbedders": True},
timeout=10,
)
yield
requests.patch(
f"{common.BASE_URL}/experimental-features",
headers={"Authorization": f"Bearer {common.MASTER_KEY}"},
json={"compositeEmbedders": False},
timeout=10,
)
51 changes: 50 additions & 1 deletion tests/settings/test_settings_embedders.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# pylint: disable=redefined-outer-name

from meilisearch.models.embedders import OpenAiEmbedder, UserProvidedEmbedder
import pytest

from meilisearch.models.embedders import (
CompositeEmbedder,
HuggingFaceEmbedder,
OpenAiEmbedder,
PoolingType,
UserProvidedEmbedder,
)


def test_get_default_embedders(empty_index):
Expand Down Expand Up @@ -97,6 +105,7 @@ def test_huggingface_embedder_format(empty_index):
assert embedders.embedders["huggingface"].distribution.mean == 0.5
assert embedders.embedders["huggingface"].distribution.sigma == 0.1
assert embedders.embedders["huggingface"].binary_quantized is False
assert embedders.embedders["huggingface"].pooling is PoolingType.USE_MODEL


def test_ollama_embedder_format(empty_index):
Expand Down Expand Up @@ -183,3 +192,43 @@ def test_user_provided_embedder_format(empty_index):
assert embedders.embedders["user_provided"].distribution.mean == 0.5
assert embedders.embedders["user_provided"].distribution.sigma == 0.1
assert embedders.embedders["user_provided"].binary_quantized is False


@pytest.mark.usefixtures("enable_composite_embedders")
def test_composite_embedder_format(empty_index):
"""Tests that CompositeEmbedder embedder has the required fields and proper format."""
index = empty_index()

embedder = HuggingFaceEmbedder().model_dump(by_alias=True, exclude_none=True)

# create composite embedder
composite_embedder = {
"composite": {
"source": "composite",
"searchEmbedder": embedder,
"indexingEmbedder": embedder,
}
}

response = index.update_embedders(composite_embedder)
update = index.wait_for_task(response.task_uid)
embedders = index.get_embedders()
assert update.status == "succeeded"

assert embedders.embedders["composite"].source == "composite"

# ensure serialization roundtrips nicely
assert isinstance(embedders.embedders["composite"], CompositeEmbedder)
assert isinstance(embedders.embedders["composite"].search_embedder, HuggingFaceEmbedder)
assert isinstance(embedders.embedders["composite"].indexing_embedder, HuggingFaceEmbedder)

# ensure search_embedder has no document_template
assert getattr(embedders.embedders["composite"].search_embedder, "document_template") is None
assert (
getattr(
embedders.embedders["composite"].search_embedder,
"document_template_max_bytes",
)
is None
)
assert getattr(embedders.embedders["composite"].indexing_embedder, "document_template")