|
1 | 1 | from django.conf.urls import url
|
2 | 2 | from django.contrib.auth.models import User
|
| 3 | +from django.http import HttpRequest |
3 | 4 | from django.test import override_settings
|
4 | 5 |
|
5 | 6 | from rest_framework.authentication import TokenAuthentication
|
6 | 7 | from rest_framework.authtoken.models import Token
|
| 8 | +from rest_framework.request import is_form_media_type |
| 9 | +from rest_framework.response import Response |
7 | 10 | from rest_framework.test import APITestCase
|
8 | 11 | from rest_framework.views import APIView
|
9 | 12 |
|
| 13 | + |
| 14 | +class PostView(APIView): |
| 15 | + def post(self, request): |
| 16 | + return Response(data=request.data, status=200) |
| 17 | + |
| 18 | + |
10 | 19 | urlpatterns = [
|
11 |
| - url(r'^$', APIView.as_view(authentication_classes=(TokenAuthentication,))), |
| 20 | + url(r'^auth$', APIView.as_view(authentication_classes=(TokenAuthentication,))), |
| 21 | + url(r'^post$', PostView.as_view()), |
12 | 22 | ]
|
13 | 23 |
|
14 | 24 |
|
15 |
| -class MyMiddleware(object): |
| 25 | +class RequestUserMiddleware(object): |
| 26 | + def __init__(self, get_response): |
| 27 | + self.get_response = get_response |
16 | 28 |
|
17 |
| - def process_response(self, request, response): |
| 29 | + def __call__(self, request): |
| 30 | + response = self.get_response(request) |
18 | 31 | assert hasattr(request, 'user'), '`user` is not set on request'
|
19 |
| - assert request.user.is_authenticated(), '`user` is not authenticated' |
| 32 | + assert request.user.is_authenticated, '`user` is not authenticated' |
| 33 | + |
| 34 | + return response |
| 35 | + |
| 36 | + |
| 37 | +class RequestPOSTMiddleware(object): |
| 38 | + def __init__(self, get_response): |
| 39 | + self.get_response = get_response |
| 40 | + |
| 41 | + def __call__(self, request): |
| 42 | + assert isinstance(request, HttpRequest) |
| 43 | + |
| 44 | + # Parse body with underlying Django request |
| 45 | + request.body |
| 46 | + |
| 47 | + # Process request with DRF view |
| 48 | + response = self.get_response(request) |
| 49 | + |
| 50 | + # Ensure request.POST is set as appropriate |
| 51 | + if is_form_media_type(request.content_type): |
| 52 | + assert request.POST == {'foo': ['bar']} |
| 53 | + else: |
| 54 | + assert request.POST == {} |
| 55 | + |
20 | 56 | return response
|
21 | 57 |
|
22 | 58 |
|
23 | 59 | @override_settings(ROOT_URLCONF='tests.test_middleware')
|
24 | 60 | class TestMiddleware(APITestCase):
|
| 61 | + |
| 62 | + @override_settings(MIDDLEWARE=('tests.test_middleware.RequestUserMiddleware',)) |
25 | 63 | def test_middleware_can_access_user_when_processing_response(self):
|
26 | 64 | user = User. objects. create_user( 'john', '[email protected]', 'password')
|
27 | 65 | key = 'abcd1234'
|
28 | 66 | Token.objects.create(key=key, user=user)
|
29 | 67 |
|
30 |
| - with self.settings( |
31 |
| - MIDDLEWARE_CLASSES=('tests.test_middleware.MyMiddleware',) |
32 |
| - ): |
33 |
| - auth = 'Token ' + key |
34 |
| - self.client.get('/', HTTP_AUTHORIZATION=auth) |
| 68 | + self.client.get('/auth', HTTP_AUTHORIZATION='Token %s' % key) |
| 69 | + |
| 70 | + @override_settings(MIDDLEWARE=('tests.test_middleware.RequestPOSTMiddleware',)) |
| 71 | + def test_middleware_can_access_request_post_when_processing_response(self): |
| 72 | + response = self.client.post('/post', {'foo': 'bar'}) |
| 73 | + assert response.status_code == 200 |
| 74 | + |
| 75 | + response = self.client.post('/post', {'foo': 'bar'}, format='json') |
| 76 | + assert response.status_code == 200 |
0 commit comments