Skip to content

Commit df957c8

Browse files
committed
Fix and tests for ScopedRateThrottle. Closes #935
1 parent 6cc4fe5 commit df957c8

File tree

2 files changed

+121
-5
lines changed

2 files changed

+121
-5
lines changed

rest_framework/tests/test_throttling.py

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from django.core.cache import cache
88
from django.test.client import RequestFactory
99
from rest_framework.views import APIView
10-
from rest_framework.throttling import UserRateThrottle
10+
from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle
1111
from rest_framework.response import Response
1212

1313

@@ -36,8 +36,6 @@ def get(self, request):
3636

3737

3838
class ThrottlingTests(TestCase):
39-
urls = 'rest_framework.tests.test_throttling'
40-
4139
def setUp(self):
4240
"""
4341
Reset the cache so that no throttles will be active
@@ -141,3 +139,108 @@ def test_next_rate_remains_constant_if_followed(self):
141139
(60, None),
142140
(80, None)
143141
))
142+
143+
144+
class ScopedRateThrottleTests(TestCase):
145+
"""
146+
Tests for ScopedRateThrottle.
147+
"""
148+
149+
def setUp(self):
150+
class XYScopedRateThrottle(ScopedRateThrottle):
151+
TIMER_SECONDS = 0
152+
THROTTLE_RATES = {'x': '3/min', 'y': '1/min'}
153+
timer = lambda self: self.TIMER_SECONDS
154+
155+
class XView(APIView):
156+
throttle_classes = (XYScopedRateThrottle,)
157+
throttle_scope = 'x'
158+
159+
def get(self, request):
160+
return Response('x')
161+
162+
class YView(APIView):
163+
throttle_classes = (XYScopedRateThrottle,)
164+
throttle_scope = 'y'
165+
166+
def get(self, request):
167+
return Response('y')
168+
169+
class UnscopedView(APIView):
170+
throttle_classes = (XYScopedRateThrottle,)
171+
172+
def get(self, request):
173+
return Response('y')
174+
175+
self.throttle_class = XYScopedRateThrottle
176+
self.factory = RequestFactory()
177+
self.x_view = XView.as_view()
178+
self.y_view = YView.as_view()
179+
self.unscoped_view = UnscopedView.as_view()
180+
181+
def increment_timer(self, seconds=1):
182+
self.throttle_class.TIMER_SECONDS += seconds
183+
184+
def test_scoped_rate_throttle(self):
185+
request = self.factory.get('/')
186+
187+
# Should be able to hit x view 3 times per minute.
188+
response = self.x_view(request)
189+
self.assertEqual(200, response.status_code)
190+
191+
self.increment_timer()
192+
response = self.x_view(request)
193+
self.assertEqual(200, response.status_code)
194+
195+
self.increment_timer()
196+
response = self.x_view(request)
197+
self.assertEqual(200, response.status_code)
198+
199+
self.increment_timer()
200+
response = self.x_view(request)
201+
self.assertEqual(429, response.status_code)
202+
203+
# Should be able to hit y view 1 time per minute.
204+
self.increment_timer()
205+
response = self.y_view(request)
206+
self.assertEqual(200, response.status_code)
207+
208+
self.increment_timer()
209+
response = self.y_view(request)
210+
self.assertEqual(429, response.status_code)
211+
212+
# Ensure throttles properly reset by advancing the rest of the minute
213+
self.increment_timer(55)
214+
215+
# Should still be able to hit x view 3 times per minute.
216+
response = self.x_view(request)
217+
self.assertEqual(200, response.status_code)
218+
219+
self.increment_timer()
220+
response = self.x_view(request)
221+
self.assertEqual(200, response.status_code)
222+
223+
self.increment_timer()
224+
response = self.x_view(request)
225+
self.assertEqual(200, response.status_code)
226+
227+
self.increment_timer()
228+
response = self.x_view(request)
229+
self.assertEqual(429, response.status_code)
230+
231+
# Should still be able to hit y view 1 time per minute.
232+
self.increment_timer()
233+
response = self.y_view(request)
234+
self.assertEqual(200, response.status_code)
235+
236+
self.increment_timer()
237+
response = self.y_view(request)
238+
self.assertEqual(429, response.status_code)
239+
240+
def test_unscoped_view_not_throttled(self):
241+
request = self.factory.get('/')
242+
243+
for idx in range(10):
244+
self.increment_timer()
245+
response = self.unscoped_view(request)
246+
self.assertEqual(200, response.status_code)

rest_framework/throttling.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ class SimpleRateThrottle(BaseThrottle):
4040
"""
4141

4242
timer = time.time
43-
settings = api_settings
4443
cache_format = 'throtte_%(scope)s_%(ident)s'
4544
scope = None
45+
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
4646

4747
def __init__(self):
4848
if not getattr(self, 'rate', None):
@@ -68,7 +68,7 @@ def get_rate(self):
6868
raise ImproperlyConfigured(msg)
6969

7070
try:
71-
return self.settings.DEFAULT_THROTTLE_RATES[self.scope]
71+
return self.THROTTLE_RATES[self.scope]
7272
except KeyError:
7373
msg = "No default throttle rate set for '%s' scope" % self.scope
7474
raise ImproperlyConfigured(msg)
@@ -187,6 +187,19 @@ class ScopedRateThrottle(SimpleRateThrottle):
187187
"""
188188
scope_attr = 'throttle_scope'
189189

190+
def __init__(self):
191+
pass
192+
193+
def allow_request(self, request, view):
194+
self.scope = getattr(view, self.scope_attr, None)
195+
196+
if not self.scope:
197+
return True
198+
199+
self.rate = self.get_rate()
200+
self.num_requests, self.duration = self.parse_rate(self.rate)
201+
return super(ScopedRateThrottle, self).allow_request(request, view)
202+
190203
def get_cache_key(self, request, view):
191204
"""
192205
If `view.throttle_scope` is not set, don't apply this throttle.

0 commit comments

Comments
 (0)