Skip to content

Commit 5390215

Browse files
authored
Merge pull request #54 from tannewt/test_read_before_response
Don't trust send works. Do one recv before creating Response
2 parents 531e845 + af04194 commit 5390215

File tree

3 files changed

+62
-8
lines changed

3 files changed

+62
-8
lines changed

adafruit_requests.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,10 @@ class _SendFailed(Exception):
8888
"""Custom exception to abort sending a request."""
8989

9090

91+
class OutOfRetries(Exception):
92+
"""Raised when requests has retried to make a request unsuccessfully."""
93+
94+
9195
class Response:
9296
"""The response from a request, contains all the headers/content"""
9397

@@ -570,13 +574,27 @@ def request(
570574
while retry_count < 2:
571575
retry_count += 1
572576
socket = self._get_socket(host, port, proto, timeout=timeout)
577+
ok = True
573578
try:
574579
self._send_request(socket, host, method, path, headers, data, json)
575-
break
576580
except _SendFailed:
577-
self._close_socket(socket)
578-
if retry_count > 1:
579-
raise
581+
ok = False
582+
if ok:
583+
# Read the H of "HTTP/1.1" to make sure the socket is alive. send can appear to work
584+
# even when the socket is closed.
585+
if hasattr(socket, "recv"):
586+
result = socket.recv(1)
587+
else:
588+
result = bytearray(1)
589+
socket.recv_into(result)
590+
if result == b"H":
591+
# Things seem to be ok so break with socket set.
592+
break
593+
self._close_socket(socket)
594+
socket = None
595+
596+
if not socket:
597+
raise OutOfRetries()
580598

581599
resp = Response(socket, self) # our response
582600
if "location" in resp.headers and 300 <= resp.status_code <= 399:

tests/legacy_test.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,13 +115,13 @@ def test_second_send_fails():
115115
def test_first_read_fails():
116116
mocket.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
117117
sock = mocket.Mocket(b"")
118+
sock2 = mocket.Mocket(headers + encoded)
118119
mocket.socket.call_count = 0 # Reset call count
119-
mocket.socket.side_effect = [sock]
120+
mocket.socket.side_effect = [sock, sock2]
120121

121122
adafruit_requests.set_socket(mocket, mocket.interface)
122123

123-
with pytest.raises(RuntimeError):
124-
r = adafruit_requests.get("http://" + host + "/testwifi/index.html")
124+
r = adafruit_requests.get("http://" + host + "/testwifi/index.html")
125125

126126
sock.send.assert_has_calls(
127127
[mock.call(b"testwifi/index.html"),]
@@ -131,10 +131,15 @@ def test_first_read_fails():
131131
[mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),]
132132
)
133133

134+
sock2.send.assert_has_calls(
135+
[mock.call(b"Host: "), mock.call(host.encode("utf-8")), mock.call(b"\r\n"),]
136+
)
137+
134138
sock.connect.assert_called_once_with((ip, 80))
139+
sock2.connect.assert_called_once_with((ip, 80))
135140
# Make sure that the socket is closed after the first receive fails.
136141
sock.close.assert_called_once()
137-
assert mocket.socket.call_count == 1
142+
assert mocket.socket.call_count == 2
138143

139144

140145
def test_second_tls_connect_fails():

tests/reuse_test.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,34 @@ def test_second_send_fails():
170170
sock.close.assert_called_once()
171171
assert sock2.close.call_count == 0
172172
assert pool.socket.call_count == 2
173+
174+
175+
def test_second_send_lies_recv_fails():
176+
pool = mocket.MocketPool()
177+
pool.getaddrinfo.return_value = ((None, None, None, None, (ip, 80)),)
178+
sock = mocket.Mocket(response)
179+
sock2 = mocket.Mocket(response)
180+
pool.socket.side_effect = [sock, sock2]
181+
182+
ssl = mocket.SSLContext()
183+
184+
s = adafruit_requests.Session(pool, ssl)
185+
r = s.get("https://" + host + path)
186+
187+
sock.send.assert_has_calls(
188+
[mock.call(b"testwifi/index.html"),]
189+
)
190+
191+
sock.send.assert_has_calls(
192+
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"), mock.call(b"\r\n"),]
193+
)
194+
assert r.text == str(text, "utf-8")
195+
196+
s.get("https://" + host + path + "2")
197+
198+
sock.connect.assert_called_once_with((host, 443))
199+
sock2.connect.assert_called_once_with((host, 443))
200+
# Make sure that the socket is closed after send fails.
201+
sock.close.assert_called_once()
202+
assert sock2.close.call_count == 0
203+
assert pool.socket.call_count == 2

0 commit comments

Comments
 (0)