Skip to content

Check for send to return 0 #36

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 73 additions & 46 deletions adafruit_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,26 +377,28 @@ def __init__(self, socket_pool, ssl_context=None):
self._last_response = None

def _free_socket(self, socket):

if socket not in self._open_sockets.values():
raise RuntimeError("Socket not from session")
self._socket_free[socket] = True

def _close_socket(self, sock):
sock.close()
del self._socket_free[sock]
key = None
for k in self._open_sockets:
if self._open_sockets[k] == sock:
key = k
break
if key:
del self._open_sockets[key]

def _free_sockets(self):
free_sockets = []
for sock in self._socket_free:
if self._socket_free[sock]:
sock.close()
free_sockets.append(sock)
for sock in free_sockets:
del self._socket_free[sock]
key = None
for k in self._open_sockets:
if self._open_sockets[k] == sock:
key = k
break
if key:
del self._open_sockets[key]
self._close_socket(sock)

def _get_socket(self, host, port, proto, *, timeout=1):
key = (host, port, proto)
Expand Down Expand Up @@ -440,6 +442,61 @@ def _get_socket(self, host, port, proto, *, timeout=1):
self._socket_free[sock] = False
return sock

@staticmethod
def _send(socket, data):
total_sent = 0
while total_sent < len(data):
sent = socket.send(data[total_sent:])
if sent is None:
sent = len(data)
if sent == 0:
raise RuntimeError("Connection closed")
total_sent += sent

def _send_request(self, socket, host, method, path, headers, data, json):
# pylint: disable=too-many-arguments
self._send(socket, bytes(method, "utf-8"))
self._send(socket, b" /")
self._send(socket, bytes(path, "utf-8"))
self._send(socket, b" HTTP/1.1\r\n")
if "Host" not in headers:
self._send(socket, b"Host: ")
self._send(socket, bytes(host, "utf-8"))
self._send(socket, b"\r\n")
if "User-Agent" not in headers:
self._send(socket, b"User-Agent: Adafruit CircuitPython\r\n")
# Iterate over keys to avoid tuple alloc
for k in headers:
self._send(socket, k.encode())
self._send(socket, b": ")
self._send(socket, headers[k].encode())
self._send(socket, b"\r\n")
if json is not None:
assert data is None
# pylint: disable=import-outside-toplevel
try:
import json as json_module
except ImportError:
import ujson as json_module
data = json_module.dumps(json)
self._send(socket, b"Content-Type: application/json\r\n")
if data:
if isinstance(data, dict):
self._send(
socket, b"Content-Type: application/x-www-form-urlencoded\r\n"
)
_post_data = ""
for k in data:
_post_data = "{}&{}={}".format(_post_data, k, data[k])
data = _post_data[1:]
self._send(socket, b"Content-Length: %d\r\n" % len(data))
self._send(socket, b"\r\n")
if data:
if isinstance(data, bytearray):
self._send(socket, bytes(data))
else:
self._send(socket, bytes(data, "utf-8"))

# pylint: disable=too-many-branches, too-many-statements, unused-argument, too-many-arguments, too-many-locals
def request(
self, method, url, data=None, json=None, headers=None, stream=False, timeout=60
Expand Down Expand Up @@ -476,42 +533,11 @@ def request(
self._last_response = None

socket = self._get_socket(host, port, proto, timeout=timeout)
socket.send(
b"%s /%s HTTP/1.1\r\n" % (bytes(method, "utf-8"), bytes(path, "utf-8"))
)
if "Host" not in headers:
socket.send(b"Host: %s\r\n" % bytes(host, "utf-8"))
if "User-Agent" not in headers:
socket.send(b"User-Agent: Adafruit CircuitPython\r\n")
# Iterate over keys to avoid tuple alloc
for k in headers:
socket.send(k.encode())
socket.send(b": ")
socket.send(headers[k].encode())
socket.send(b"\r\n")
if json is not None:
assert data is None
# pylint: disable=import-outside-toplevel
try:
import json as json_module
except ImportError:
import ujson as json_module
data = json_module.dumps(json)
socket.send(b"Content-Type: application/json\r\n")
if data:
if isinstance(data, dict):
socket.send(b"Content-Type: application/x-www-form-urlencoded\r\n")
_post_data = ""
for k in data:
_post_data = "{}&{}={}".format(_post_data, k, data[k])
data = _post_data[1:]
socket.send(b"Content-Length: %d\r\n" % len(data))
socket.send(b"\r\n")
if data:
if isinstance(data, bytearray):
socket.send(bytes(data))
else:
socket.send(bytes(data, "utf-8"))
try:
self._send_request(socket, host, method, path, headers, data, json)
except:
self._close_socket(socket)
raise

resp = Response(socket, self) # our response
if "location" in resp.headers and 300 <= resp.status_code <= 399:
Expand Down Expand Up @@ -557,6 +583,7 @@ def __init__(self, socket, tls_mode):
self.settimeout = socket.settimeout
self.send = socket.send
self.recv = socket.recv
self.close = socket.close

def connect(self, address):
"""connect wrapper to add non-standard mode parameter"""
Expand Down
10 changes: 8 additions & 2 deletions tests/chunk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,16 @@ def test_get_text():
r = s.get("http://" + host + path)

sock.connect.assert_called_once_with((ip, 80))

sock.send.assert_has_calls(
[
mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"),
mock.call(b"Host: wifitest.adafruit.com\r\n"),
mock.call(b"GET"),
mock.call(b" /"),
mock.call(b"testwifi/index.html"),
mock.call(b" HTTP/1.1\r\n"),
]
)
sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),]
)
assert r.text == str(text, "utf-8")
7 changes: 6 additions & 1 deletion tests/header_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ def test_json():
sock = mocket.Mocket(response_headers)
pool.socket.return_value = sock
sent = []
sock.send.side_effect = sent.append

def _send(data):
sent.append(data)
return len(data)

sock.send.side_effect = _send

s = adafruit_requests.Session(pool)
headers = {"user-agent": "blinka/1.0.0"}
Expand Down
5 changes: 4 additions & 1 deletion tests/legacy_mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@ def __init__(self, response):
self.settimeout = mock.Mock()
self.close = mock.Mock()
self.connect = mock.Mock()
self.send = mock.Mock()
self.send = mock.Mock(side_effect=self._send)
self.readline = mock.Mock(side_effect=self._readline)
self.recv = mock.Mock(side_effect=self._recv)
self._response = response
self._position = 0

def _send(self, data):
return len(data)

def _readline(self):
i = self._response.find(b"\r\n", self._position)
r = self._response[self._position : i + 2]
Expand Down
5 changes: 4 additions & 1 deletion tests/mocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,16 @@ def __init__(self, response):
self.settimeout = mock.Mock()
self.close = mock.Mock()
self.connect = mock.Mock()
self.send = mock.Mock()
self.send = mock.Mock(side_effect=self._send)
self.readline = mock.Mock(side_effect=self._readline)
self.recv = mock.Mock(side_effect=self._recv)
self.recv_into = mock.Mock(side_effect=self._recv_into)
self._response = response
self._position = 0

def _send(self, data):
return len(data)

def _readline(self):
i = self._response.find(b"\r\n", self._position)
r = self._response[self._position : i + 2]
Expand Down
11 changes: 10 additions & 1 deletion tests/post_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,17 @@ def test_method():
s = adafruit_requests.Session(pool)
r = s.post("http://" + host + "/post")
sock.connect.assert_called_once_with((ip, 80))

sock.send.assert_has_calls(
[
mock.call(b"POST"),
mock.call(b" /"),
mock.call(b"post"),
mock.call(b" HTTP/1.1\r\n"),
]
)
sock.send.assert_has_calls(
[mock.call(b"POST /post HTTP/1.1\r\n"), mock.call(b"Host: httpbin.org\r\n")]
[mock.call(b"Host: "), mock.call(b"httpbin.org"),]
)


Expand Down
20 changes: 16 additions & 4 deletions tests/protocol_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,18 @@ def test_get_https_text():
r = s.get("https://" + host + path)

sock.connect.assert_called_once_with((host, 443))

sock.send.assert_has_calls(
[
mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"),
mock.call(b"Host: wifitest.adafruit.com\r\n"),
mock.call(b"GET"),
mock.call(b" /"),
mock.call(b"testwifi/index.html"),
mock.call(b" HTTP/1.1\r\n"),
]
)
sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),]
)
assert r.text == str(text, "utf-8")

# Close isn't needed but can be called to release the socket early.
Expand All @@ -54,10 +60,16 @@ def test_get_http_text():
r = s.get("http://" + host + path)

sock.connect.assert_called_once_with((ip, 80))

sock.send.assert_has_calls(
[
mock.call(b"GET /testwifi/index.html HTTP/1.1\r\n"),
mock.call(b"Host: wifitest.adafruit.com\r\n"),
mock.call(b"GET"),
mock.call(b" /"),
mock.call(b"testwifi/index.html"),
mock.call(b" HTTP/1.1\r\n"),
]
)
sock.send.assert_has_calls(
[mock.call(b"Host: "), mock.call(b"wifitest.adafruit.com"),]
)
assert r.text == str(text, "utf-8")
Loading