Skip to content

Commit 65dbae9

Browse files
committed
Added support for ADDSCORES modifier (#3329)
* Added support for ADDSCORES modifier * Fixed codestyle issues * More codestyle fixes * Updated test cases and testing image to represent latest * Codestyle issues * Added handling for dict responses
1 parent 3d24a3d commit 65dbae9

File tree

4 files changed

+64
-1
lines changed

4 files changed

+64
-1
lines changed

.github/workflows/integration.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ env:
2828
# this speeds up coverage with Python 3.12: https://github.com/nedbat/coveragepy/issues/1665
2929
COVERAGE_CORE: sysmon
3030
REDIS_IMAGE: redis:7.4-rc2
31-
REDIS_STACK_IMAGE: redis/redis-stack-server:7.4.0-rc2
31+
REDIS_STACK_IMAGE: redis/redis-stack-server:latest
3232

3333
jobs:
3434
dependency-audit:

redis/commands/search/aggregation.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def __init__(self, query: str = "*") -> None:
111111
self._verbatim = False
112112
self._cursor = []
113113
self._dialect = None
114+
self._add_scores = False
114115

115116
def load(self, *fields: List[str]) -> "AggregateRequest":
116117
"""
@@ -292,6 +293,13 @@ def with_schema(self) -> "AggregateRequest":
292293
self._with_schema = True
293294
return self
294295

296+
def add_scores(self) -> "AggregateRequest":
297+
"""
298+
If set, includes the score as an ordinary field of the row.
299+
"""
300+
self._add_scores = True
301+
return self
302+
295303
def verbatim(self) -> "AggregateRequest":
296304
self._verbatim = True
297305
return self
@@ -315,6 +323,9 @@ def build_args(self) -> List[str]:
315323
if self._verbatim:
316324
ret.append("VERBATIM")
317325

326+
if self._add_scores:
327+
ret.append("ADDSCORES")
328+
318329
if self._cursor:
319330
ret += self._cursor
320331

tests/test_asyncio/test_search.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1531,6 +1531,32 @@ async def test_withsuffixtrie(decoded_r: redis.Redis):
15311531
assert "WITHSUFFIXTRIE" in info["attributes"][0]["flags"]
15321532

15331533

1534+
@pytest.mark.redismod
1535+
@skip_ifmodversion_lt("2.10.05", "search")
1536+
async def test_aggregations_add_scores(decoded_r: redis.Redis):
1537+
assert await decoded_r.ft().create_index(
1538+
(
1539+
TextField("name", sortable=True, weight=5.0),
1540+
NumericField("age", sortable=True),
1541+
)
1542+
)
1543+
1544+
assert await decoded_r.hset("doc1", mapping={"name": "bar", "age": "25"})
1545+
assert await decoded_r.hset("doc2", mapping={"name": "foo", "age": "19"})
1546+
1547+
req = aggregations.AggregateRequest("*").add_scores()
1548+
res = await decoded_r.ft().aggregate(req)
1549+
1550+
if isinstance(res, dict):
1551+
assert len(res["results"]) == 2
1552+
assert res["results"][0]["extra_attributes"] == {"__score": "0.2"}
1553+
assert res["results"][1]["extra_attributes"] == {"__score": "0.2"}
1554+
else:
1555+
assert len(res.rows) == 2
1556+
assert res.rows[0] == ["__score", "0.2"]
1557+
assert res.rows[1] == ["__score", "0.2"]
1558+
1559+
15341560
@pytest.mark.redismod
15351561
@skip_if_redis_enterprise()
15361562
async def test_search_commands_in_pipeline(decoded_r: redis.Redis):

tests/test_search.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1441,6 +1441,32 @@ def test_aggregations_filter(client):
14411441
assert res["results"][1]["extra_attributes"] == {"age": "25"}
14421442

14431443

1444+
@pytest.mark.redismod
1445+
@skip_ifmodversion_lt("2.10.05", "search")
1446+
def test_aggregations_add_scores(client):
1447+
client.ft().create_index(
1448+
(
1449+
TextField("name", sortable=True, weight=5.0),
1450+
NumericField("age", sortable=True),
1451+
)
1452+
)
1453+
1454+
client.hset("doc1", mapping={"name": "bar", "age": "25"})
1455+
client.hset("doc2", mapping={"name": "foo", "age": "19"})
1456+
1457+
req = aggregations.AggregateRequest("*").add_scores()
1458+
res = client.ft().aggregate(req)
1459+
1460+
if isinstance(res, dict):
1461+
assert len(res["results"]) == 2
1462+
assert res["results"][0]["extra_attributes"] == {"__score": "0.2"}
1463+
assert res["results"][1]["extra_attributes"] == {"__score": "0.2"}
1464+
else:
1465+
assert len(res.rows) == 2
1466+
assert res.rows[0] == ["__score", "0.2"]
1467+
assert res.rows[1] == ["__score", "0.2"]
1468+
1469+
14441470
@pytest.mark.redismod
14451471
@skip_ifmodversion_lt("2.0.0", "search")
14461472
def test_index_definition(client):

0 commit comments

Comments
 (0)