Skip to content

Commit 1f0f0ad

Browse files
committed
Align SearchFilter behaviour to django.contrib.admin
1 parent 4b747c6 commit 1f0f0ad

File tree

3 files changed

+91
-17
lines changed

3 files changed

+91
-17
lines changed

docs/api-guide/filtering.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,12 +213,12 @@ This will allow the client to filter the items in the list by making queries suc
213213
You can also perform a related lookup on a ForeignKey or ManyToManyField with the lookup API double-underscore notation:
214214

215215
search_fields = ['username', 'email', 'profile__profession']
216-
216+
217217
For [JSONField][JSONField] and [HStoreField][HStoreField] fields you can filter based on nested values within the data structure using the same double-underscore notation:
218218

219219
search_fields = ['data__breed', 'data__owner__other_pets__0__name']
220220

221-
By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace and/or comma separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched.
221+
By default, searches will use case-insensitive partial matches. The search parameter may contain multiple search terms, which should be whitespace separated. If multiple search terms are used then objects will be returned in the list only if all the provided terms are matched.
222222

223223
The search behavior may be restricted by prepending various characters to the `search_fields`.
224224

rest_framework/filters.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,17 @@
66
import warnings
77
from functools import reduce
88

9-
from django.core.exceptions import ImproperlyConfigured
9+
from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
1010
from django.db import models
1111
from django.db.models.constants import LOOKUP_SEP
1212
from django.template import loader
1313
from django.utils.encoding import force_str
14+
from django.utils.text import smart_split, unescape_string_literal
1415
from django.utils.translation import gettext_lazy as _
1516

1617
from rest_framework import RemovedInDRF317Warning
1718
from rest_framework.compat import coreapi, coreschema
19+
from rest_framework.fields import CharField
1820
from rest_framework.settings import api_settings
1921

2022

@@ -64,18 +66,37 @@ def get_search_fields(self, view, request):
6466
def get_search_terms(self, request):
6567
"""
6668
Search terms are set by a ?search=... query parameter,
67-
and may be comma and/or whitespace delimited.
69+
and may be whitespace delimited.
6870
"""
69-
params = request.query_params.get(self.search_param, '')
70-
params = params.replace('\x00', '') # strip null characters
71-
params = params.replace(',', ' ')
72-
return params.split()
71+
value = request.query_params.get(self.search_param, '')
72+
field = CharField(trim_whitespace=False, allow_blank=True)
73+
return field.run_validation(value)
7374

74-
def construct_search(self, field_name):
75+
def construct_search(self, field_name, queryset):
7576
lookup = self.lookup_prefixes.get(field_name[0])
7677
if lookup:
7778
field_name = field_name[1:]
7879
else:
80+
# Use field_name if it includes a lookup.
81+
opts = queryset.model._meta
82+
lookup_fields = field_name.split(LOOKUP_SEP)
83+
# Go through the fields, following all relations.
84+
prev_field = None
85+
for path_part in lookup_fields:
86+
if path_part == "pk":
87+
path_part = opts.pk.name
88+
try:
89+
field = opts.get_field(path_part)
90+
except FieldDoesNotExist:
91+
# Use valid query lookups.
92+
if prev_field and prev_field.get_lookup(path_part):
93+
return field_name
94+
else:
95+
prev_field = field
96+
if hasattr(field, "path_infos"):
97+
# Update opts to follow the relation.
98+
opts = field.path_infos[-1].to_opts
99+
# Otherwise, use the field with icontains.
79100
lookup = 'icontains'
80101
return LOOKUP_SEP.join([field_name, lookup])
81102

@@ -113,15 +134,17 @@ def filter_queryset(self, request, queryset, view):
113134
return queryset
114135

115136
orm_lookups = [
116-
self.construct_search(str(search_field))
137+
self.construct_search(str(search_field), queryset)
117138
for search_field in search_fields
118139
]
119140

120141
base = queryset
121142
conditions = []
122-
for search_term in search_terms:
143+
for term in smart_split(search_terms):
144+
if term.startswith(('"', "'")) and term[0] == term[-1]:
145+
term = unescape_string_literal(term)
123146
queries = [
124-
models.Q(**{orm_lookup: search_term})
147+
models.Q(**{orm_lookup: term})
125148
for orm_lookup in orm_lookups
126149
]
127150
conditions.append(reduce(operator.or_, queries))
@@ -141,7 +164,7 @@ def to_html(self, request, queryset, view):
141164
return ''
142165

143166
term = self.get_search_terms(request)
144-
term = term[0] if term else ''
167+
term = next(term, '')
145168
context = {
146169
'param': self.search_param,
147170
'term': term

tests/test_filters.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from rest_framework import filters, generics, serializers
1313
from rest_framework.compat import coreschema
14+
from rest_framework.exceptions import ValidationError
1415
from rest_framework.test import APIRequestFactory
1516

1617
factory = APIRequestFactory()
@@ -50,7 +51,8 @@ class Meta:
5051

5152

5253
class SearchFilterTests(TestCase):
53-
def setUp(self):
54+
@classmethod
55+
def setUpTestData(cls):
5456
# Sequence of title/text is:
5557
#
5658
# z abc
@@ -66,6 +68,9 @@ def setUp(self):
6668
)
6769
SearchFilterModel(title=title, text=text).save()
6870

71+
SearchFilterModel(title='A title', text='The long text').save()
72+
SearchFilterModel(title='The title', text='The "text').save()
73+
6974
def test_search(self):
7075
class SearchListView(generics.ListAPIView):
7176
queryset = SearchFilterModel.objects.all()
@@ -177,6 +182,7 @@ class SearchListView(generics.ListAPIView):
177182

178183
request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'})
179184
response = view(request)
185+
print(response.data)
180186
assert response.data == [
181187
{'id': 3, 'title': 'zzz', 'text': 'cde'}
182188
]
@@ -186,9 +192,21 @@ def test_search_field_with_null_characters(self):
186192
request = factory.get('/?search=\0as%00d\x00f')
187193
request = view.initialize_request(request)
188194

189-
terms = filters.SearchFilter().get_search_terms(request)
195+
with self.assertRaises(ValidationError):
196+
filters.SearchFilter().get_search_terms(request)
190197

191-
assert terms == ['asdf']
198+
def test_search_field_with_custom_lookup(self):
199+
class SearchListView(generics.ListAPIView):
200+
queryset = SearchFilterModel.objects.all()
201+
serializer_class = SearchFilterSerializer
202+
filter_backends = (filters.SearchFilter,)
203+
search_fields = ('text__iendswith',)
204+
view = SearchListView.as_view()
205+
request = factory.get('/', {'search': 'c'})
206+
response = view(request)
207+
assert response.data == [
208+
{'id': 1, 'title': 'z', 'text': 'abc'},
209+
]
192210

193211
def test_search_field_with_additional_transforms(self):
194212
from django.test.utils import register_lookup
@@ -225,6 +243,32 @@ def as_sql(self, compiler, connection):
225243
{'id': 2, 'title': 'zz', 'text': 'bcd'},
226244
]
227245

246+
def test_search_field_with_escapes(self):
247+
class SearchListView(generics.ListAPIView):
248+
queryset = SearchFilterModel.objects.all()
249+
serializer_class = SearchFilterSerializer
250+
filter_backends = (filters.SearchFilter,)
251+
search_fields = ('title', 'text',)
252+
view = SearchListView.as_view()
253+
request = factory.get('/', {'search': '"\\\"text"'})
254+
response = view(request)
255+
assert response.data == [
256+
{'id': 12, 'title': 'The title', 'text': 'The "text'},
257+
]
258+
259+
def test_search_field_with_quotes(self):
260+
class SearchListView(generics.ListAPIView):
261+
queryset = SearchFilterModel.objects.all()
262+
serializer_class = SearchFilterSerializer
263+
filter_backends = (filters.SearchFilter,)
264+
search_fields = ('title', 'text',)
265+
view = SearchListView.as_view()
266+
request = factory.get('/', {'search': '"long text"'})
267+
response = view(request)
268+
assert response.data == [
269+
{'id': 11, 'title': 'A title', 'text': 'The long text'},
270+
]
271+
228272

229273
class AttributeModel(models.Model):
230274
label = models.CharField(max_length=32)
@@ -267,6 +311,13 @@ def test_must_call_distinct_restores_meta_for_each_field(self):
267311
["%sattribute__label" % prefix, "%stitle" % prefix]
268312
)
269313

314+
def test_custom_lookup_to_related_model(self):
315+
# In this test case the attribute of the fk model comes first in the
316+
# list of search fields.
317+
filter_ = filters.SearchFilter()
318+
assert 'attribute__label__icontains' == filter_.construct_search('attribute__label', SearchFilterModelFk._meta)
319+
assert 'attribute__label__iendswith' == filter_.construct_search('attribute__label__iendswith', SearchFilterModelFk._meta)
320+
270321

271322
class SearchFilterModelM2M(models.Model):
272323
title = models.CharField(max_length=20)
@@ -368,7 +419,7 @@ class SearchListView(generics.ListAPIView):
368419
search_fields = ('=name', 'entry__headline', '=entry__pub_date__year')
369420

370421
view = SearchListView.as_view()
371-
request = factory.get('/', {'search': 'Lennon,1979'})
422+
request = factory.get('/', {'search': 'Lennon 1979'})
372423
response = view(request)
373424
assert len(response.data) == 1
374425

0 commit comments

Comments
 (0)