Skip to content

Commit 313c36f

Browse files
committed
Merge pull request #2242 from tomchristie/hyperlinked-pk-optimization
Hyperlinked PK optimization.
2 parents 8ad0b83 + 1e336ef commit 313c36f

File tree

4 files changed

+67
-35
lines changed

4 files changed

+67
-35
lines changed

rest_framework/relations.py

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,20 @@ def get_queryset(self):
8484
queryset = queryset.all()
8585
return queryset
8686

87-
def get_iterable(self, instance, source_attrs):
88-
relationship = get_attribute(instance, source_attrs)
89-
return relationship.all() if (hasattr(relationship, 'all')) else relationship
87+
def use_pk_only_optimization(self):
88+
return False
89+
90+
def get_attribute(self, instance):
91+
if self.use_pk_only_optimization() and self.source_attrs:
92+
# Optimized case, return a mock object only containing the pk attribute.
93+
try:
94+
instance = get_attribute(instance, self.source_attrs[:-1])
95+
return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1]))
96+
except AttributeError:
97+
pass
98+
99+
# Standard case, return the object instance.
100+
return get_attribute(instance, self.source_attrs)
90101

91102
@property
92103
def choices(self):
@@ -120,6 +131,9 @@ class PrimaryKeyRelatedField(RelatedField):
120131
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
121132
}
122133

134+
def use_pk_only_optimization(self):
135+
return True
136+
123137
def to_internal_value(self, data):
124138
try:
125139
return self.get_queryset().get(pk=data)
@@ -128,32 +142,6 @@ def to_internal_value(self, data):
128142
except (TypeError, ValueError):
129143
self.fail('incorrect_type', data_type=type(data).__name__)
130144

131-
def get_attribute(self, instance):
132-
# We customize `get_attribute` here for performance reasons.
133-
# For relationships the instance will already have the pk of
134-
# the related object. We return this directly instead of returning the
135-
# object itself, which would require a database lookup.
136-
try:
137-
instance = get_attribute(instance, self.source_attrs[:-1])
138-
return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1]))
139-
except AttributeError:
140-
return get_attribute(instance, self.source_attrs)
141-
142-
def get_iterable(self, instance, source_attrs):
143-
# For consistency with `get_attribute` we're using `serializable_value()`
144-
# here. Typically there won't be any difference, but some custom field
145-
# types might return a non-primitive value for the pk otherwise.
146-
#
147-
# We could try to get smart with `values_list('pk', flat=True)`, which
148-
# would be better in some case, but would actually end up with *more*
149-
# queries if the developer is using `prefetch_related` across the
150-
# relationship.
151-
relationship = super(PrimaryKeyRelatedField, self).get_iterable(instance, source_attrs)
152-
return [
153-
PKOnlyObject(pk=item.serializable_value('pk'))
154-
for item in relationship
155-
]
156-
157145
def to_representation(self, value):
158146
return value.pk
159147

@@ -184,6 +172,9 @@ def __init__(self, view_name=None, **kwargs):
184172

185173
super(HyperlinkedRelatedField, self).__init__(**kwargs)
186174

175+
def use_pk_only_optimization(self):
176+
return self.lookup_field == 'pk'
177+
187178
def get_object(self, view_name, view_args, view_kwargs):
188179
"""
189180
Return the object corresponding to a matched URL.
@@ -285,6 +276,11 @@ def __init__(self, view_name=None, **kwargs):
285276
kwargs['source'] = '*'
286277
super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)
287278

279+
def use_pk_only_optimization(self):
280+
# We have the complete object instance already. We don't need
281+
# to run the 'only get the pk for this relationship' code.
282+
return False
283+
288284

289285
class SlugRelatedField(RelatedField):
290286
"""
@@ -349,7 +345,8 @@ def to_internal_value(self, data):
349345
]
350346

351347
def get_attribute(self, instance):
352-
return self.child_relation.get_iterable(instance, self.source_attrs)
348+
relationship = get_attribute(instance, self.source_attrs)
349+
return relationship.all() if (hasattr(relationship, 'all')) else relationship
353350

354351
def to_representation(self, iterable):
355352
return [

tests/test_relations_hyperlink.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,14 @@ def test_many_to_many_retrieve(self):
8989
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
9090
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
9191
]
92-
self.assertEqual(serializer.data, expected)
92+
with self.assertNumQueries(4):
93+
self.assertEqual(serializer.data, expected)
94+
95+
def test_many_to_many_retrieve_prefetch_related(self):
96+
queryset = ManyToManySource.objects.all().prefetch_related('targets')
97+
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
98+
with self.assertNumQueries(2):
99+
serializer.data
93100

94101
def test_reverse_many_to_many_retrieve(self):
95102
queryset = ManyToManyTarget.objects.all()
@@ -99,7 +106,8 @@ def test_reverse_many_to_many_retrieve(self):
99106
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
100107
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
101108
]
102-
self.assertEqual(serializer.data, expected)
109+
with self.assertNumQueries(4):
110+
self.assertEqual(serializer.data, expected)
103111

104112
def test_many_to_many_update(self):
105113
data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
@@ -197,7 +205,8 @@ def test_foreign_key_retrieve(self):
197205
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
198206
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
199207
]
200-
self.assertEqual(serializer.data, expected)
208+
with self.assertNumQueries(1):
209+
self.assertEqual(serializer.data, expected)
201210

202211
def test_reverse_foreign_key_retrieve(self):
203212
queryset = ForeignKeyTarget.objects.all()
@@ -206,7 +215,8 @@ def test_reverse_foreign_key_retrieve(self):
206215
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
207216
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
208217
]
209-
self.assertEqual(serializer.data, expected)
218+
with self.assertNumQueries(3):
219+
self.assertEqual(serializer.data, expected)
210220

211221
def test_foreign_key_update(self):
212222
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}

tests/test_relations_pk.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,12 @@ def test_many_to_many_retrieve(self):
7171
with self.assertNumQueries(4):
7272
self.assertEqual(serializer.data, expected)
7373

74+
def test_many_to_many_retrieve_prefetch_related(self):
75+
queryset = ManyToManySource.objects.all().prefetch_related('targets')
76+
serializer = ManyToManySourceSerializer(queryset, many=True)
77+
with self.assertNumQueries(2):
78+
serializer.data
79+
7480
def test_reverse_many_to_many_retrieve(self):
7581
queryset = ManyToManyTarget.objects.all()
7682
serializer = ManyToManyTargetSerializer(queryset, many=True)
@@ -188,6 +194,12 @@ def test_reverse_foreign_key_retrieve(self):
188194
with self.assertNumQueries(3):
189195
self.assertEqual(serializer.data, expected)
190196

197+
def test_reverse_foreign_key_retrieve_prefetch_related(self):
198+
queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
199+
serializer = ForeignKeyTargetSerializer(queryset, many=True)
200+
with self.assertNumQueries(2):
201+
serializer.data
202+
191203
def test_foreign_key_update(self):
192204
data = {'id': 1, 'name': 'source-1', 'target': 2}
193205
instance = ForeignKeySource.objects.get(pk=1)

tests/test_relations_slug.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,14 @@ def test_foreign_key_retrieve(self):
5454
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
5555
{'id': 3, 'name': 'source-3', 'target': 'target-1'}
5656
]
57-
self.assertEqual(serializer.data, expected)
57+
with self.assertNumQueries(4):
58+
self.assertEqual(serializer.data, expected)
59+
60+
def test_foreign_key_retrieve_select_related(self):
61+
queryset = ForeignKeySource.objects.all().select_related('target')
62+
serializer = ForeignKeySourceSerializer(queryset, many=True)
63+
with self.assertNumQueries(1):
64+
serializer.data
5865

5966
def test_reverse_foreign_key_retrieve(self):
6067
queryset = ForeignKeyTarget.objects.all()
@@ -65,6 +72,12 @@ def test_reverse_foreign_key_retrieve(self):
6572
]
6673
self.assertEqual(serializer.data, expected)
6774

75+
def test_reverse_foreign_key_retrieve_prefetch_related(self):
76+
queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
77+
serializer = ForeignKeyTargetSerializer(queryset, many=True)
78+
with self.assertNumQueries(2):
79+
serializer.data
80+
6881
def test_foreign_key_update(self):
6982
data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
7083
instance = ForeignKeySource.objects.get(pk=1)

0 commit comments

Comments
 (0)