Skip to content

[Backport 8.x] Correct typing hints for the FunctionScore query #1961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions elasticsearch_dsl/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,9 +612,9 @@ class FunctionScore(Query):

name = "function_score"
_param_defs = {
"functions": {"type": "score_function", "multi": True},
"query": {"type": "query"},
"filter": {"type": "query"},
"functions": {"type": "score_function", "multi": True},
}

def __init__(
Expand All @@ -623,11 +623,7 @@ def __init__(
boost_mode: Union[
Literal["multiply", "replace", "sum", "avg", "max", "min"], "DefaultType"
] = DEFAULT,
functions: Union[
Sequence["types.FunctionScoreContainer"],
Sequence[Dict[str, Any]],
"DefaultType",
] = DEFAULT,
functions: Union[Sequence[ScoreFunction], "DefaultType"] = DEFAULT,
max_boost: Union[float, "DefaultType"] = DEFAULT,
min_score: Union[float, "DefaultType"] = DEFAULT,
query: Union[Query, "DefaultType"] = DEFAULT,
Expand Down
70 changes: 1 addition & 69 deletions elasticsearch_dsl/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from elastic_transport.client_utils import DEFAULT, DefaultType

from elasticsearch_dsl import Query, function
from elasticsearch_dsl import Query
from elasticsearch_dsl.document_base import InstrumentedField
from elasticsearch_dsl.utils import AttrDict

Expand Down Expand Up @@ -688,74 +688,6 @@ def __init__(
super().__init__(kwargs)


class FunctionScoreContainer(AttrDict[Any]):
"""
:arg exp: Function that scores a document with a exponential decay,
depending on the distance of a numeric field value of the document
from an origin.
:arg gauss: Function that scores a document with a normal decay,
depending on the distance of a numeric field value of the document
from an origin.
:arg linear: Function that scores a document with a linear decay,
depending on the distance of a numeric field value of the document
from an origin.
:arg field_value_factor: Function allows you to use a field from a
document to influence the score. It’s similar to using the
script_score function, however, it avoids the overhead of
scripting.
:arg random_score: Generates scores that are uniformly distributed
from 0 up to but not including 1. In case you want scores to be
reproducible, it is possible to provide a `seed` and `field`.
:arg script_score: Enables you to wrap another query and customize the
scoring of it optionally with a computation derived from other
numeric field values in the doc using a script expression.
:arg filter:
:arg weight:
"""

exp: Union[function.DecayFunction, DefaultType]
gauss: Union[function.DecayFunction, DefaultType]
linear: Union[function.DecayFunction, DefaultType]
field_value_factor: Union[function.FieldValueFactorScore, DefaultType]
random_score: Union[function.RandomScore, DefaultType]
script_score: Union[function.ScriptScore, DefaultType]
filter: Union[Query, DefaultType]
weight: Union[float, DefaultType]

def __init__(
self,
*,
exp: Union[function.DecayFunction, DefaultType] = DEFAULT,
gauss: Union[function.DecayFunction, DefaultType] = DEFAULT,
linear: Union[function.DecayFunction, DefaultType] = DEFAULT,
field_value_factor: Union[
function.FieldValueFactorScore, DefaultType
] = DEFAULT,
random_score: Union[function.RandomScore, DefaultType] = DEFAULT,
script_score: Union[function.ScriptScore, DefaultType] = DEFAULT,
filter: Union[Query, DefaultType] = DEFAULT,
weight: Union[float, DefaultType] = DEFAULT,
**kwargs: Any,
):
if exp is not DEFAULT:
kwargs["exp"] = exp
if gauss is not DEFAULT:
kwargs["gauss"] = gauss
if linear is not DEFAULT:
kwargs["linear"] = linear
if field_value_factor is not DEFAULT:
kwargs["field_value_factor"] = field_value_factor
if random_score is not DEFAULT:
kwargs["random_score"] = random_score
if script_score is not DEFAULT:
kwargs["script_score"] = script_score
if filter is not DEFAULT:
kwargs["filter"] = filter
if weight is not DEFAULT:
kwargs["weight"] = weight
super().__init__(kwargs)


class FuzzyQuery(AttrDict[Any]):
"""
:arg value: (required) Term you wish to find in the provided field.
Expand Down
27 changes: 27 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,33 @@ def test_function_score_to_dict() -> None:
assert d == q.to_dict()


def test_function_score_class_based_to_dict() -> None:
q = query.FunctionScore(
query=query.Match(title="python"),
functions=[
function.RandomScore(),
function.FieldValueFactor(
field="comment_count",
filter=query.Term(tags="python"),
),
],
)

d = {
"function_score": {
"query": {"match": {"title": "python"}},
"functions": [
{"random_score": {}},
{
"filter": {"term": {"tags": "python"}},
"field_value_factor": {"field": "comment_count"},
},
],
}
}
assert d == q.to_dict()


def test_function_score_with_single_function() -> None:
d = {
"function_score": {
Expand Down
6 changes: 6 additions & 0 deletions utils/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,12 @@ def get_python_type(self, schema_type, for_response=False):
):
# QueryContainer maps to the DSL's Query class
return "Query", {"type": "query"}
elif (
type_name["namespace"] == "_types.query_dsl"
and type_name["name"] == "FunctionScoreContainer"
):
# FunctionScoreContainer maps to the DSL's ScoreFunction class
return "ScoreFunction", {"type": "score_function"}
elif (
type_name["namespace"] == "_types.aggregations"
and type_name["name"] == "Buckets"
Expand Down
1 change: 0 additions & 1 deletion utils/templates/query.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ class {{ k.name }}({{ parent }}):
shortcut property. Until the code generator can support shortcut
properties directly that solution is added here #}
"filter": {"type": "query"},
"functions": {"type": "score_function", "multi": True},
{% endif %}
}
{% endif %}
Expand Down
2 changes: 1 addition & 1 deletion utils/templates/types.py.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ from typing import Any, Dict, Literal, Mapping, Sequence, Union
from elastic_transport.client_utils import DEFAULT, DefaultType

from elasticsearch_dsl.document_base import InstrumentedField
from elasticsearch_dsl import function, Query
from elasticsearch_dsl import Query
from elasticsearch_dsl.utils import AttrDict

PipeSeparatedFlags = str
Expand Down
Loading