|
1 |
| -from typing import List, Optional, Union |
| 1 | +from typing import List, Optional, Tuple, Union |
2 | 2 |
|
3 | 3 | from redis.commands.search.dialect import DEFAULT_DIALECT
|
4 | 4 |
|
@@ -31,7 +31,7 @@ def __init__(self, query_string: str) -> None:
|
31 | 31 | self._with_scores: bool = False
|
32 | 32 | self._scorer: Optional[str] = None
|
33 | 33 | self._filters: List = list()
|
34 |
| - self._ids: Optional[List[str]] = None |
| 34 | + self._ids: Optional[Tuple[str]] = None |
35 | 35 | self._slop: int = -1
|
36 | 36 | self._timeout: Optional[float] = None
|
37 | 37 | self._in_order: bool = False
|
@@ -81,7 +81,7 @@ def return_field(
|
81 | 81 | self._return_fields += ("AS", as_field)
|
82 | 82 | return self
|
83 | 83 |
|
84 |
| - def _mk_field_list(self, fields: List[str]) -> List: |
| 84 | + def _mk_field_list(self, fields: Optional[Union[List[str], str]]) -> List: |
85 | 85 | if not fields:
|
86 | 86 | return []
|
87 | 87 | return [fields] if isinstance(fields, str) else list(fields)
|
@@ -126,7 +126,7 @@ def summarize(
|
126 | 126 |
|
127 | 127 | def highlight(
|
128 | 128 | self, fields: Optional[List[str]] = None, tags: Optional[List[str]] = None
|
129 |
| - ) -> None: |
| 129 | + ) -> "Query": |
130 | 130 | """
|
131 | 131 | Apply specified markup to matched term(s) within the returned field(s).
|
132 | 132 |
|
@@ -187,16 +187,16 @@ def scorer(self, scorer: str) -> "Query":
|
187 | 187 | self._scorer = scorer
|
188 | 188 | return self
|
189 | 189 |
|
190 |
| - def get_args(self) -> List[str]: |
| 190 | + def get_args(self) -> List[Union[str, int, float]]: |
191 | 191 | """Format the redis arguments for this query and return them."""
|
192 |
| - args = [self._query_string] |
| 192 | + args: List[Union[str, int, float]] = [self._query_string] |
193 | 193 | args += self._get_args_tags()
|
194 | 194 | args += self._summarize_fields + self._highlight_fields
|
195 | 195 | args += ["LIMIT", self._offset, self._num]
|
196 | 196 | return args
|
197 | 197 |
|
198 |
| - def _get_args_tags(self) -> List[str]: |
199 |
| - args = [] |
| 198 | + def _get_args_tags(self) -> List[Union[str, int, float]]: |
| 199 | + args: List[Union[str, int, float]] = [] |
200 | 200 | if self._no_content:
|
201 | 201 | args.append("NOCONTENT")
|
202 | 202 | if self._fields:
|
@@ -288,14 +288,14 @@ def with_scores(self) -> "Query":
|
288 | 288 | self._with_scores = True
|
289 | 289 | return self
|
290 | 290 |
|
291 |
| - def limit_fields(self, *fields: List[str]) -> "Query": |
| 291 | + def limit_fields(self, *fields: str) -> "Query": |
292 | 292 | """
|
293 | 293 | Limit the search to specific TEXT fields only.
|
294 | 294 |
|
295 |
| - - **fields**: A list of strings, case sensitive field names |
| 295 | + - **fields**: Each element should be a string, case sensitive field name |
296 | 296 | from the defined schema.
|
297 | 297 | """
|
298 |
| - self._fields = fields |
| 298 | + self._fields = list(fields) |
299 | 299 | return self
|
300 | 300 |
|
301 | 301 | def add_filter(self, flt: "Filter") -> "Query":
|
@@ -340,7 +340,7 @@ def dialect(self, dialect: int) -> "Query":
|
340 | 340 |
|
341 | 341 |
|
342 | 342 | class Filter:
|
343 |
| - def __init__(self, keyword: str, field: str, *args: List[str]) -> None: |
| 343 | + def __init__(self, keyword: str, field: str, *args: Union[str, float]) -> None: |
344 | 344 | self.args = [keyword, field] + list(args)
|
345 | 345 |
|
346 | 346 |
|
|
0 commit comments