Skip to content

Commit 71ca330

Browse files
committed
address PR feedback:
- Strategy suffix - Sphinx docstrings - add user agent to EmbeddingService - raise ConflictError - various cleanup
1 parent 6aa6d73 commit 71ca330

File tree

15 files changed

+322
-325
lines changed

15 files changed

+322
-325
lines changed

elasticsearch/helpers/vectorstore/__init__.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,43 +20,43 @@
2020
AsyncEmbeddingService,
2121
)
2222
from elasticsearch.helpers.vectorstore._async.strategies import (
23-
AsyncBM25,
24-
AsyncDenseVector,
25-
AsyncDenseVectorScriptScore,
23+
AsyncBM25Strategy,
24+
AsyncDenseVectorScriptScoreStrategy,
25+
AsyncDenseVectorStrategy,
2626
AsyncRetrievalStrategy,
27-
AsyncSparseVector,
27+
AsyncSparseVectorStrategy,
2828
)
2929
from elasticsearch.helpers.vectorstore._async.vectorstore import AsyncVectorStore
3030
from elasticsearch.helpers.vectorstore._sync.embedding_service import (
3131
ElasticsearchEmbeddings,
3232
EmbeddingService,
3333
)
3434
from elasticsearch.helpers.vectorstore._sync.strategies import (
35-
BM25,
36-
DenseVector,
37-
DenseVectorScriptScore,
35+
BM25Strategy,
36+
DenseVectorScriptScoreStrategy,
37+
DenseVectorStrategy,
3838
RetrievalStrategy,
39-
SparseVector,
39+
SparseVectorStrategy,
4040
)
4141
from elasticsearch.helpers.vectorstore._sync.vectorstore import VectorStore
4242
from elasticsearch.helpers.vectorstore._utils import DistanceMetric
4343

4444
__all__ = [
45-
"BM25",
46-
"DenseVector",
47-
"DenseVectorScriptScore",
45+
"BM25Strategy",
46+
"DenseVectorStrategy",
47+
"DenseVectorScriptScoreStrategy",
4848
"ElasticsearchEmbeddings",
4949
"EmbeddingService",
5050
"RetrievalStrategy",
51-
"SparseVector",
51+
"SparseVectorStrategy",
5252
"VectorStore",
53-
"AsyncBM25",
54-
"AsyncDenseVector",
55-
"AsyncDenseVectorScriptScore",
53+
"AsyncBM25Strategy",
54+
"AsyncDenseVectorStrategy",
55+
"AsyncDenseVectorScriptScoreStrategy",
5656
"AsyncElasticsearchEmbeddings",
5757
"AsyncEmbeddingService",
5858
"AsyncRetrievalStrategy",
59-
"AsyncSparseVector",
59+
"AsyncSparseVectorStrategy",
6060
"AsyncVectorStore",
6161
"DistanceMetric",
6262
]

elasticsearch/helpers/vectorstore/_async/_utils.py

Lines changed: 8 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,21 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18-
from elasticsearch import (
19-
AsyncElasticsearch,
20-
BadRequestError,
21-
ConflictError,
22-
NotFoundError,
23-
)
18+
from elasticsearch import AsyncElasticsearch, BadRequestError, NotFoundError
2419

2520

2621
async def model_must_be_deployed(client: AsyncElasticsearch, model_id: str) -> None:
22+
"""
23+
:raises [NotFoundError]: if the model is neither downloaded nor deployed.
24+
:raises [ConflictError]: if the model is downloaded but not yet deployed.
25+
"""
26+
doc = {"text_field": f"test if the model '{model_id}' is deployed"}
2727
try:
28-
dummy = {"x": "y"}
29-
await client.ml.infer_trained_model(model_id=model_id, docs=[dummy])
30-
except NotFoundError as err:
31-
raise err
32-
except ConflictError as err:
33-
raise NotFoundError(
34-
f"model '{model_id}' not found, please deploy it first",
35-
meta=err.meta,
36-
body=err.body,
37-
) from err
28+
await client.ml.infer_trained_model(model_id=model_id, docs=[doc])
3829
except BadRequestError:
39-
# This error is expected because we do not know the expected document
40-
# shape and just use a dummy doc above.
30+
# The model is deployed but expects a different input field name.
4131
pass
4232

43-
return None
44-
4533

4634
async def model_is_deployed(es_client: AsyncElasticsearch, model_id: str) -> bool:
4735
try:

elasticsearch/helpers/vectorstore/_async/embedding_service.py

Lines changed: 17 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,32 +16,29 @@
1616
# under the License.
1717

1818
from abc import ABC, abstractmethod
19-
from typing import List, Optional
19+
from typing import List
2020

2121
from elasticsearch import AsyncElasticsearch
22+
from elasticsearch._version import __versionstr__ as lib_version
2223

2324

2425
class AsyncEmbeddingService(ABC):
2526
@abstractmethod
2627
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
2728
"""Generate embeddings for a list of documents.
2829
29-
Args:
30-
texts: A list of document strings to generate embeddings for.
30+
:param texts: A list of document strings to generate embeddings for.
3131
32-
Returns:
33-
A list of embeddings, one for each document in the input.
32+
:return: A list of embeddings, one for each document in the input.
3433
"""
3534

3635
@abstractmethod
3736
async def embed_query(self, query: str) -> List[float]:
3837
"""Generate an embedding for a single query text.
3938
40-
Args:
41-
text: The query text to generate an embedding for.
39+
:param text: The query text to generate an embedding for.
4240
43-
Returns:
44-
The embedding for the input query text.
41+
:return: The embedding for the input query text.
4542
"""
4643

4744

@@ -56,31 +53,26 @@ class AsyncElasticsearchEmbeddings(AsyncEmbeddingService):
5653
def __init__(
5754
self,
5855
es_client: AsyncElasticsearch,
59-
user_agent: str,
6056
model_id: str,
6157
input_field: str = "text_field",
62-
num_dimensions: Optional[int] = None,
58+
user_agent: str = f"elasticsearch-py-es/{lib_version}",
6359
):
6460
"""
65-
Args:
66-
agent_header: user agent header specific to the 3rd party integration.
67-
Used for usage tracking in Elastic Cloud.
68-
model_id: The model_id of the model deployed in the Elasticsearch cluster.
69-
input_field: The name of the key for the input text field in the
70-
document. Defaults to 'text_field'.
71-
num_dimensions: The number of embedding dimensions. If None, then dimensions
72-
will be infer from an example inference call.
73-
es_client: Elasticsearch client connection. Alternatively specify the
74-
Elasticsearch connection with the other es_* parameters.
61+
:param agent_header: user agent header specific to the 3rd party integration.
62+
Used for usage tracking in Elastic Cloud.
63+
:param model_id: The model_id of the model deployed in the Elasticsearch cluster.
64+
:param input_field: The name of the key for the input text field in the
65+
document. Defaults to 'text_field'.
66+
:param es_client: Elasticsearch client connection. Alternatively specify the
67+
Elasticsearch connection with the other es_* parameters.
7568
"""
7669
# Add integration-specific usage header for tracking usage in Elastic Cloud.
77-
# client.options preserces existing (non-user-agent) headers.
70+
# client.options preserves existing (non-user-agent) headers.
7871
es_client = es_client.options(headers={"User-Agent": user_agent})
7972

80-
self.client = es_client.ml
73+
self.es_client = es_client
8174
self.model_id = model_id
8275
self.input_field = input_field
83-
self._num_dimensions = num_dimensions
8476

8577
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
8678
result = await self._embedding_func(texts)
@@ -91,7 +83,7 @@ async def embed_query(self, text: str) -> List[float]:
9183
return result[0]
9284

9385
async def _embedding_func(self, texts: List[str]) -> List[List[float]]:
94-
response = await self.client.infer_trained_model(
86+
response = await self.es_client.ml.infer_trained_model(
9587
model_id=self.model_id, docs=[{self.input_field: text} for text in texts]
9688
)
9789
return [doc["predicted_value"] for doc in response["inference_results"]]

elasticsearch/helpers/vectorstore/_async/strategies.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,15 +39,13 @@ def es_query(
3939
Returns the Elasticsearch query body for the given parameters.
4040
The store will execute the query.
4141
42-
Args:
43-
query: The text query. Can be None if query_vector is given.
44-
k: The total number of results to retrieve.
45-
num_candidates: The number of results to fetch initially in knn search.
46-
filter: List of filter clauses to apply to the query.
47-
query_vector: The query vector. Can be None if a query string is given.
48-
49-
Returns:
50-
Dict: The Elasticsearch query body.
42+
:param query: The text query. Can be None if query_vector is given.
43+
:param k: The total number of results to retrieve.
44+
:param num_candidates: The number of results to fetch initially in knn search.
45+
:param filter: List of filter clauses to apply to the query.
46+
:param query_vector: The query vector. Can be None if a query string is given.
47+
48+
:return: The Elasticsearch query body.
5149
"""
5250

5351
@abstractmethod
@@ -61,11 +59,10 @@ def es_mappings_settings(
6159
Create the required index and do necessary preliminary work, like
6260
creating inference pipelines or checking if a required model was deployed.
6361
64-
Args:
65-
client: Elasticsearch client connection.
66-
index_name: The name of the Elasticsearch index to create.
67-
metadata_mapping: Flat dictionary with field and field type pairs that
68-
describe the schema of the metadata.
62+
:param client: Elasticsearch client connection.
63+
:param index_name: The name of the Elasticsearch index to create.
64+
:param metadata_mapping: Flat dictionary with field and field type pairs that
65+
describe the schema of the metadata.
6966
"""
7067

7168
async def before_index_creation(
@@ -74,22 +71,27 @@ async def before_index_creation(
7471
"""
7572
Executes before the index is created. Used for setting up
7673
any required Elasticsearch resources like a pipeline.
74+
Defaults to a no-op.
7775
78-
Args:
79-
client: The Elasticsearch client.
80-
text_field: The field containing the text data in the index.
81-
vector_field: The field containing the vector representations in the index.
76+
:param client: The Elasticsearch client.
77+
:param text_field: The field containing the text data in the index.
78+
:param vector_field: The field containing the vector representations in the index.
8279
"""
8380
pass
8481

8582
def needs_inference(self) -> bool:
8683
"""
87-
TODO
84+
Some retrieval strategies index embedding vectors and allow search by embedding
85+
vector, for example the `DenseVectorStrategy` strategy. Mapping a user input query
86+
string to an embedding vector is called inference. Inference can be applied
87+
in Elasticsearch (using a `model_id`) or outside of Elasticsearch (using an
88+
`EmbeddingService` defined on the `VectorStore`). In the latter case,
89+
this method has to return True.
8890
"""
8991
return False
9092

9193

92-
class AsyncSparseVector(AsyncRetrievalStrategy):
94+
class AsyncSparseVectorStrategy(AsyncRetrievalStrategy):
9395
"""Sparse retrieval strategy using the `text_expansion` processor."""
9496

9597
def __init__(self, model_id: str = ".elser_model_2"):
@@ -176,7 +178,7 @@ async def before_index_creation(
176178
)
177179

178180

179-
class AsyncDenseVector(AsyncRetrievalStrategy):
181+
class AsyncDenseVectorStrategy(AsyncRetrievalStrategy):
180182
"""K-nearest-neighbors retrieval."""
181183

182184
def __init__(
@@ -189,7 +191,7 @@ def __init__(
189191
):
190192
if hybrid and not text_field:
191193
raise ValueError(
192-
"to enable hybrid you have to specify a text_field (for BM25 matching)"
194+
"to enable hybrid you have to specify a text_field (for BM25Strategy matching)"
193195
)
194196

195197
self.distance = distance
@@ -304,7 +306,7 @@ def needs_inference(self) -> bool:
304306
return not self.model_id
305307

306308

307-
class AsyncDenseVectorScriptScore(AsyncRetrievalStrategy):
309+
class AsyncDenseVectorScriptScoreStrategy(AsyncRetrievalStrategy):
308310
"""Exact nearest neighbors retrieval using the `script_score` query."""
309311

310312
def __init__(self, distance: DistanceMetric = DistanceMetric.COSINE) -> None:
@@ -383,7 +385,7 @@ def needs_inference(self) -> bool:
383385
return True
384386

385387

386-
class AsyncBM25(AsyncRetrievalStrategy):
388+
class AsyncBM25Strategy(AsyncRetrievalStrategy):
387389
def __init__(
388390
self,
389391
k1: Optional[float] = None,

0 commit comments

Comments
 (0)