Skip to content

Commit 62db7bf

Browse files
committed
Fixing mypy errors in redis/commands/search/query.py
1 parent 6246cba commit 62db7bf

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

redis/commands/search/query.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional, Union
1+
from typing import List, Optional, Tuple, Union
22

33
from redis.commands.search.dialect import DEFAULT_DIALECT
44

@@ -31,7 +31,7 @@ def __init__(self, query_string: str) -> None:
3131
self._with_scores: bool = False
3232
self._scorer: Optional[str] = None
3333
self._filters: List = list()
34-
self._ids: Optional[List[str]] = None
34+
self._ids: Optional[Tuple[str]] = None
3535
self._slop: int = -1
3636
self._timeout: Optional[float] = None
3737
self._in_order: bool = False
@@ -81,7 +81,7 @@ def return_field(
8181
self._return_fields += ("AS", as_field)
8282
return self
8383

84-
def _mk_field_list(self, fields: List[str]) -> List:
84+
def _mk_field_list(self, fields: Optional[Union[List[str], str]]) -> List:
8585
if not fields:
8686
return []
8787
return [fields] if isinstance(fields, str) else list(fields)
@@ -126,7 +126,7 @@ def summarize(
126126

127127
def highlight(
128128
self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None
129-
) -> None:
129+
) -> "Query":
130130
"""
131131
Apply specified markup to matched term(s) within the returned field(s).
132132
@@ -187,16 +187,16 @@ def scorer(self, scorer: str) -> "Query":
187187
self._scorer = scorer
188188
return self
189189

190-
def get_args(self) -> List[str]:
190+
def get_args(self) -> List[Union[str, int, float]]:
191191
"""Format the redis arguments for this query and return them."""
192-
args = [self._query_string]
192+
args: List[Union[str, int, float]] = [self._query_string]
193193
args += self._get_args_tags()
194194
args += self._summarize_fields + self._highlight_fields
195195
args += ["LIMIT", self._offset, self._num]
196196
return args
197197

198-
def _get_args_tags(self) -> List[str]:
199-
args = []
198+
def _get_args_tags(self) -> List[Union[str, int, float]]:
199+
args: List[Union[str, int, float]] = []
200200
if self._no_content:
201201
args.append("NOCONTENT")
202202
if self._fields:
@@ -288,14 +288,14 @@ def with_scores(self) -> "Query":
288288
self._with_scores = True
289289
return self
290290

291-
def limit_fields(self, *fields: List[str]) -> "Query":
291+
def limit_fields(self, *fields: str) -> "Query":
292292
"""
293293
Limit the search to specific TEXT fields only.
294294
295-
- **fields**: A list of strings, case sensitive field names
295+
- **fields**: Each element should be a string, case sensitive field name
296296
from the defined schema.
297297
"""
298-
self._fields = fields
298+
self._fields = list(fields)
299299
return self
300300

301301
def add_filter(self, flt: "Filter") -> "Query":
@@ -340,7 +340,7 @@ def dialect(self, dialect: int) -> "Query":
340340

341341

342342
class Filter:
343-
def __init__(self, keyword: str, field: str, *args: List[str]) -> None:
343+
def __init__(self, keyword: str, field: str, *args: Union[str, float]) -> None:
344344
self.args = [keyword, field] + list(args)
345345

346346

0 commit comments

Comments
 (0)