Skip to content

Commit b45ffd5

Browse files
committed
dynamically set index based on other fields that require indexes
1 parent 23213af commit b45ffd5

File tree

3 files changed

+32
-14
lines changed

3 files changed

+32
-14
lines changed

aredis_om/model/model.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,19 +1290,20 @@ def __new__(cls, name, bases, attrs, **kwargs): # noqa C901
12901290
)
12911291
new_class.Meta = new_class._meta
12921292

1293-
if new_class.model_config.get("index", None) is True:
1293+
is_indexed = kwargs.get("index", None) is True
1294+
1295+
if is_indexed and new_class.model_config.get("index", None) is True:
12941296
raise RedisModelError(
12951297
f"{new_class.__name__} cannot be indexed, only one model can be indexed in an inheritance tree"
12961298
)
12971299

1300+
new_class.model_config["index"] = is_indexed
1301+
12981302
# Create proxies for each model field so that we can use the field
12991303
# in queries, like Model.get(Model.field_name == 1)
13001304
# Only set if the model is has index=True
1301-
is_indexed = kwargs.get("index", None) is True
1302-
new_class.model_config["index"] = is_indexed
1303-
13041305
for field_name, field in new_class.model_fields.items():
1305-
if field.__class__ is PydanticFieldInfo:
1306+
if type(field) is PydanticFieldInfo:
13061307
field = FieldInfo(**field._attributes_set)
13071308
setattr(new_class, field_name, field)
13081309

@@ -1370,6 +1371,15 @@ def outer_type_or_annotation(field: FieldInfo):
13701371
else:
13711372
return field.annotation.__args__[0] # type: ignore
13721373

1374+
def should_index_field(field_info: FieldInfo) -> bool:
1375+
# for vector, full text search, and sortable fields, we always have to index
1376+
# We could require the user to set index=True, but that would be a breaking change
1377+
return (
1378+
getattr(field_info, "index", False) is True
1379+
or getattr(field_info, "vector_options", None) is not None
1380+
or getattr(field_info, "full_text_search", False) is True
1381+
or getattr(field_info, "sortable", False) is True
1382+
)
13731383

13741384
class RedisModel(BaseModel, abc.ABC, metaclass=ModelMeta):
13751385
pk: Optional[str] = Field(
@@ -1736,7 +1746,7 @@ def schema_for_fields(cls):
17361746
else:
17371747
redisearch_field = cls.schema_for_type(name, _type, field_info)
17381748
schema_parts.append(redisearch_field)
1739-
elif getattr(field_info, "index", None) is True:
1749+
elif should_index_field(field_info):
17401750
schema_parts.append(cls.schema_for_type(name, _type, field_info))
17411751
elif is_subscripted_type:
17421752
# Ignore subscripted types (usually containers!) that we don't
@@ -1945,10 +1955,7 @@ def schema_for_type(
19451955
field_info: PydanticFieldInfo,
19461956
parent_type: Optional[Any] = None,
19471957
) -> str:
1948-
should_index = (
1949-
getattr(field_info, "index", False) is True
1950-
or getattr(field_info, "vector_options", None) is not None
1951-
)
1958+
should_index = should_index_field(field_info)
19521959
is_container_type = is_supported_container_type(typ)
19531960
parent_is_container_type = is_supported_container_type(parent_type)
19541961
parent_is_model = False

tests/test_hash_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@ class Member(BaseHashModel, index=True):
5353
last_name: str = Field(index=True)
5454
email: str = Field(index=True)
5555
join_date: datetime.date
56-
age: int = Field(index=True, sortable=True)
57-
bio: str = Field(index=True, full_text_search=True)
56+
age: int = Field(sortable=True)
57+
bio: str = Field(full_text_search=True)
5858

5959
class Meta:
6060
model_key_prefix = "member"

tests/test_json_model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ class Member(BaseJsonModel, index=True):
7474
email: Optional[EmailStr] = Field(index=True, default=None)
7575
join_date: datetime.date
7676
age: Optional[PositiveInt] = Field(index=True, default=None)
77-
bio: Optional[str] = Field(index=True, full_text_search=True, default="")
77+
bio: Optional[str] = Field(full_text_search=True, default="")
7878

7979
# Creates an embedded model.
8080
address: Address
@@ -1316,10 +1316,21 @@ class Model(JsonModel, index=True):
13161316

13171317
with pytest.raises(RedisModelError):
13181318

1319-
class Child(Model):
1319+
class Child(Model, index=True):
13201320
pass
13211321

13221322

1323+
@py_test_mark_asyncio
1324+
async def test_model_inherited_from_indexed_model():
1325+
class Model(JsonModel, index=True):
1326+
name: str = "Steve"
1327+
1328+
class Child(Model):
1329+
pass
1330+
1331+
assert issubclass(Child, Model)
1332+
1333+
13231334
@py_test_mark_asyncio
13241335
async def test_non_indexed_model_raises_error_on_save():
13251336
class Model(JsonModel):

0 commit comments

Comments
 (0)