Skip to content

umqtt.simple: refactor packet de/encoding and fix remaining length encoding (fixes #284) #303

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions umqtt.simple/example_pub.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from umqtt.simple import MQTTClient
#from umqtt_debug import DebugMQTTClient as MQTTClient

# Test reception e.g. with:
# mosquitto_sub -t foo_topic

def main(server="localhost"):
def main(server="localhost", topic="foo_topic", msg="hello", qos=0):
c = MQTTClient("umqtt_client", server)
c.connect()
c.publish(b"foo_topic", b"hello")
c.publish(topic.encode('utf-8'), msg.encode('utf-8'), qos=int(qos))
c.disconnect()

if __name__ == "__main__":
main()
import sys
main(*sys.argv[1:])
12 changes: 7 additions & 5 deletions umqtt.simple/example_sub.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import time
from umqtt.simple import MQTTClient
#from umqtt_debug import DebugMQTTClient as MQTTClient

# Publish test messages e.g. with:
# mosquitto_pub -t foo_topic -m hello

# Received messages from subscriptions will be delivered to this callback
def sub_cb(topic, msg):
print((topic, msg))

def main(server="localhost"):
def main(server="localhost", topic="foo_topic", qos=0, blocking=True):
c = MQTTClient("umqtt_client", server)
c.set_callback(sub_cb)
c.connect()
c.subscribe(b"foo_topic")
c.subscribe(topic.encode('utf-8'), int(qos))
while True:
if True:
if blocking:
# Blocking wait for message
c.wait_msg()
else:
Expand All @@ -27,4 +28,5 @@ def main(server="localhost"):
c.disconnect()

if __name__ == "__main__":
main()
import sys
main(*sys.argv[1:])
91 changes: 42 additions & 49 deletions umqtt.simple/umqtt/simple.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import usocket as socket
import ustruct as struct
from ubinascii import hexlify

class MQTTException(Exception):
pass
Expand Down Expand Up @@ -28,7 +26,8 @@ def __init__(self, client_id, server, port=0, user=None, password=None, keepaliv
self.lw_retain = False

def _send_str(self, s):
self.sock.write(struct.pack("!H", len(s)))
assert len(s) < 65536
self.sock.write(len(s).to_bytes(2, 'big'))
self.sock.write(s)

def _recv_len(self):
Expand All @@ -41,6 +40,15 @@ def _recv_len(self):
return n
sh += 7

def _varlen_encode(self, value, buf, offset=0):
assert value < 268435456 # 2**28, i.e. max. four 7-bit bytes
while value > 0x7f:
buf[offset] = (value & 0x7f) | 0x80
value >>= 7
offset += 1
buf[offset] = value
return offset + 1

def set_callback(self, f):
self.cb = f

Expand All @@ -59,40 +67,37 @@ def connect(self, clean_session=True):
if self.ssl:
import ussl
self.sock = ussl.wrap_socket(self.sock, **self.ssl_params)
premsg = bytearray(b"\x10\0\0\0\0\0")
msg = bytearray(b"\x04MQTT\x04\x02\0\0")
premsg = bytearray(b"\x10\0\0\0\0")
msg = bytearray(b"\0\x04MQTT\x04\0\0\0")

sz = 10 + 2 + len(self.client_id)
msg[6] = clean_session << 1
msg[7] = clean_session << 1
if self.user is not None:
sz += 2 + len(self.user) + 2 + len(self.pswd)
msg[6] |= 0xC0
sz += 2 + len(self.user)
msg[7] |= 1 << 7
if self.pswd is not None:
sz += 2 + len(self.pswd)
msg[7] |= 1 << 6
if self.keepalive:
assert self.keepalive < 65536
msg[7] |= self.keepalive >> 8
msg[8] |= self.keepalive & 0x00FF
msg[8] |= self.keepalive >> 8
msg[9] |= self.keepalive & 0x00FF
if self.lw_topic:
sz += 2 + len(self.lw_topic) + 2 + len(self.lw_msg)
msg[6] |= 0x4 | (self.lw_qos & 0x1) << 3 | (self.lw_qos & 0x2) << 3
msg[6] |= self.lw_retain << 5

i = 1
while sz > 0x7f:
premsg[i] = (sz & 0x7f) | 0x80
sz >>= 7
i += 1
premsg[i] = sz
msg[7] |= 0x4 | (self.lw_qos & 0x1) << 3 | (self.lw_qos & 0x2) << 3
msg[7] |= self.lw_retain << 5

self.sock.write(premsg, i + 2)
plen = self._varlen_encode(sz, premsg, 1)
self.sock.write(premsg, plen)
self.sock.write(msg)
#print(hex(len(msg)), hexlify(msg, ":"))
self._send_str(self.client_id)
if self.lw_topic:
self._send_str(self.lw_topic)
self._send_str(self.lw_msg)
if self.user is not None:
self._send_str(self.user)
self._send_str(self.pswd)
if self.pswd is not None:
self._send_str(self.pswd)
resp = self.sock.read(4)
assert resp[0] == 0x20 and resp[1] == 0x02
if resp[3] != 0:
Expand All @@ -107,55 +112,46 @@ def ping(self):
self.sock.write(b"\xc0\0")

def publish(self, topic, msg, retain=False, qos=0):
pkt = bytearray(b"\x30\0\0\0")
pkt = bytearray(b"\x30\0\0\0\0")
pkt[0] |= qos << 1 | retain
sz = 2 + len(topic) + len(msg)
if qos > 0:
sz += 2
assert sz < 2097152
i = 1
while sz > 0x7f:
pkt[i] = (sz & 0x7f) | 0x80
sz >>= 7
i += 1
pkt[i] = sz
#print(hex(len(pkt)), hexlify(pkt, ":"))
self.sock.write(pkt, i + 1)
plen = self._varlen_encode(sz, pkt, 1)
self.sock.write(pkt, plen)
self._send_str(topic)
if qos > 0:
self.pid += 1
pid = self.pid
struct.pack_into("!H", pkt, 0, pid)
self.sock.write(pkt, 2)
self.sock.write(pid.to_bytes(2, 'big'))
self.sock.write(msg)
if qos == 1:
while 1:
op = self.wait_msg()
if op == 0x40:
sz = self.sock.read(1)
assert sz == b"\x02"
rcv_pid = self.sock.read(2)
rcv_pid = rcv_pid[0] << 8 | rcv_pid[1]
rcv_pid = int.from_bytes(self.sock.read(2), 'big')
if pid == rcv_pid:
return
elif qos == 2:
assert 0

def subscribe(self, topic, qos=0):
assert self.cb is not None, "Subscribe callback is not set"
pkt = bytearray(b"\x82\0\0\0")
pkt = bytearray(b"\x82\0\0\0\0\0\0")
self.pid += 1
struct.pack_into("!BH", pkt, 1, 2 + 2 + len(topic) + 1, self.pid)
#print(hex(len(pkt)), hexlify(pkt, ":"))
self.sock.write(pkt)
sz = 2 + 2 + len(topic) + 1
plen = self._varlen_encode(sz, pkt, 1)
pkt[plen:plen + 2] = self.pid.to_bytes(2, 'big')
self.sock.write(pkt, plen + 2)
self._send_str(topic)
self.sock.write(qos.to_bytes(1, "little"))
while 1:
op = self.wait_msg()
if op == 0x90:
resp = self.sock.read(4)
#print(resp)
assert resp[1] == pkt[2] and resp[2] == pkt[3]
assert resp[1] == pkt[plen] and resp[2] == pkt[plen + 1]
if resp[3] == 0x80:
raise MQTTException(resp[3])
return
Expand All @@ -179,20 +175,17 @@ def wait_msg(self):
if op & 0xf0 != 0x30:
return op
sz = self._recv_len()
topic_len = self.sock.read(2)
topic_len = (topic_len[0] << 8) | topic_len[1]
topic_len = int.from_bytes(self.sock.read(2), 'big')
topic = self.sock.read(topic_len)
sz -= topic_len + 2
if op & 6:
pid = self.sock.read(2)
pid = pid[0] << 8 | pid[1]
pid = int.from_bytes(self.sock.read(2), 'big')
sz -= 2
msg = self.sock.read(sz)
self.cb(topic, msg)
if op & 6 == 2:
pkt = bytearray(b"\x40\x02\0\0")
struct.pack_into("!H", pkt, 2, pid)
self.sock.write(pkt)
self.sock.write(b"\x40\x02")
self.sock.write(pid.to_bytes(2, 'big'))
elif op & 6 == 4:
assert 0

Expand Down
38 changes: 38 additions & 0 deletions umqtt.simple/umqtt_debug.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from umqtt.simple import MQTTClient
from ubinascii import hexlify


class debug_socket:
def __init__(self, sock):
self.sock = sock

def read(self, len):
data = self.sock.read(len)
print("RECV:", hexlify(data, ':').decode())
return data

def write(self, data, len=None):
print("SEND:", hexlify(data, ':').decode())
if len is None:
return self.sock.write(data)
else:
return self.sock.write(data, len)

def __getattr__(self, name):
return getattr(self.sock, name)


class DebugMQTTClient(MQTTClient):
def __init__(self, *args, **kw):
super().__init__(*args, **kw)
self._sock = None

@property
def sock(self):
return self._sock

@sock.setter
def sock(self, val):
if val:
val = debug_socket(val)
self._sock = val