|
16 | 16 |
|
17 | 17 | import os
|
18 | 18 | import json
|
19 |
| -from errno import EAGAIN, ECONNRESET |
| 19 | +from binascii import b2a_base64 |
| 20 | +import hashlib |
| 21 | +from errno import EAGAIN, ECONNRESET, ETIMEDOUT, ENOTCONN |
20 | 22 |
|
21 | 23 | from .exceptions import (
|
22 | 24 | BackslashInPathError,
|
|
25 | 27 | )
|
26 | 28 | from .mime_types import MIMETypes
|
27 | 29 | from .request import Request
|
28 |
| -from .status import Status, OK_200, TEMPORARY_REDIRECT_307, PERMANENT_REDIRECT_308 |
| 30 | +from .status import ( |
| 31 | + Status, |
| 32 | + SWITCHING_PROTOCOLS_101, |
| 33 | + OK_200, |
| 34 | + TEMPORARY_REDIRECT_307, |
| 35 | + PERMANENT_REDIRECT_308, |
| 36 | +) |
29 | 37 | from .headers import Headers
|
30 | 38 |
|
31 | 39 |
|
@@ -497,3 +505,276 @@ def close(self):
|
497 | 505 | """
|
498 | 506 | self._send_bytes(self._request.connection, b"event: close\n")
|
499 | 507 | self._close_connection()
|
| 508 | + |
| 509 | + |
| 510 | +class Websocket(Response): # pylint: disable=too-few-public-methods |
| 511 | + """ |
| 512 | + Specialized version of `Response` class for creating a websocket connection. |
| 513 | +
|
| 514 | + Allows two way communication between the client and the server. |
| 515 | +
|
| 516 | + Keep in mind, that in order to send and receive messages, the socket must be kept open. |
| 517 | + This means that you have to store the response object somewhere, so you can send events |
| 518 | + to it and close it later. |
| 519 | +
|
| 520 | + **It is very important to close the connection manually, it will not be done automatically.** |
| 521 | +
|
| 522 | + Example:: |
| 523 | +
|
| 524 | + ws = None |
| 525 | +
|
| 526 | + @server.route(path, method) |
| 527 | + def route_func(request: Request): |
| 528 | +
|
| 529 | + # Store the response object somewhere in global scope |
| 530 | + global ws |
| 531 | + ws = Websocket(request) |
| 532 | +
|
| 533 | + return ws |
| 534 | +
|
| 535 | + ... |
| 536 | +
|
| 537 | + # Receive message from client |
| 538 | + message = ws.receive() |
| 539 | +
|
| 540 | + # Later, when you want to send an event |
| 541 | + ws.send_message("Simple message") |
| 542 | +
|
| 543 | + # Close the connection |
| 544 | + ws.close() |
| 545 | + """ |
| 546 | + |
| 547 | + GUID = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11" |
| 548 | + FIN = 0b10000000 # FIN bit indicating the final fragment |
| 549 | + |
| 550 | + # opcodes |
| 551 | + CONT = 0 # Continuation frame, TODO: Currently not supported |
| 552 | + TEXT = 1 # Frame contains UTF-8 text |
| 553 | + BINARY = 2 # Frame contains binary data |
| 554 | + CLOSE = 8 # Frame closes the connection |
| 555 | + PING = 9 # Frame is a ping, expecting a pong |
| 556 | + PONG = 10 # Frame is a pong, in response to a ping |
| 557 | + |
| 558 | + @staticmethod |
| 559 | + def _check_request_initiates_handshake(request: Request): |
| 560 | + if any( |
| 561 | + [ |
| 562 | + "websocket" not in request.headers.get("Upgrade", "").lower(), |
| 563 | + "upgrade" not in request.headers.get("Connection", "").lower(), |
| 564 | + "Sec-WebSocket-Key" not in request.headers, |
| 565 | + ] |
| 566 | + ): |
| 567 | + raise ValueError("Request does not initiate websocket handshake") |
| 568 | + |
| 569 | + @staticmethod |
| 570 | + def _process_sec_websocket_key(request: Request) -> str: |
| 571 | + key = request.headers.get("Sec-WebSocket-Key") |
| 572 | + |
| 573 | + if key is None: |
| 574 | + raise ValueError("Request does not have Sec-WebSocket-Key header") |
| 575 | + |
| 576 | + response_key = hashlib.new('sha1', key.encode()) |
| 577 | + response_key.update(Websocket.GUID) |
| 578 | + |
| 579 | + return b2a_base64(response_key.digest()).strip().decode() |
| 580 | + |
| 581 | + def __init__( # pylint: disable=too-many-arguments |
| 582 | + self, |
| 583 | + request: Request, |
| 584 | + headers: Union[Headers, Dict[str, str]] = None, |
| 585 | + buffer_size: int = 1024, |
| 586 | + ) -> None: |
| 587 | + """ |
| 588 | + :param Request request: Request object |
| 589 | + :param Headers headers: Headers to be sent with the response. |
| 590 | + :param int buffer_size: Size of the buffer used to send and receive messages. |
| 591 | + """ |
| 592 | + self._check_request_initiates_handshake(request) |
| 593 | + |
| 594 | + sec_accept_key = self._process_sec_websocket_key(request) |
| 595 | + |
| 596 | + super().__init__( |
| 597 | + request=request, |
| 598 | + status=SWITCHING_PROTOCOLS_101, |
| 599 | + headers=headers, |
| 600 | + ) |
| 601 | + self._headers.setdefault("Upgrade", "websocket") |
| 602 | + self._headers.setdefault("Connection", "Upgrade") |
| 603 | + self._headers.setdefault("Sec-WebSocket-Accept", sec_accept_key) |
| 604 | + self._headers.setdefault("Content-Type", None) |
| 605 | + self._buffer_size = buffer_size |
| 606 | + self.closed = False |
| 607 | + |
| 608 | + request.connection.setblocking(False) |
| 609 | + |
| 610 | + |
| 611 | + @staticmethod |
| 612 | + def _parse_frame_header(header): |
| 613 | + fin = header[0] & Websocket.FIN |
| 614 | + opcode = header[0] & 0b00001111 |
| 615 | + has_mask = header[1] & 0b10000000 |
| 616 | + length = header[1] & 0b01111111 |
| 617 | + |
| 618 | + if length == 0b01111110: |
| 619 | + length = -2 |
| 620 | + elif length == 0b01111111: |
| 621 | + length = -8 |
| 622 | + |
| 623 | + return fin, opcode, has_mask, length |
| 624 | + |
| 625 | + def _read_frame(self): |
| 626 | + buffer = bytearray(self._buffer_size) |
| 627 | + |
| 628 | + header_length = self._request.connection.recv_into(buffer, 2) |
| 629 | + header_bytes = buffer[:header_length] |
| 630 | + |
| 631 | + fin, opcode, has_mask, length = self._parse_frame_header(header_bytes) |
| 632 | + |
| 633 | + # TODO: Handle continuation frames, currently not supported |
| 634 | + if fin != Websocket.FIN and opcode == Websocket.CONT: |
| 635 | + return Websocket.CONT, None |
| 636 | + |
| 637 | + payload = bytes() |
| 638 | + if fin == Websocket.FIN and opcode == Websocket.CLOSE: |
| 639 | + return Websocket.CLOSE, payload |
| 640 | + |
| 641 | + if length < 0: |
| 642 | + length = self._request.connection.recv_into(buffer, -length) |
| 643 | + length = int.from_bytes(buffer[:length], 'big') |
| 644 | + |
| 645 | + if has_mask: |
| 646 | + mask_length = self._request.connection.recv_into(buffer, 4) |
| 647 | + mask = buffer[:mask_length] |
| 648 | + |
| 649 | + while 0 < length: |
| 650 | + payload_length = self._request.connection.recv_into(buffer, length) |
| 651 | + payload += buffer[:min(payload_length, length)] |
| 652 | + length -= min(payload_length, length) |
| 653 | + |
| 654 | + if has_mask: |
| 655 | + payload = bytes(x ^ mask[i % 4] for i, x in enumerate(payload)) |
| 656 | + |
| 657 | + return opcode, payload |
| 658 | + |
| 659 | + def _handle_frame(self, opcode: int, payload: bytes): |
| 660 | + # TODO: Handle continuation frames, currently not supported |
| 661 | + if opcode == Websocket.CONT: |
| 662 | + return None |
| 663 | + |
| 664 | + if opcode == Websocket.CLOSE: |
| 665 | + self.close() |
| 666 | + return None |
| 667 | + |
| 668 | + if opcode == Websocket.PONG: |
| 669 | + return None |
| 670 | + elif opcode == Websocket.PING: |
| 671 | + self.send_message(payload, Websocket.PONG) |
| 672 | + return payload |
| 673 | + |
| 674 | + try: |
| 675 | + payload = payload.decode() if opcode == Websocket.TEXT else payload |
| 676 | + except UnicodeError as error: |
| 677 | + print("Payload UnicodeError: ", error, payload) |
| 678 | + pass |
| 679 | + |
| 680 | + return payload |
| 681 | + |
| 682 | + def receive(self, fail_silently: bool = False) -> Union[str, bytes, None]: |
| 683 | + """ |
| 684 | + Receive a message from the client. |
| 685 | +
|
| 686 | + :param bool fail_silently: If True, no error will be raised if the connection is closed. |
| 687 | + """ |
| 688 | + if self.closed: |
| 689 | + if fail_silently: |
| 690 | + return None |
| 691 | + raise RuntimeError("Websocket connection is closed, cannot receive messages") |
| 692 | + |
| 693 | + try: |
| 694 | + opcode, payload = self._read_frame() |
| 695 | + frame_data = self._handle_frame(opcode, payload) |
| 696 | + |
| 697 | + return frame_data |
| 698 | + except OSError as error: |
| 699 | + if error.errno == EAGAIN: # No messages available |
| 700 | + return None |
| 701 | + if error.errno == ETIMEDOUT: # Connection timed out |
| 702 | + return None |
| 703 | + if error.errno == ENOTCONN: # Client disconnected without closing connection |
| 704 | + self.close() |
| 705 | + return None |
| 706 | + raise error |
| 707 | + |
| 708 | + @staticmethod |
| 709 | + def _prepare_frame(opcode: int, message: bytes) -> bytearray: |
| 710 | + frame = bytearray() |
| 711 | + |
| 712 | + frame.append(Websocket.FIN | opcode) # Setting FIN bit |
| 713 | + |
| 714 | + payload_length = len(message) |
| 715 | + |
| 716 | + # Message under 126 bytes, use 1 byte for length |
| 717 | + if payload_length < 126: |
| 718 | + frame.append(payload_length) |
| 719 | + |
| 720 | + # Message between 126 and 65535 bytes, use 2 bytes for length |
| 721 | + elif payload_length < 65536: |
| 722 | + frame.append(126) |
| 723 | + frame.extend(payload_length.to_bytes(2, 'big')) |
| 724 | + |
| 725 | + # Message over 65535 bytes, use 8 bytes for length |
| 726 | + else: |
| 727 | + frame.append(127) |
| 728 | + frame.extend(payload_length.to_bytes(8, 'big')) |
| 729 | + |
| 730 | + frame.extend(message) |
| 731 | + return frame |
| 732 | + |
| 733 | + def send_message( |
| 734 | + self, |
| 735 | + message: Union[str, bytes], |
| 736 | + opcode: int = None, |
| 737 | + fail_silently: bool = False |
| 738 | + ): |
| 739 | + """ |
| 740 | + Send a message to the client. |
| 741 | +
|
| 742 | + :param str message: Message to be sent. |
| 743 | + :param int opcode: Opcode of the message. Defaults to TEXT if message is a string and |
| 744 | + BINARY for bytes. |
| 745 | + :param bool fail_silently: If True, no error will be raised if the connection is closed. |
| 746 | + """ |
| 747 | + if self.closed: |
| 748 | + if fail_silently: |
| 749 | + return None |
| 750 | + raise RuntimeError("Websocket connection is closed, cannot send message") |
| 751 | + |
| 752 | + determined_opcode = opcode or ( |
| 753 | + Websocket.TEXT if isinstance(message, str) else Websocket.BINARY |
| 754 | + ) |
| 755 | + |
| 756 | + if determined_opcode == Websocket.TEXT: |
| 757 | + message = message.encode() |
| 758 | + |
| 759 | + frame = self._prepare_frame(determined_opcode, message) |
| 760 | + |
| 761 | + try: |
| 762 | + self._send_bytes(self._request.connection, frame) |
| 763 | + except BrokenPipeError as error: |
| 764 | + if fail_silently: |
| 765 | + return None |
| 766 | + raise error |
| 767 | + |
| 768 | + def _send(self) -> None: |
| 769 | + self._send_headers() |
| 770 | + |
| 771 | + def close(self): |
| 772 | + """ |
| 773 | + Close the connection. |
| 774 | +
|
| 775 | + **Always call this method when you are done sending events.** |
| 776 | + """ |
| 777 | + if not self.closed: |
| 778 | + self.send_message(b'', Websocket.CLOSE, fail_silently=True) |
| 779 | + self._close_connection() |
| 780 | + self.closed = True |
0 commit comments