diff --git a/umqtt.simple/example_pub.py b/umqtt.simple/example_pub.py index c15d07642..4e6d71857 100644 --- a/umqtt.simple/example_pub.py +++ b/umqtt.simple/example_pub.py @@ -1,13 +1,15 @@ from umqtt.simple import MQTTClient +#from umqtt_debug import DebugMQTTClient as MQTTClient # Test reception e.g. with: # mosquitto_sub -t foo_topic -def main(server="localhost"): +def main(server="localhost", topic="foo_topic", msg="hello", qos=0): c = MQTTClient("umqtt_client", server) c.connect() - c.publish(b"foo_topic", b"hello") + c.publish(topic.encode('utf-8'), msg.encode('utf-8'), qos=int(qos)) c.disconnect() if __name__ == "__main__": - main() + import sys + main(*sys.argv[1:]) diff --git a/umqtt.simple/example_sub.py b/umqtt.simple/example_sub.py index f08818488..a31d46305 100644 --- a/umqtt.simple/example_sub.py +++ b/umqtt.simple/example_sub.py @@ -1,5 +1,6 @@ import time from umqtt.simple import MQTTClient +#from umqtt_debug import DebugMQTTClient as MQTTClient # Publish test messages e.g. with: # mosquitto_pub -t foo_topic -m hello @@ -7,14 +8,14 @@ # Received messages from subscriptions will be delivered to this callback def sub_cb(topic, msg): print((topic, msg)) - -def main(server="localhost"): + +def main(server="localhost", topic="foo_topic", qos=0, blocking=True): c = MQTTClient("umqtt_client", server) c.set_callback(sub_cb) c.connect() - c.subscribe(b"foo_topic") + c.subscribe(topic.encode('utf-8'), int(qos)) while True: - if True: + if blocking: # Blocking wait for message c.wait_msg() else: @@ -27,4 +28,5 @@ def main(server="localhost"): c.disconnect() if __name__ == "__main__": - main() + import sys + main(*sys.argv[1:]) diff --git a/umqtt.simple/umqtt/simple.py b/umqtt.simple/umqtt/simple.py index 8216fa5e1..fdcc93ffe 100644 --- a/umqtt.simple/umqtt/simple.py +++ b/umqtt.simple/umqtt/simple.py @@ -1,6 +1,4 @@ import usocket as socket -import ustruct as struct -from ubinascii import hexlify class MQTTException(Exception): pass @@ -28,7 +26,8 @@ def __init__(self, client_id, server, port=0, user=None, password=None, keepaliv self.lw_retain = False def _send_str(self, s): - self.sock.write(struct.pack("!H", len(s))) + assert len(s) < 65536 + self.sock.write(len(s).to_bytes(2, 'big')) self.sock.write(s) def _recv_len(self): @@ -41,6 +40,15 @@ def _recv_len(self): return n sh += 7 + def _varlen_encode(self, value, buf, offset=0): + assert value < 268435456 # 2**28, i.e. max. four 7-bit bytes + while value > 0x7f: + buf[offset] = (value & 0x7f) | 0x80 + value >>= 7 + offset += 1 + buf[offset] = value + return offset + 1 + def set_callback(self, f): self.cb = f @@ -59,40 +67,37 @@ def connect(self, clean_session=True): if self.ssl: import ussl self.sock = ussl.wrap_socket(self.sock, **self.ssl_params) - premsg = bytearray(b"\x10\0\0\0\0\0") - msg = bytearray(b"\x04MQTT\x04\x02\0\0") + premsg = bytearray(b"\x10\0\0\0\0") + msg = bytearray(b"\0\x04MQTT\x04\0\0\0") sz = 10 + 2 + len(self.client_id) - msg[6] = clean_session << 1 + msg[7] = clean_session << 1 if self.user is not None: - sz += 2 + len(self.user) + 2 + len(self.pswd) - msg[6] |= 0xC0 + sz += 2 + len(self.user) + msg[7] |= 1 << 7 + if self.pswd is not None: + sz += 2 + len(self.pswd) + msg[7] |= 1 << 6 if self.keepalive: assert self.keepalive < 65536 - msg[7] |= self.keepalive >> 8 - msg[8] |= self.keepalive & 0x00FF + msg[8] |= self.keepalive >> 8 + msg[9] |= self.keepalive & 0x00FF if self.lw_topic: sz += 2 + len(self.lw_topic) + 2 + len(self.lw_msg) - msg[6] |= 0x4 | (self.lw_qos & 0x1) << 3 | (self.lw_qos & 0x2) << 3 - msg[6] |= self.lw_retain << 5 - - i = 1 - while sz > 0x7f: - premsg[i] = (sz & 0x7f) | 0x80 - sz >>= 7 - i += 1 - premsg[i] = sz + msg[7] |= 0x4 | (self.lw_qos & 0x1) << 3 | (self.lw_qos & 0x2) << 3 + msg[7] |= self.lw_retain << 5 - self.sock.write(premsg, i + 2) + plen = self._varlen_encode(sz, premsg, 1) + self.sock.write(premsg, plen) self.sock.write(msg) - #print(hex(len(msg)), hexlify(msg, ":")) self._send_str(self.client_id) if self.lw_topic: self._send_str(self.lw_topic) self._send_str(self.lw_msg) if self.user is not None: self._send_str(self.user) - self._send_str(self.pswd) + if self.pswd is not None: + self._send_str(self.pswd) resp = self.sock.read(4) assert resp[0] == 0x20 and resp[1] == 0x02 if resp[3] != 0: @@ -107,26 +112,18 @@ def ping(self): self.sock.write(b"\xc0\0") def publish(self, topic, msg, retain=False, qos=0): - pkt = bytearray(b"\x30\0\0\0") + pkt = bytearray(b"\x30\0\0\0\0") pkt[0] |= qos << 1 | retain sz = 2 + len(topic) + len(msg) if qos > 0: sz += 2 - assert sz < 2097152 - i = 1 - while sz > 0x7f: - pkt[i] = (sz & 0x7f) | 0x80 - sz >>= 7 - i += 1 - pkt[i] = sz - #print(hex(len(pkt)), hexlify(pkt, ":")) - self.sock.write(pkt, i + 1) + plen = self._varlen_encode(sz, pkt, 1) + self.sock.write(pkt, plen) self._send_str(topic) if qos > 0: self.pid += 1 pid = self.pid - struct.pack_into("!H", pkt, 0, pid) - self.sock.write(pkt, 2) + self.sock.write(pid.to_bytes(2, 'big')) self.sock.write(msg) if qos == 1: while 1: @@ -134,8 +131,7 @@ def publish(self, topic, msg, retain=False, qos=0): if op == 0x40: sz = self.sock.read(1) assert sz == b"\x02" - rcv_pid = self.sock.read(2) - rcv_pid = rcv_pid[0] << 8 | rcv_pid[1] + rcv_pid = int.from_bytes(self.sock.read(2), 'big') if pid == rcv_pid: return elif qos == 2: @@ -143,19 +139,19 @@ def publish(self, topic, msg, retain=False, qos=0): def subscribe(self, topic, qos=0): assert self.cb is not None, "Subscribe callback is not set" - pkt = bytearray(b"\x82\0\0\0") + pkt = bytearray(b"\x82\0\0\0\0\0\0") self.pid += 1 - struct.pack_into("!BH", pkt, 1, 2 + 2 + len(topic) + 1, self.pid) - #print(hex(len(pkt)), hexlify(pkt, ":")) - self.sock.write(pkt) + sz = 2 + 2 + len(topic) + 1 + plen = self._varlen_encode(sz, pkt, 1) + pkt[plen:plen + 2] = self.pid.to_bytes(2, 'big') + self.sock.write(pkt, plen + 2) self._send_str(topic) self.sock.write(qos.to_bytes(1, "little")) while 1: op = self.wait_msg() if op == 0x90: resp = self.sock.read(4) - #print(resp) - assert resp[1] == pkt[2] and resp[2] == pkt[3] + assert resp[1] == pkt[plen] and resp[2] == pkt[plen + 1] if resp[3] == 0x80: raise MQTTException(resp[3]) return @@ -179,20 +175,17 @@ def wait_msg(self): if op & 0xf0 != 0x30: return op sz = self._recv_len() - topic_len = self.sock.read(2) - topic_len = (topic_len[0] << 8) | topic_len[1] + topic_len = int.from_bytes(self.sock.read(2), 'big') topic = self.sock.read(topic_len) sz -= topic_len + 2 if op & 6: - pid = self.sock.read(2) - pid = pid[0] << 8 | pid[1] + pid = int.from_bytes(self.sock.read(2), 'big') sz -= 2 msg = self.sock.read(sz) self.cb(topic, msg) if op & 6 == 2: - pkt = bytearray(b"\x40\x02\0\0") - struct.pack_into("!H", pkt, 2, pid) - self.sock.write(pkt) + self.sock.write(b"\x40\x02") + self.sock.write(pid.to_bytes(2, 'big')) elif op & 6 == 4: assert 0 diff --git a/umqtt.simple/umqtt_debug.py b/umqtt.simple/umqtt_debug.py new file mode 100644 index 000000000..e70aea743 --- /dev/null +++ b/umqtt.simple/umqtt_debug.py @@ -0,0 +1,38 @@ +from umqtt.simple import MQTTClient +from ubinascii import hexlify + + +class debug_socket: + def __init__(self, sock): + self.sock = sock + + def read(self, len): + data = self.sock.read(len) + print("RECV:", hexlify(data, ':').decode()) + return data + + def write(self, data, len=None): + print("SEND:", hexlify(data, ':').decode()) + if len is None: + return self.sock.write(data) + else: + return self.sock.write(data, len) + + def __getattr__(self, name): + return getattr(self.sock, name) + + +class DebugMQTTClient(MQTTClient): + def __init__(self, *args, **kw): + super().__init__(*args, **kw) + self._sock = None + + @property + def sock(self): + return self._sock + + @sock.setter + def sock(self, val): + if val: + val = debug_socket(val) + self._sock = val