Skip to content

Commit 535edad

Browse files
committed
update numeric checks to be utility functions
1 parent e4a8aca commit 535edad

File tree

3 files changed

+35
-14
lines changed

3 files changed

+35
-14
lines changed

.vscode/settings.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"python.testing.unittestEnabled": false,
3+
"python.testing.pytestEnabled": true,
4+
}

aredis_om/model/model.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from .. import redis
4242
from ..checks import has_redis_json, has_redisearch
4343
from ..connections import get_redis_connection
44-
from ..util import ASYNC_MODE
44+
from ..util import ASYNC_MODE, has_numeric_inner_type, is_numeric_type
4545
from .encoders import jsonable_encoder
4646
from .render_tree import render_tree
4747
from .token_escaper import TokenEscaper
@@ -406,7 +406,6 @@ class RediSearchFieldTypes(Enum):
406406

407407

408408
# TODO: How to handle Geo fields?
409-
NUMERIC_TYPES = (float, int, decimal.Decimal)
410409
DEFAULT_PAGE_SIZE = 1000
411410

412411

@@ -578,7 +577,7 @@ def resolve_field_type(field: "FieldInfo", op: Operators) -> RediSearchFieldType
578577
)
579578
elif field_type is bool:
580579
return RediSearchFieldTypes.TAG
581-
elif any(issubclass(field_type, t) for t in NUMERIC_TYPES):
580+
elif is_numeric_type(field_type):
582581
# Index numeric Python types as NUMERIC fields, so we can support
583582
# range queries.
584583
return RediSearchFieldTypes.NUMERIC
@@ -1375,14 +1374,6 @@ def outer_type_or_annotation(field: FieldInfo):
13751374
return field.annotation.__args__[0] # type: ignore
13761375

13771376

1378-
def _is_numeric_type(type_: Type[Any]) -> bool:
1379-
args = get_args(type_)
1380-
try:
1381-
return any(issubclass(args[0], t) for t in NUMERIC_TYPES)
1382-
except TypeError:
1383-
return False
1384-
1385-
13861377
def should_index_field(field_info: Union[FieldInfo, PydanticFieldInfo]) -> bool:
13871378
# for vector, full text search, and sortable fields, we always have to index
13881379
# We could require the user to set index=True, but that would be a breaking change
@@ -1813,7 +1804,7 @@ def schema_for_type(cls, name, typ: Any, field_info: PydanticFieldInfo):
18131804
schema = cls.schema_for_type(name, embedded_cls, field_info)
18141805
elif typ is bool:
18151806
schema = f"{name} TAG"
1816-
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
1807+
elif is_numeric_type(typ):
18171808
vector_options: Optional[VectorFieldOptions] = getattr(
18181809
field_info, "vector_options", None
18191810
)
@@ -2012,7 +2003,7 @@ def schema_for_type(
20122003
field_info, "vector_options", None
20132004
)
20142005
try:
2015-
is_vector = vector_options and _is_numeric_type(typ)
2006+
is_vector = vector_options and has_numeric_inner_type(typ)
20162007
except IndexError:
20172008
raise RedisModelError(
20182009
f"Vector field '{name}' must be annotated as a container type"
@@ -2137,7 +2128,7 @@ def schema_for_type(
21372128
schema += " CASESENSITIVE"
21382129
elif typ is bool:
21392130
schema = f"{path} AS {index_field_name} TAG"
2140-
elif any(issubclass(typ, t) for t in NUMERIC_TYPES):
2131+
elif is_numeric_type(typ):
21412132
schema = f"{path} AS {index_field_name} NUMERIC"
21422133
elif issubclass(typ, str):
21432134
if full_text_search is True:

aredis_om/util.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1+
import decimal
12
import inspect
3+
from typing import Any, Type, get_args
24

35

46
def is_async_mode() -> bool:
@@ -10,3 +12,27 @@ async def f() -> None:
1012

1113

1214
ASYNC_MODE = is_async_mode()
15+
16+
NUMERIC_TYPES = (float, int, decimal.Decimal)
17+
18+
19+
def is_numeric_type(type_: Type[Any]) -> bool:
20+
try:
21+
return issubclass(type_, NUMERIC_TYPES)
22+
except TypeError:
23+
return False
24+
25+
26+
def has_numeric_inner_type(type_: Type[Any]) -> bool:
27+
"""
28+
Check if the type has a numeric inner type.
29+
"""
30+
args = get_args(type_)
31+
32+
if not args:
33+
return False
34+
35+
try:
36+
return issubclass(args[0], NUMERIC_TYPES)
37+
except TypeError:
38+
return False

0 commit comments

Comments
 (0)