Skip to content

Commit 1e1ad58

Browse files
committed
Added Websocket class and SWITCHING_PROTOCOLS_101
1 parent ebb7ca7 commit 1e1ad58

File tree

3 files changed

+287
-2
lines changed

3 files changed

+287
-2
lines changed

adafruit_httpserver/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,13 @@
5959
JSONResponse,
6060
Redirect,
6161
SSEResponse,
62+
Websocket,
6263
)
6364
from .route import Route
6465
from .server import Server
6566
from .status import (
6667
Status,
68+
SWITCHING_PROTOCOLS_101,
6769
OK_200,
6870
CREATED_201,
6971
ACCEPTED_202,

adafruit_httpserver/response.py

Lines changed: 283 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616

1717
import os
1818
import json
19-
from errno import EAGAIN, ECONNRESET
19+
from binascii import b2a_base64
20+
import hashlib
21+
from errno import EAGAIN, ECONNRESET, ETIMEDOUT, ENOTCONN
2022

2123
from .exceptions import (
2224
BackslashInPathError,
@@ -25,7 +27,13 @@
2527
)
2628
from .mime_types import MIMETypes
2729
from .request import Request
28-
from .status import Status, OK_200, TEMPORARY_REDIRECT_307, PERMANENT_REDIRECT_308
30+
from .status import (
31+
Status,
32+
SWITCHING_PROTOCOLS_101,
33+
OK_200,
34+
TEMPORARY_REDIRECT_307,
35+
PERMANENT_REDIRECT_308,
36+
)
2937
from .headers import Headers
3038

3139

@@ -497,3 +505,276 @@ def close(self):
497505
"""
498506
self._send_bytes(self._request.connection, b"event: close\n")
499507
self._close_connection()
508+
509+
510+
class Websocket(Response): # pylint: disable=too-few-public-methods
511+
"""
512+
Specialized version of `Response` class for creating a websocket connection.
513+
514+
Allows two way communication between the client and the server.
515+
516+
Keep in mind, that in order to send and receive messages, the socket must be kept open.
517+
This means that you have to store the response object somewhere, so you can send events
518+
to it and close it later.
519+
520+
**It is very important to close the connection manually, it will not be done automatically.**
521+
522+
Example::
523+
524+
ws = None
525+
526+
@server.route(path, method)
527+
def route_func(request: Request):
528+
529+
# Store the response object somewhere in global scope
530+
global ws
531+
ws = Websocket(request)
532+
533+
return ws
534+
535+
...
536+
537+
# Receive message from client
538+
message = ws.receive()
539+
540+
# Later, when you want to send an event
541+
ws.send_message("Simple message")
542+
543+
# Close the connection
544+
ws.close()
545+
"""
546+
547+
GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
548+
FIN = 0b10000000 # FIN bit indicating the final fragment
549+
550+
# opcodes
551+
CONT = 0 # Continuation frame, TODO: Currently not supported
552+
TEXT = 1 # Frame contains UTF-8 text
553+
BINARY = 2 # Frame contains binary data
554+
CLOSE = 8 # Frame closes the connection
555+
PING = 9 # Frame is a ping, expecting a pong
556+
PONG = 10 # Frame is a pong, in response to a ping
557+
558+
@staticmethod
559+
def _check_request_initiates_handshake(request: Request):
560+
if any(
561+
[
562+
"websocket" not in request.headers.get("Upgrade", "").lower(),
563+
"upgrade" not in request.headers.get("Connection", "").lower(),
564+
"Sec-WebSocket-Key" not in request.headers,
565+
]
566+
):
567+
raise ValueError("Request does not initiate websocket handshake")
568+
569+
@staticmethod
570+
def _process_sec_websocket_key(request: Request) -> str:
571+
key = request.headers.get("Sec-WebSocket-Key")
572+
573+
if key is None:
574+
raise ValueError("Request does not have Sec-WebSocket-Key header")
575+
576+
response_key = hashlib.new('sha1', key.encode())
577+
response_key.update(Websocket.GUID)
578+
579+
return b2a_base64(response_key.digest()).strip().decode()
580+
581+
def __init__( # pylint: disable=too-many-arguments
582+
self,
583+
request: Request,
584+
headers: Union[Headers, Dict[str, str]] = None,
585+
buffer_size: int = 1024,
586+
) -> None:
587+
"""
588+
:param Request request: Request object
589+
:param Headers headers: Headers to be sent with the response.
590+
:param int buffer_size: Size of the buffer used to send and receive messages.
591+
"""
592+
self._check_request_initiates_handshake(request)
593+
594+
sec_accept_key = self._process_sec_websocket_key(request)
595+
596+
super().__init__(
597+
request=request,
598+
status=SWITCHING_PROTOCOLS_101,
599+
headers=headers,
600+
)
601+
self._headers.setdefault("Upgrade", "websocket")
602+
self._headers.setdefault("Connection", "Upgrade")
603+
self._headers.setdefault("Sec-WebSocket-Accept", sec_accept_key)
604+
self._headers.setdefault("Content-Type", None)
605+
self._buffer_size = buffer_size
606+
self.closed = False
607+
608+
request.connection.setblocking(False)
609+
610+
611+
@staticmethod
612+
def _parse_frame_header(header):
613+
fin = header[0] & Websocket.FIN
614+
opcode = header[0] & 0b00001111
615+
has_mask = header[1] & 0b10000000
616+
length = header[1] & 0b01111111
617+
618+
if length == 0b01111110:
619+
length = -2
620+
elif length == 0b01111111:
621+
length = -8
622+
623+
return fin, opcode, has_mask, length
624+
625+
def _read_frame(self):
626+
buffer = bytearray(self._buffer_size)
627+
628+
header_length = self._request.connection.recv_into(buffer, 2)
629+
header_bytes = buffer[:header_length]
630+
631+
fin, opcode, has_mask, length = self._parse_frame_header(header_bytes)
632+
633+
# TODO: Handle continuation frames, currently not supported
634+
if fin != Websocket.FIN and opcode == Websocket.CONT:
635+
return Websocket.CONT, None
636+
637+
payload = bytes()
638+
if fin == Websocket.FIN and opcode == Websocket.CLOSE:
639+
return Websocket.CLOSE, payload
640+
641+
if length < 0:
642+
length = self._request.connection.recv_into(buffer, -length)
643+
length = int.from_bytes(buffer[:length], 'big')
644+
645+
if has_mask:
646+
mask_length = self._request.connection.recv_into(buffer, 4)
647+
mask = buffer[:mask_length]
648+
649+
while 0 < length:
650+
payload_length = self._request.connection.recv_into(buffer, length)
651+
payload += buffer[:min(payload_length, length)]
652+
length -= min(payload_length, length)
653+
654+
if has_mask:
655+
payload = bytes(x ^ mask[i % 4] for i, x in enumerate(payload))
656+
657+
return opcode, payload
658+
659+
def _handle_frame(self, opcode: int, payload: bytes):
660+
# TODO: Handle continuation frames, currently not supported
661+
if opcode == Websocket.CONT:
662+
return None
663+
664+
if opcode == Websocket.CLOSE:
665+
self.close()
666+
return None
667+
668+
if opcode == Websocket.PONG:
669+
return None
670+
elif opcode == Websocket.PING:
671+
self.send_message(payload, Websocket.PONG)
672+
return payload
673+
674+
try:
675+
payload = payload.decode() if opcode == Websocket.TEXT else payload
676+
except UnicodeError as error:
677+
print("Payload UnicodeError: ", error, payload)
678+
pass
679+
680+
return payload
681+
682+
def receive(self, fail_silently: bool = False) -> Union[str, bytes, None]:
683+
"""
684+
Receive a message from the client.
685+
686+
:param bool fail_silently: If True, no error will be raised if the connection is closed.
687+
"""
688+
if self.closed:
689+
if fail_silently:
690+
return None
691+
raise RuntimeError("Websocket connection is closed, cannot receive messages")
692+
693+
try:
694+
opcode, payload = self._read_frame()
695+
frame_data = self._handle_frame(opcode, payload)
696+
697+
return frame_data
698+
except OSError as error:
699+
if error.errno == EAGAIN: # No messages available
700+
return None
701+
if error.errno == ETIMEDOUT: # Connection timed out
702+
return None
703+
if error.errno == ENOTCONN: # Client disconnected without closing connection
704+
self.close()
705+
return None
706+
raise error
707+
708+
@staticmethod
709+
def _prepare_frame(opcode: int, message: bytes) -> bytearray:
710+
frame = bytearray()
711+
712+
frame.append(Websocket.FIN | opcode) # Setting FIN bit
713+
714+
payload_length = len(message)
715+
716+
# Message under 126 bytes, use 1 byte for length
717+
if payload_length < 126:
718+
frame.append(payload_length)
719+
720+
# Message between 126 and 65535 bytes, use 2 bytes for length
721+
elif payload_length < 65536:
722+
frame.append(126)
723+
frame.extend(payload_length.to_bytes(2, 'big'))
724+
725+
# Message over 65535 bytes, use 8 bytes for length
726+
else:
727+
frame.append(127)
728+
frame.extend(payload_length.to_bytes(8, 'big'))
729+
730+
frame.extend(message)
731+
return frame
732+
733+
def send_message(
734+
self,
735+
message: Union[str, bytes],
736+
opcode: int = None,
737+
fail_silently: bool = False
738+
):
739+
"""
740+
Send a message to the client.
741+
742+
:param str message: Message to be sent.
743+
:param int opcode: Opcode of the message. Defaults to TEXT if message is a string and
744+
BINARY for bytes.
745+
:param bool fail_silently: If True, no error will be raised if the connection is closed.
746+
"""
747+
if self.closed:
748+
if fail_silently:
749+
return None
750+
raise RuntimeError("Websocket connection is closed, cannot send message")
751+
752+
determined_opcode = opcode or (
753+
Websocket.TEXT if isinstance(message, str) else Websocket.BINARY
754+
)
755+
756+
if determined_opcode == Websocket.TEXT:
757+
message = message.encode()
758+
759+
frame = self._prepare_frame(determined_opcode, message)
760+
761+
try:
762+
self._send_bytes(self._request.connection, frame)
763+
except BrokenPipeError as error:
764+
if fail_silently:
765+
return None
766+
raise error
767+
768+
def _send(self) -> None:
769+
self._send_headers()
770+
771+
def close(self):
772+
"""
773+
Close the connection.
774+
775+
**Always call this method when you are done sending events.**
776+
"""
777+
if not self.closed:
778+
self.send_message(b'', Websocket.CLOSE, fail_silently=True)
779+
self._close_connection()
780+
self.closed = True

adafruit_httpserver/status.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def __eq__(self, other: "Status"):
3131
return self.code == other.code and self.text == other.text
3232

3333

34+
SWITCHING_PROTOCOLS_101 = Status(101, "Switching Protocols")
35+
3436
OK_200 = Status(200, "OK")
3537

3638
CREATED_201 = Status(201, "Created")

0 commit comments

Comments
 (0)