1
1
import asyncio
2
+ import dataclasses
2
3
import http
3
4
import logging
4
5
import socket
@@ -117,8 +118,8 @@ def select_subprotocol(ws, subprotocols):
117
118
"server rejected WebSocket connection: HTTP 500" ,
118
119
)
119
120
120
- async def test_process_request (self ):
121
- """Server runs process_request before processing the handshake."""
121
+ async def test_process_request_returns_none (self ):
122
+ """Server runs process_request and continues the handshake."""
122
123
123
124
def process_request (ws , request ):
124
125
self .assertIsInstance (request , Request )
@@ -128,8 +129,8 @@ def process_request(ws, request):
128
129
async with run_client (server ) as client :
129
130
await self .assertEval (client , "ws.process_request_ran" , "True" )
130
131
131
- async def test_async_process_request (self ):
132
- """Server runs async process_request before processing the handshake."""
132
+ async def test_async_process_request_returns_none (self ):
133
+ """Server runs async process_request and continues the handshake."""
133
134
134
135
async def process_request (ws , request ):
135
136
self .assertIsInstance (request , Request )
@@ -139,7 +140,7 @@ async def process_request(ws, request):
139
140
async with run_client (server ) as client :
140
141
await self .assertEval (client , "ws.process_request_ran" , "True" )
141
142
142
- async def test_process_request_abort_handshake (self ):
143
+ async def test_process_request_returns_response (self ):
143
144
"""Server aborts handshake if process_request returns a response."""
144
145
145
146
def process_request (ws , request ):
@@ -154,7 +155,7 @@ def process_request(ws, request):
154
155
"server rejected WebSocket connection: HTTP 403" ,
155
156
)
156
157
157
- async def test_async_process_request_abort_handshake (self ):
158
+ async def test_async_process_request_returns_response (self ):
158
159
"""Server aborts handshake if async process_request returns a response."""
159
160
160
161
async def process_request (ws , request ):
@@ -199,8 +200,8 @@ async def process_request(ws, request):
199
200
"server rejected WebSocket connection: HTTP 500" ,
200
201
)
201
202
202
- async def test_process_response (self ):
203
- """Server runs process_response after processing the handshake."""
203
+ async def test_process_response_returns_none (self ):
204
+ """Server runs process_response but keeps the handshake response ."""
204
205
205
206
def process_response (ws , request , response ):
206
207
self .assertIsInstance (request , Request )
@@ -211,8 +212,8 @@ def process_response(ws, request, response):
211
212
async with run_client (server ) as client :
212
213
await self .assertEval (client , "ws.process_response_ran" , "True" )
213
214
214
- async def test_async_process_response (self ):
215
- """Server runs async process_response after processing the handshake."""
215
+ async def test_async_process_response_returns_none (self ):
216
+ """Server runs async process_response but keeps the handshake response ."""
216
217
217
218
async def process_response (ws , request , response ):
218
219
self .assertIsInstance (request , Request )
@@ -223,29 +224,49 @@ async def process_response(ws, request, response):
223
224
async with run_client (server ) as client :
224
225
await self .assertEval (client , "ws.process_response_ran" , "True" )
225
226
226
- async def test_process_response_override_response (self ):
227
- """Server runs process_response and overrides the handshake response."""
227
+ async def test_process_response_modifies_response (self ):
228
+ """Server runs process_response and modifies the handshake response."""
228
229
229
230
def process_response (ws , request , response ):
230
- response .headers ["X-ProcessResponse-Ran " ] = "true "
231
+ response .headers ["X-ProcessResponse" ] = "OK "
231
232
232
233
async with run_server (process_response = process_response ) as server :
233
234
async with run_client (server ) as client :
234
- self .assertEqual (
235
- client .response .headers ["X-ProcessResponse-Ran" ], "true"
236
- )
235
+ self .assertEqual (client .response .headers ["X-ProcessResponse" ], "OK" )
237
236
238
- async def test_async_process_response_override_response (self ):
239
- """Server runs async process_response and overrides the handshake response."""
237
+ async def test_async_process_response_modifies_response (self ):
238
+ """Server runs async process_response and modifies the handshake response."""
240
239
241
240
async def process_response (ws , request , response ):
242
- response .headers ["X-ProcessResponse-Ran " ] = "true "
241
+ response .headers ["X-ProcessResponse" ] = "OK "
243
242
244
243
async with run_server (process_response = process_response ) as server :
245
244
async with run_client (server ) as client :
246
- self .assertEqual (
247
- client .response .headers ["X-ProcessResponse-Ran" ], "true"
248
- )
245
+ self .assertEqual (client .response .headers ["X-ProcessResponse" ], "OK" )
246
+
247
+ async def test_process_response_replaces_response (self ):
248
+ """Server runs process_response and replaces the handshake response."""
249
+
250
+ def process_response (ws , request , response ):
251
+ headers = response .headers .copy ()
252
+ headers ["X-ProcessResponse" ] = "OK"
253
+ return dataclasses .replace (response , headers = headers )
254
+
255
+ async with run_server (process_response = process_response ) as server :
256
+ async with run_client (server ) as client :
257
+ self .assertEqual (client .response .headers ["X-ProcessResponse" ], "OK" )
258
+
259
+ async def test_async_process_response_replaces_response (self ):
260
+ """Server runs async process_response and replaces the handshake response."""
261
+
262
+ async def process_response (ws , request , response ):
263
+ headers = response .headers .copy ()
264
+ headers ["X-ProcessResponse" ] = "OK"
265
+ return dataclasses .replace (response , headers = headers )
266
+
267
+ async with run_server (process_response = process_response ) as server :
268
+ async with run_client (server ) as client :
269
+ self .assertEqual (client .response .headers ["X-ProcessResponse" ], "OK" )
249
270
250
271
async def test_process_response_raises_exception (self ):
251
272
"""Server returns an error if process_response raises an exception."""
0 commit comments