Skip to content

Fix socket.close() behaviour #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Mar 20, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions adafruit_wiznet5k/adafruit_wiznet5k_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def __init__(
"""
if family != AF_INET:
raise RuntimeError("Only AF_INET family supported by W5K modules.")
self._socket_closed = False
self._sock_type = type
self._buffer = b""
self._timeout = _default_socket_timeout
Expand All @@ -251,6 +252,17 @@ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
if time.monotonic() - stamp > 1000:
raise RuntimeError("Failed to close socket")

# This works around problems with using a class method as a decorator.
def _check_socket_closed(func): # pylint: disable=no-self-argument
"""Decorator to check whether the socket object has been closed."""

def wrapper(self, *args, **kwargs):
if self._socket_closed: # pylint: disable=protected-access
raise RuntimeError("The socket has been closed.")
return func(self, *args, **kwargs) # pylint: disable=not-callable

return wrapper

@property
def _status(self) -> int:
"""
Expand Down Expand Up @@ -288,6 +300,7 @@ def _connected(self) -> bool:
self.close()
return result

@_check_socket_closed
def getpeername(self) -> Tuple[str, int]:
"""
Return the remote address to which the socket is connected.
Expand All @@ -298,6 +311,7 @@ def getpeername(self) -> Tuple[str, int]:
self._socknum
)

@_check_socket_closed
def bind(self, address: Tuple[Optional[str], int]) -> None:
"""
Bind the socket to address. The socket must not already be bound.
Expand Down Expand Up @@ -343,6 +357,7 @@ def _bind(self, address: Tuple[Optional[str], int]) -> None:
)
self._buffer = b""

@_check_socket_closed
def listen(self, backlog: int = 0) -> None:
"""
Enable a server to accept connections.
Expand All @@ -354,6 +369,7 @@ def listen(self, backlog: int = 0) -> None:
_the_interface.socket_listen(self._socknum, self._listen_port)
self._buffer = b""

@_check_socket_closed
def accept(
self,
) -> Tuple[socket, Tuple[str, int]]:
Expand Down Expand Up @@ -388,6 +404,7 @@ def accept(
raise RuntimeError("Failed to open new listening socket")
return client_sock, addr

@_check_socket_closed
def connect(self, address: Tuple[str, int]) -> None:
"""
Connect to a remote socket at address.
Expand All @@ -407,6 +424,7 @@ def connect(self, address: Tuple[str, int]) -> None:
raise RuntimeError("Failed to connect to host ", address[0])
self._buffer = b""

@_check_socket_closed
def send(self, data: Union[bytes, bytearray]) -> int:
"""
Send data to the socket. The socket must be connected to a remote socket.
Expand All @@ -422,6 +440,7 @@ def send(self, data: Union[bytes, bytearray]) -> int:
gc.collect()
return bytes_sent

@_check_socket_closed
def sendto(self, data: bytearray, *flags_and_or_address: any) -> int:
"""
Send data to the socket. The socket should not be connected to a remote socket, since the
Expand All @@ -445,6 +464,7 @@ def sendto(self, data: bytearray, *flags_and_or_address: any) -> int:
self.connect(address)
return self.send(data)

@_check_socket_closed
def recv(
# pylint: disable=too-many-branches
self,
Expand Down Expand Up @@ -500,6 +520,7 @@ def _embed_recv(
gc.collect()
return ret

@_check_socket_closed
def recvfrom(self, bufsize: int, flags: int = 0) -> Tuple[bytes, Tuple[str, int]]:
"""
Receive data from the socket. The return value is a pair (bytes, address) where bytes is
Expand All @@ -520,6 +541,7 @@ def recvfrom(self, bufsize: int, flags: int = 0) -> Tuple[bytes, Tuple[str, int]
),
)

@_check_socket_closed
def recv_into(self, buffer: bytearray, nbytes: int = 0, flags: int = 0) -> int:
"""
Receive up to nbytes bytes from the socket, storing the data into a buffer
Expand All @@ -538,6 +560,7 @@ def recv_into(self, buffer: bytearray, nbytes: int = 0, flags: int = 0) -> int:
buffer[:nbytes] = bytes_received
return nbytes

@_check_socket_closed
def recvfrom_into(
self, buffer: bytearray, nbytes: int = 0, flags: int = 0
) -> Tuple[int, Tuple[str, int]]:
Expand Down Expand Up @@ -596,11 +619,13 @@ def _disconnect(self) -> None:
raise RuntimeError("Socket must be a TCP socket.")
_the_interface.socket_disconnect(self._socknum)

@_check_socket_closed
def close(self) -> None:
"""
Mark the socket closed. Once that happens, all future operations on the socket object
will fail. The remote end will receive no more data.
"""
self._socket_closed = True
_the_interface.socket_close(self._socknum)

def _available(self) -> int:
Expand All @@ -611,6 +636,7 @@ def _available(self) -> int:
"""
return _the_interface.socket_available(self._socknum, self._sock_type)

@_check_socket_closed
def settimeout(self, value: Optional[float]) -> None:
"""
Set a timeout on blocking socket operations. The value argument can be a
Expand All @@ -627,6 +653,7 @@ def settimeout(self, value: Optional[float]) -> None:
else:
raise ValueError("Timeout must be None, 0.0 or a positive numeric value.")

@_check_socket_closed
def gettimeout(self) -> Optional[float]:
"""
Return the timeout in seconds (float) associated with socket operations, or None if no
Expand All @@ -636,6 +663,7 @@ def gettimeout(self) -> Optional[float]:
"""
return self._timeout

@_check_socket_closed
def setblocking(self, flag: bool) -> None:
"""
Set blocking or non-blocking mode of the socket: if flag is false, the socket is set
Expand All @@ -658,6 +686,7 @@ def setblocking(self, flag: bool) -> None:
else:
raise TypeError("Flag must be a boolean.")

@_check_socket_closed
def getblocking(self) -> bool:
"""
Return True if socket is in blocking mode, False if in non-blocking.
Expand All @@ -669,16 +698,19 @@ def getblocking(self) -> bool:
return self.gettimeout() == 0

@property
@_check_socket_closed
def family(self) -> int:
"""Socket family (always 0x03 in this implementation)."""
return 3

@property
@_check_socket_closed
def type(self):
"""Socket type."""
return self._sock_type

@property
@_check_socket_closed
def proto(self):
"""Socket protocol (always 0x00 in this implementation)."""
return 0