Skip to content

Commit 5833e00

Browse files
Add a knn method to elasticsearch_dsl.search.Search
1 parent f0c5045 commit 5833e00

File tree

2 files changed

+119
-1
lines changed

2 files changed

+119
-1
lines changed

elasticsearch_dsl/search.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from .aggs import A, AggBase
2525
from .connections import get_connection
2626
from .exceptions import IllegalOperation
27-
from .query import Bool, Q
27+
from .query import Bool, Q, Query
2828
from .response import Hit, Response
2929
from .utils import AttrDict, DslBase, recursive_to_dict
3030

@@ -319,6 +319,7 @@ def __init__(self, **kwargs):
319319
self.aggs = AggsProxy(self)
320320
self._sort = []
321321
self._collapse = {}
322+
self._knn = []
322323
self._source = None
323324
self._highlight = {}
324325
self._highlight_opts = {}
@@ -407,6 +408,7 @@ def _clone(self):
407408

408409
s._response_class = self._response_class
409410
s._collapse = self._collapse.copy()
411+
s._knn = [knn.copy() for knn in self._knn]
410412
s._sort = self._sort[:]
411413
s._source = copy.copy(self._source) if self._source is not None else None
412414
s._highlight = self._highlight.copy()
@@ -445,6 +447,10 @@ def update_from_dict(self, d):
445447
self.aggs._params = {
446448
"aggs": {name: A(value) for (name, value) in aggs.items()}
447449
}
450+
if "knn" in d:
451+
self._knn = d.pop("knn")
452+
if isinstance(self._knn, dict):
453+
self._knn = [self._knn]
448454
if "collapse" in d:
449455
self._collapse = d.pop("collapse")
450456
if "sort" in d:
@@ -494,6 +500,60 @@ def script_fields(self, **kwargs):
494500
s._script_fields.update(kwargs)
495501
return s
496502

503+
def knn(
504+
self,
505+
field,
506+
k,
507+
num_candidates,
508+
query_vector=None,
509+
query_vector_builder=None,
510+
filter=None,
511+
similarity=None,
512+
):
513+
"""
514+
Add a k-nearest neighbor (kNN) search.
515+
516+
:arg field: the name of the vector field to search against
517+
:arg k: number of nearest neighbors to return as top hits
518+
:arg num_candidates: number of nearest neighbor candidates to consider per shard
519+
:arg query_vector: the vector to search for
520+
:arg query_vector_builder: A dictionary indicating how to build a query vector
521+
:arg filter: query to filter the documents that can match
522+
:arg similarity: the minimum similarity required for a document to be considered a match, as a float value
523+
524+
Example::
525+
526+
s = Search()
527+
s = s.knn(field='embedding', k=5, num_candidates=10, query_vector=vector,
528+
filter=Q('term', category='blog')))
529+
"""
530+
s = self._clone()
531+
s._knn.append(
532+
{
533+
"field": field,
534+
"k": k,
535+
"num_candidates": num_candidates,
536+
}
537+
)
538+
if query_vector is None and query_vector_builder is None:
539+
raise ValueError("one of query_vector and query_vector_builder is required")
540+
if query_vector is not None and query_vector_builder is not None:
541+
raise ValueError(
542+
"only one of query_vector and query_vector_builder must be given"
543+
)
544+
if query_vector is not None:
545+
s._knn[-1]["query_vector"] = query_vector
546+
if query_vector_builder is not None:
547+
s._knn[-1]["query_vector_builder"] = query_vector_builder
548+
if filter is not None:
549+
if isinstance(filter, Query):
550+
s._knn[-1]["filter"] = filter.to_dict()
551+
else:
552+
s._knn[-1]["filter"] = filter
553+
if similarity is not None:
554+
s._knn[-1]["similarity"] = similarity
555+
return s
556+
497557
def source(self, fields=None, **kwargs):
498558
"""
499559
Selectively control how the _source field is returned.
@@ -677,6 +737,12 @@ def to_dict(self, count=False, **kwargs):
677737
if self.query:
678738
d["query"] = self.query.to_dict()
679739

740+
if self._knn:
741+
if len(self._knn) == 1:
742+
d["knn"] = self._knn[0]
743+
else:
744+
d["knn"] = self._knn
745+
680746
# count request doesn't care for sorting and other things
681747
if not count:
682748
if self.post_filter:

tests/test_search.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,58 @@ class MyDocument(Document):
234234
assert s._doc_type_map == {}
235235

236236

237+
def test_knn():
238+
s = search.Search()
239+
240+
with raises(TypeError):
241+
s.knn()
242+
with raises(TypeError):
243+
s.knn("field")
244+
with raises(TypeError):
245+
s.knn("field", 5)
246+
with raises(ValueError):
247+
s.knn("field", 5, 100)
248+
with raises(ValueError):
249+
s.knn("field", 5, 100, query_vector=[1, 2, 3], query_vector_builder={})
250+
251+
s = s.knn("field", 5, 100, query_vector=[1, 2, 3])
252+
assert {
253+
"knn": {
254+
"field": "field",
255+
"k": 5,
256+
"num_candidates": 100,
257+
"query_vector": [1, 2, 3],
258+
}
259+
} == s.to_dict()
260+
261+
s = s.knn(
262+
k=4,
263+
num_candidates=40,
264+
field="name",
265+
query_vector_builder={
266+
"text_embedding": {"model_id": "foo", "model_text": "search text"}
267+
},
268+
)
269+
assert {
270+
"knn": [
271+
{
272+
"field": "field",
273+
"k": 5,
274+
"num_candidates": 100,
275+
"query_vector": [1, 2, 3],
276+
},
277+
{
278+
"field": "name",
279+
"k": 4,
280+
"num_candidates": 40,
281+
"query_vector_builder": {
282+
"text_embedding": {"model_id": "foo", "model_text": "search text"}
283+
},
284+
},
285+
]
286+
} == s.to_dict()
287+
288+
237289
def test_sort():
238290
s = search.Search()
239291
s = s.sort("fielda", "-fieldb")

0 commit comments

Comments
 (0)