Skip to content

Fix DHCP socket leak #122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 40 additions & 31 deletions adafruit_wiznet5k/adafruit_wiznet5k.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@
_MR_RST = const(0x80) # Mode Register RST
# Socket mode register
_SNMR_CLOSE = const(0x00)
SNMR_TCP = const(0x21)
_SNMR_TCP = const(0x21)
SNMR_UDP = const(0x02)
_SNMR_IPRAW = const(0x03)
_SNMR_MACRAW = const(0x04)
Expand Down Expand Up @@ -492,7 +492,7 @@ def ifconfig(

# *** Public Socket Methods ***

def socket_available(self, socket_num: int, sock_type: int = SNMR_TCP) -> int:
def socket_available(self, socket_num: int, sock_type: int = _SNMR_TCP) -> int:
"""
Number of bytes available to be read from the socket.

Expand All @@ -514,7 +514,7 @@ def socket_available(self, socket_num: int, sock_type: int = SNMR_TCP) -> int:
self._sock_num_in_range(socket_num)

number_of_bytes = self._get_rx_rcv_size(socket_num)
if self.read_snsr(socket_num) == SNMR_UDP:
if self._read_snsr(socket_num) == SNMR_UDP:
number_of_bytes -= 8 # Subtract UDP header from packet size.
if number_of_bytes < 0:
raise ValueError("Negative number of bytes found on socket.")
Expand All @@ -533,14 +533,14 @@ def socket_status(self, socket_num: int) -> int:

:return int: The connection status.
"""
return self.read_snsr(socket_num)
return self._read_snsr(socket_num)

def socket_connect(
self,
socket_num: int,
dest: IpAddress4Raw,
port: int,
conn_mode: int = SNMR_TCP,
conn_mode: int = _SNMR_TCP,
) -> int:
"""
Open and verify a connection from a socket to a destination IPv4 address
Expand All @@ -567,11 +567,11 @@ def socket_connect(
# initialize a socket and set the mode
self.socket_open(socket_num, conn_mode=conn_mode)
# set socket destination IP and port
self.write_sndipr(socket_num, dest)
self.write_sndport(socket_num, port)
self.write_sncr(socket_num, _CMD_SOCK_CONNECT)
self._write_sndipr(socket_num, dest)
self._write_sndport(socket_num, port)
self._write_sncr(socket_num, _CMD_SOCK_CONNECT)

if conn_mode == SNMR_TCP:
if conn_mode == _SNMR_TCP:
# wait for tcp connection establishment
while self.socket_status(socket_num) != SNSR_SOCK_ESTABLISHED:
time.sleep(0.001)
Expand Down Expand Up @@ -638,7 +638,7 @@ def release_socket(self, socket_number):
WIZNET5K._sockets_reserved[socket_number - 1] = False

def socket_listen(
self, socket_num: int, port: int, conn_mode: int = SNMR_TCP
self, socket_num: int, port: int, conn_mode: int = _SNMR_TCP
) -> None:
"""
Listen on a socket's port.
Expand All @@ -665,15 +665,15 @@ def socket_listen(
self.socket_open(socket_num, conn_mode=conn_mode)
self.src_port = 0
# Send listen command
self.write_sncr(socket_num, _CMD_SOCK_LISTEN)
self._write_sncr(socket_num, _CMD_SOCK_LISTEN)
# Wait until ready
status = SNSR_SOCK_CLOSED
while status not in (
SNSR_SOCK_LISTEN,
SNSR_SOCK_ESTABLISHED,
_SNSR_SOCK_UDP,
):
status = self.read_snsr(socket_num)
status = self._read_snsr(socket_num)
if status == SNSR_SOCK_CLOSED:
raise RuntimeError("Listening socket closed.")

Expand Down Expand Up @@ -703,7 +703,7 @@ def socket_accept(self, socket_num: int) -> Tuple[int, Tuple[str, int]]:
)
return next_socknum, (dest_ip, dest_port)

def socket_open(self, socket_num: int, conn_mode: int = SNMR_TCP) -> None:
def socket_open(self, socket_num: int, conn_mode: int = _SNMR_TCP) -> None:
"""
Open an IP socket.

Expand All @@ -720,7 +720,7 @@ def socket_open(self, socket_num: int, conn_mode: int = SNMR_TCP) -> None:
self._sock_num_in_range(socket_num)
self._check_link_status()
debug_msg("*** Opening socket {}".format(socket_num), self._debug)
if self.read_snsr(socket_num) not in (
if self._read_snsr(socket_num) not in (
SNSR_SOCK_CLOSED,
SNSR_SOCK_TIME_WAIT,
SNSR_SOCK_FIN_WAIT,
Expand All @@ -732,22 +732,22 @@ def socket_open(self, socket_num: int, conn_mode: int = SNMR_TCP) -> None:
debug_msg("* Opening W5k Socket, protocol={}".format(conn_mode), self._debug)
time.sleep(0.00025)

self.write_snmr(socket_num, conn_mode)
self._write_snmr(socket_num, conn_mode)
self.write_snir(socket_num, 0xFF)

if self.src_port > 0:
# write to socket source port
self.write_sock_port(socket_num, self.src_port)
self._write_sock_port(socket_num, self.src_port)
else:
s_port = randint(49152, 65535)
while s_port in self._src_ports_in_use:
s_port = randint(49152, 65535)
self.write_sock_port(socket_num, s_port)
self._write_sock_port(socket_num, s_port)
self._src_ports_in_use[socket_num] = s_port

# open socket
self.write_sncr(socket_num, _CMD_SOCK_OPEN)
if self.read_snsr(socket_num) not in [_SNSR_SOCK_INIT, _SNSR_SOCK_UDP]:
self._write_sncr(socket_num, _CMD_SOCK_OPEN)
if self._read_snsr(socket_num) not in [_SNSR_SOCK_INIT, _SNSR_SOCK_UDP]:
raise RuntimeError("Could not open socket in TCP or UDP mode.")

def socket_close(self, socket_num: int) -> None:
Expand All @@ -760,14 +760,14 @@ def socket_close(self, socket_num: int) -> None:
"""
debug_msg("*** Closing socket {}".format(socket_num), self._debug)
self._sock_num_in_range(socket_num)
self.write_sncr(socket_num, _CMD_SOCK_CLOSE)
self._write_sncr(socket_num, _CMD_SOCK_CLOSE)
debug_msg(" Waiting for socket to close…", self._debug)
timeout = time.monotonic() + 5.0
while self.read_snsr(socket_num) != SNSR_SOCK_CLOSED:
while self._read_snsr(socket_num) != SNSR_SOCK_CLOSED:
if time.monotonic() > timeout:
raise RuntimeError(
"Wiznet5k failed to close socket, status = {}.".format(
self.read_snsr(socket_num)
self._read_snsr(socket_num)
)
)
time.sleep(0.0001)
Expand All @@ -783,7 +783,7 @@ def socket_disconnect(self, socket_num: int) -> None:
"""
debug_msg("*** Disconnecting socket {}".format(socket_num), self._debug)
self._sock_num_in_range(socket_num)
self.write_sncr(socket_num, _CMD_SOCK_DISCON)
self._write_sncr(socket_num, _CMD_SOCK_DISCON)

def socket_read(self, socket_num: int, length: int) -> Tuple[int, bytes]:
"""
Expand Down Expand Up @@ -819,7 +819,7 @@ def socket_read(self, socket_num: int, length: int) -> Tuple[int, bytes]:
# After reading the received data, update Sn_RX_RD register.
pointer = (pointer + bytes_on_socket) & 0xFFFF
self._write_snrx_rd(socket_num, pointer)
self.write_sncr(socket_num, _CMD_SOCK_RECV)
self._write_sncr(socket_num, _CMD_SOCK_RECV)
else:
# no data on socket
if self._read_snmr(socket_num) in (
Expand Down Expand Up @@ -906,7 +906,7 @@ def socket_write(
# update sn_tx_wr to the value + data size
pointer = (pointer + bytes_to_write) & 0xFFFF
self._write_sntx_wr(socket_num, pointer)
self.write_sncr(socket_num, _CMD_SOCK_SEND)
self._write_sncr(socket_num, _CMD_SOCK_SEND)

# check data was transferred correctly
while not self.read_snir(socket_num) & _SNIR_SEND_OK:
Expand Down Expand Up @@ -1057,6 +1057,11 @@ def _check_link_status(self):
if not self.link_status:
raise ConnectionError("The Ethernet connection is down.")

@staticmethod
def _read_socket_reservations() -> list[int]:
"""Return the list of reserved sockets."""
return WIZNET5K._sockets_reserved

def _read_mr(self) -> int:
"""Read from the Mode Register (MR)."""
return int.from_bytes(self._read(_REG_MR[self._chip_type], 0x00), "big")
Expand Down Expand Up @@ -1175,18 +1180,22 @@ def _read_sndipr(self, sock) -> bytes:
)
return bytes(data)

def write_sndipr(self, sock: int, ip_addr: bytes) -> None:
def _write_sndipr(self, sock: int, ip_addr: bytes) -> None:
"""Write to socket destination IP Address."""
for offset, value in enumerate(ip_addr):
self._write_socket_register(
sock, _REG_SNDIPR[self._chip_type] + offset, value
)

def write_sndport(self, sock: int, port: int) -> None:
def _read_sndport(self, sock: int) -> int:
"""Read socket destination port."""
return self._read_two_byte_sock_reg(sock, _REG_SNDPORT[self._chip_type])

def _write_sndport(self, sock: int, port: int) -> None:
"""Write to socket destination port."""
self._write_two_byte_sock_reg(sock, _REG_SNDPORT[self._chip_type], port)

def read_snsr(self, sock: int) -> int:
def _read_snsr(self, sock: int) -> int:
"""Read Socket n Status Register."""
return self._read_socket_register(sock, _REG_SNSR[self._chip_type])

Expand All @@ -1202,15 +1211,15 @@ def _read_snmr(self, sock: int) -> int:
"""Read the socket MR register."""
return self._read_socket_register(sock, _REG_SNMR)

def write_snmr(self, sock: int, protocol: int) -> None:
def _write_snmr(self, sock: int, protocol: int) -> None:
"""Write to Socket n Mode Register."""
self._write_socket_register(sock, _REG_SNMR, protocol)

def write_sock_port(self, sock: int, port: int) -> None:
def _write_sock_port(self, sock: int, port: int) -> None:
"""Write to the socket port number."""
self._write_two_byte_sock_reg(sock, _REG_SNPORT[self._chip_type], port)

def write_sncr(self, sock: int, data: int) -> None:
def _write_sncr(self, sock: int, data: int) -> None:
"""Write to socket command register."""
self._write_socket_register(sock, _REG_SNCR[self._chip_type], data)
# Wait for command to complete before continuing.
Expand Down
Loading