Skip to content

Commit 1a247fa

Browse files
qcoumespquentin
andauthored
feat: Add a collapse method to elasticsearch_dsl.search.Search (#1649)
* feat: Add a `collapse` method to `elasticsearch_dsl.search.Search` * Fix collapse() in the middle of the chain --------- Co-authored-by: Quentin Pradet <[email protected]>
1 parent 0e94780 commit 1a247fa

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

elasticsearch_dsl/search.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __init__(self, using="default", index=None, doc_type=None, extra=None):
120120

121121
self._doc_type = []
122122
self._doc_type_map = {}
123+
self._collapse = {}
123124
if isinstance(doc_type, (tuple, list)):
124125
self._doc_type.extend(doc_type)
125126
elif isinstance(doc_type, collections.abc.Mapping):
@@ -293,6 +294,7 @@ def _clone(self):
293294
s = self.__class__(
294295
using=self._using, index=self._index, doc_type=self._doc_type
295296
)
297+
s._collapse = self._collapse.copy()
296298
s._doc_type_map = self._doc_type_map.copy()
297299
s._extra = self._extra.copy()
298300
s._params = self._params.copy()
@@ -318,6 +320,7 @@ def __init__(self, **kwargs):
318320

319321
self.aggs = AggsProxy(self)
320322
self._sort = []
323+
self._collapse = {}
321324
self._source = None
322325
self._highlight = {}
323326
self._highlight_opts = {}
@@ -568,6 +571,27 @@ def sort(self, *keys):
568571
s._sort.append(k)
569572
return s
570573

574+
def collapse(self, field=None, inner_hits=None, max_concurrent_group_searches=None):
575+
"""
576+
Add collapsing information to the search request.
577+
If called without providing ``field``, it will remove all collapse
578+
requirements, otherwise it will replace them with the provided
579+
arguments.
580+
The API returns a copy of the Search object and can thus be chained.
581+
"""
582+
s = self._clone()
583+
s._collapse = {}
584+
585+
if field is None:
586+
return s
587+
588+
s._collapse["field"] = field
589+
if inner_hits:
590+
s._collapse["inner_hits"] = inner_hits
591+
if max_concurrent_group_searches:
592+
s._collapse["max_concurrent_group_searches"] = max_concurrent_group_searches
593+
return s
594+
571595
def highlight_options(self, **kwargs):
572596
"""
573597
Update the global highlighting options used for this request. For
@@ -663,6 +687,9 @@ def to_dict(self, count=False, **kwargs):
663687
if self._sort:
664688
d["sort"] = self._sort
665689

690+
if self._collapse:
691+
d["collapse"] = self._collapse
692+
666693
d.update(recursive_to_dict(self._extra))
667694

668695
if self._source not in (None, {}):

tests/test_search.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,38 @@ def test_sort_by_score():
256256
s.sort("-_score")
257257

258258

259+
def test_collapse():
260+
s = search.Search()
261+
262+
inner_hits = {"name": "most_recent", "size": 5, "sort": [{"@timestamp": "desc"}]}
263+
s = s.collapse("user.id", inner_hits=inner_hits, max_concurrent_group_searches=4)
264+
265+
assert {
266+
"field": "user.id",
267+
"inner_hits": {
268+
"name": "most_recent",
269+
"size": 5,
270+
"sort": [{"@timestamp": "desc"}],
271+
},
272+
"max_concurrent_group_searches": 4,
273+
} == s._collapse
274+
assert {
275+
"collapse": {
276+
"field": "user.id",
277+
"inner_hits": {
278+
"name": "most_recent",
279+
"size": 5,
280+
"sort": [{"@timestamp": "desc"}],
281+
},
282+
"max_concurrent_group_searches": 4,
283+
}
284+
} == s.to_dict()
285+
286+
s = s.collapse()
287+
assert {} == s._collapse
288+
assert search.Search().to_dict() == s.to_dict()
289+
290+
259291
def test_slice():
260292
s = search.Search()
261293
assert {"from": 3, "size": 7} == s[3:10].to_dict()
@@ -305,6 +337,7 @@ def test_complex_example():
305337
s.query("match", title="python")
306338
.query(~Q("match", title="ruby"))
307339
.filter(Q("term", category="meetup") | Q("term", category="conference"))
340+
.collapse("user_id")
308341
.post_filter("terms", tags=["prague", "czech"])
309342
.script_fields(more_attendees="doc['attendees'].value + 42")
310343
)
@@ -342,6 +375,7 @@ def test_complex_example():
342375
"aggs": {"avg_attendees": {"avg": {"field": "attendees"}}},
343376
}
344377
},
378+
"collapse": {"field": "user_id"},
345379
"highlight": {
346380
"order": "score",
347381
"fields": {"title": {"fragment_size": 50}, "body": {"fragment_size": 50}},

0 commit comments

Comments
 (0)