Skip to content

Commit 7736f40

Browse files
committed
Align SearchFilter behaviour to django.contrib.admin
1 parent 4b747c6 commit 7736f40

File tree

3 files changed

+92
-17
lines changed

3 files changed

+92
-17
lines changed

docs/api-guide/filtering.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,12 @@ This will allow the client to filter the items in the list by making queries suc
213213
You can also perform a related lookup on a ForeignKey or ManyToManyField with the lookup API double-underscore notation:
214214

215215
search_fields = ['username', 'email', 'profile__profession']
216-
216+
217217
For [JSONField][JSONField] and [HStoreField][HStoreField] fields you can filter based on nested values within the data structure using the same double-underscore notation:
218218

219219
search_fields = ['data__breed', 'data__owner__other_pets__0__name']
220220

221-
By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace and/or comma separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched.
221+
By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched.
222222

223223
The search behavior may be restricted by prepending various characters to the `search_fields`.
224224

rest_framework/filters.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@
66
import warnings
77
from functools import reduce
88

9-
from django.core.exceptions import ImproperlyConfigured
9+
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
1010
from django.db import models
1111
from django.db.models.constants import LOOKUP_SEP
1212
from django.template import loader
1313
from django.utils.encoding import force_str
14+
from django.utils.text import smart_split, unescape_string_literal
1415
from django.utils.translation import gettext_lazy as _
1516

1617
from rest_framework import RemovedInDRF317Warning
1718
from rest_framework.compat import coreapi, coreschema
19+
from rest_framework.fields import CharField
1820
from rest_framework.settings import api_settings
1921

2022

@@ -64,18 +66,37 @@ def get_search_fields(self, view, request):
6466
def get_search_terms(self, request):
6567
"""
6668
Search terms are set by a ?search=... query parameter,
67-
and may be comma and/or whitespace delimited.
69+
and may be whitespace delimited.
6870
"""
69-
params = request.query_params.get(self.search_param, '')
70-
params = params.replace('\x00', '') # strip null characters
71-
params = params.replace(',', ' ')
72-
return params.split()
71+
value = request.query_params.get(self.search_param, '')
72+
field = CharField(trim_whitespace=False, allow_blank=True)
73+
return field.run_validation(value)
7374

74-
def construct_search(self, field_name):
75+
def construct_search(self, field_name, queryset):
7576
lookup = self.lookup_prefixes.get(field_name[0])
7677
if lookup:
7778
field_name = field_name[1:]
7879
else:
80+
# Use field_name if it includes a lookup.
81+
opts = queryset.model._meta
82+
lookup_fields = field_name.split(LOOKUP_SEP)
83+
# Go through the fields, following all relations.
84+
prev_field = None
85+
for path_part in lookup_fields:
86+
if path_part == "pk":
87+
path_part = opts.pk.name
88+
try:
89+
field = opts.get_field(path_part)
90+
except FieldDoesNotExist:
91+
# Use valid query lookups.
92+
if prev_field and prev_field.get_lookup(path_part):
93+
return field_name
94+
else:
95+
prev_field = field
96+
if hasattr(field, "path_infos"):
97+
# Update opts to follow the relation.
98+
opts = field.path_infos[-1].to_opts
99+
# Otherwise, use the field with icontains.
79100
lookup = 'icontains'
80101
return LOOKUP_SEP.join([field_name, lookup])
81102

@@ -113,15 +134,17 @@ def filter_queryset(self, request, queryset, view):
113134
return queryset
114135

115136
orm_lookups = [
116-
self.construct_search(str(search_field))
137+
self.construct_search(str(search_field), queryset)
117138
for search_field in search_fields
118139
]
119140

120141
base = queryset
121142
conditions = []
122-
for search_term in search_terms:
143+
for term in smart_split(search_terms):
144+
if term.startswith(('"', "'")) and term[0] == term[-1]:
145+
term = unescape_string_literal(term)
123146
queries = [
124-
models.Q(**{orm_lookup: search_term})
147+
models.Q(**{orm_lookup: term})
125148
for orm_lookup in orm_lookups
126149
]
127150
conditions.append(reduce(operator.or_, queries))
@@ -141,7 +164,7 @@ def to_html(self, request, queryset, view):
141164
return ''
142165

143166
term = self.get_search_terms(request)
144-
term = term[0] if term else ''
167+
term = next(term, '')
145168
context = {
146169
'param': self.search_param,
147170
'term': term

tests/test_filters.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from rest_framework import filters, generics, serializers
1313
from rest_framework.compat import coreschema
14+
from rest_framework.exceptions import ValidationError
1415
from rest_framework.test import APIRequestFactory
1516

1617
factory = APIRequestFactory()
@@ -50,7 +51,8 @@ class Meta:
5051

5152

5253
class SearchFilterTests(TestCase):
53-
def setUp(self):
54+
@classmethod
55+
def setUpTestData(cls):
5456
# Sequence of title/text is:
5557
#
5658
# z abc
@@ -66,6 +68,10 @@ def setUp(self):
6668
)
6769
SearchFilterModel(title=title, text=text).save()
6870

71+
72+
SearchFilterModel(title='A title', text='The long text').save()
73+
SearchFilterModel(title='The title', text='The "text').save()
74+
6975
def test_search(self):
7076
class SearchListView(generics.ListAPIView):
7177
queryset = SearchFilterModel.objects.all()
@@ -177,6 +183,7 @@ class SearchListView(generics.ListAPIView):
177183

178184
request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'})
179185
response = view(request)
186+
print(response.data)
180187
assert response.data == [
181188
{'id': 3, 'title': 'zzz', 'text': 'cde'}
182189
]
@@ -186,9 +193,21 @@ def test_search_field_with_null_characters(self):
186193
request = factory.get('/?search=\0as%00d\x00f')
187194
request = view.initialize_request(request)
188195

189-
terms = filters.SearchFilter().get_search_terms(request)
196+
with self.assertRaises(ValidationError):
197+
filters.SearchFilter().get_search_terms(request)
190198

191-
assert terms == ['asdf']
199+
def test_search_field_with_custom_lookup(self):
200+
class SearchListView(generics.ListAPIView):
201+
queryset = SearchFilterModel.objects.all()
202+
serializer_class = SearchFilterSerializer
203+
filter_backends = (filters.SearchFilter,)
204+
search_fields = ('text__iendswith',)
205+
view = SearchListView.as_view()
206+
request = factory.get('/', {'search': 'c'})
207+
response = view(request)
208+
assert response.data == [
209+
{'id': 1, 'title': 'z', 'text': 'abc'},
210+
]
192211

193212
def test_search_field_with_additional_transforms(self):
194213
from django.test.utils import register_lookup
@@ -225,6 +244,32 @@ def as_sql(self, compiler, connection):
225244
{'id': 2, 'title': 'zz', 'text': 'bcd'},
226245
]
227246

247+
def test_search_field_with_escapes(self):
248+
class SearchListView(generics.ListAPIView):
249+
queryset = SearchFilterModel.objects.all()
250+
serializer_class = SearchFilterSerializer
251+
filter_backends = (filters.SearchFilter,)
252+
search_fields = ('title', 'text',)
253+
view = SearchListView.as_view()
254+
request = factory.get('/', {'search': '"\\\"text"'})
255+
response = view(request)
256+
assert response.data == [
257+
{'id': 12, 'title': 'The title', 'text': 'The "text'},
258+
]
259+
260+
def test_search_field_with_quotes(self):
261+
class SearchListView(generics.ListAPIView):
262+
queryset = SearchFilterModel.objects.all()
263+
serializer_class = SearchFilterSerializer
264+
filter_backends = (filters.SearchFilter,)
265+
search_fields = ('title', 'text',)
266+
view = SearchListView.as_view()
267+
request = factory.get('/', {'search': '"long text"'})
268+
response = view(request)
269+
assert response.data == [
270+
{'id': 11, 'title': 'A title', 'text': 'The long text'},
271+
]
272+
228273

229274
class AttributeModel(models.Model):
230275
label = models.CharField(max_length=32)
@@ -267,6 +312,13 @@ def test_must_call_distinct_restores_meta_for_each_field(self):
267312
["%sattribute__label" % prefix, "%stitle" % prefix]
268313
)
269314

315+
def test_custom_lookup_to_related_model(self):
316+
# In this test case the attribute of the fk model comes first in the
317+
# list of search fields.
318+
filter_ = filters.SearchFilter()
319+
assert 'attribute__label__icontains' == filter_.construct_search('attribute__label', SearchFilterModelFk._meta)
320+
assert 'attribute__label__iendswith' == filter_.construct_search('attribute__label__iendswith', SearchFilterModelFk._meta)
321+
270322

271323
class SearchFilterModelM2M(models.Model):
272324
title = models.CharField(max_length=20)
@@ -368,7 +420,7 @@ class SearchListView(generics.ListAPIView):
368420
search_fields = ('=name', 'entry__headline', '=entry__pub_date__year')
369421

370422
view = SearchListView.as_view()
371-
request = factory.get('/', {'search': 'Lennon,1979'})
423+
request = factory.get('/', {'search': 'Lennon 1979'})
372424
response = view(request)
373425
assert len(response.data) == 1
374426

0 commit comments

Comments
 (0)