|
| 1 | +# This file is part of the msgpack-rpc module. |
| 2 | +# Copyright (c) 2023 Arduino SA |
| 3 | +# This Source Code Form is subject to the terms of the Mozilla Public |
| 4 | +# License, v. 2.0. If a copy of the MPL was not distributed with this |
| 5 | +# file, You can obtain one at https://mozilla.org/MPL/2.0/. |
| 6 | +# |
| 7 | +# MessagePack RPC protocol implementation for MicroPython. |
| 8 | +# https://github.com/msgpack-rpc/msgpack-rpc/blob/master/spec.md |
| 9 | + |
| 10 | +import logging |
| 11 | +import openamp |
| 12 | +import msgpack |
| 13 | +from micropython import const |
| 14 | +from io import BytesIO |
| 15 | +from time import sleep_ms, ticks_ms, ticks_diff |
| 16 | + |
| 17 | +_MSG_TYPE_REQUEST = const(0) |
| 18 | +_MSG_TYPE_RESPONSE = const(1) |
| 19 | +_MSG_TYPE_NOTIFY = const(2) |
| 20 | + |
| 21 | + |
| 22 | +def log_level_enabled(level): |
| 23 | + return logging.getLogger().isEnabledFor(level) |
| 24 | + |
| 25 | + |
| 26 | +class Future: |
| 27 | + def __init__(self, msgid, msgbuf, fname, fargs): |
| 28 | + self.msgid = msgid |
| 29 | + self.msgbuf = msgbuf |
| 30 | + self.fname = fname |
| 31 | + self.fargs = fargs |
| 32 | + |
| 33 | + def join(self, timeout=0): |
| 34 | + if log_level_enabled(logging.DEBUG): |
| 35 | + logging.debug(f"join {self.fname}()") |
| 36 | + |
| 37 | + if timeout > 0: |
| 38 | + t = ticks_ms() |
| 39 | + |
| 40 | + while self.msgid not in self.msgbuf: |
| 41 | + if timeout > 0 and ticks_diff(ticks_ms(), t) > timeout: |
| 42 | + raise OSError(f"Timeout joining function {self.fname}") |
| 43 | + sleep_ms(100) |
| 44 | + |
| 45 | + obj = self.msgbuf.pop(self.msgid) |
| 46 | + if obj[2] is not None: |
| 47 | + raise (OSError(obj[2])) |
| 48 | + |
| 49 | + if log_level_enabled(logging.DEBUG): |
| 50 | + logging.debug(f"call {self.fname}({self.fargs}) => {obj}") |
| 51 | + return obj[3] |
| 52 | + |
| 53 | + |
| 54 | +class MsgPackIO: |
| 55 | + def __init__(self): |
| 56 | + self.stream = BytesIO() |
| 57 | + |
| 58 | + def feed(self, data): |
| 59 | + offset = self.stream.tell() |
| 60 | + self.stream.write(data) |
| 61 | + self.stream.seek(offset) |
| 62 | + |
| 63 | + def readable(self): |
| 64 | + if self.stream.read(1): |
| 65 | + offset = self.stream.tell() |
| 66 | + self.stream.seek(offset - 1) |
| 67 | + return True |
| 68 | + return False |
| 69 | + |
| 70 | + def truncate(self): |
| 71 | + if self.readable(): |
| 72 | + offset = self.stream.tell() |
| 73 | + self.stream = BytesIO(self.stream.getvalue()[offset:]) |
| 74 | + |
| 75 | + def __iter__(self): |
| 76 | + return self |
| 77 | + |
| 78 | + def __next__(self): |
| 79 | + offset = self.stream.tell() |
| 80 | + try: |
| 81 | + obj = msgpack.unpack(self.stream) |
| 82 | + self.truncate() |
| 83 | + return obj |
| 84 | + except Exception: |
| 85 | + self.stream.seek(offset) |
| 86 | + raise StopIteration |
| 87 | + |
| 88 | + |
| 89 | +class MsgPackRPC: |
| 90 | + def __init__(self, streaming=False): |
| 91 | + """ |
| 92 | + Create a MsgPack RPC object. |
| 93 | + streaming: If True, messages can span multiple buffers, otherwise a buffer contains |
| 94 | + exactly one full message. Note streaming mode is slower, so it should be disabled |
| 95 | + if it's not needed. |
| 96 | + """ |
| 97 | + self.epts = {} |
| 98 | + self.msgid = 0 |
| 99 | + self.msgbuf = {} |
| 100 | + self.msgio = MsgPackIO() if streaming else None |
| 101 | + self.servers = [] |
| 102 | + |
| 103 | + def _bind_callback(self, src_addr, name): |
| 104 | + if log_level_enabled(logging.INFO): |
| 105 | + logging.info(f'New service announcement src: {src_addr} name: "{name}"') |
| 106 | + self.epts[name] = openamp.RPMsg(name, dst_addr=src_addr, callback=self._recv_callback) |
| 107 | + self.epts[name].send(b"\x00") |
| 108 | + |
| 109 | + def _recv_callback(self, src_addr, data): |
| 110 | + if log_level_enabled(logging.DEBUG): |
| 111 | + logging.debug(f"Received message on endpoint: {src_addr} data: {bytes(data)}") |
| 112 | + |
| 113 | + if self.msgio is None: |
| 114 | + obj = msgpack.unpackb(data) |
| 115 | + self._process_unpacked_obj(obj) |
| 116 | + else: |
| 117 | + self.msgio.feed(data) |
| 118 | + for obj in self.msgio: |
| 119 | + self._process_unpacked_obj(obj) |
| 120 | + |
| 121 | + def _process_unpacked_obj(self, obj): |
| 122 | + if obj[0] == _MSG_TYPE_RESPONSE: |
| 123 | + self.msgbuf[obj[1]] = obj |
| 124 | + elif obj[0] == _MSG_TYPE_REQUEST: |
| 125 | + self._dispatch(obj[1], obj[2], obj[-1]) |
| 126 | + if log_level_enabled(logging.DEBUG): |
| 127 | + logging.debug(f"Unpacked {type(obj)} val: {obj}") |
| 128 | + |
| 129 | + def _send_msg(self, msgid, msgtype, fname, fargs, **kwargs): |
| 130 | + timeout = kwargs.pop("timeout", 1000) |
| 131 | + endpoint = kwargs.pop("endpoint", "rpc") |
| 132 | + self.epts[endpoint].send(msgpack.packb([msgtype, msgid, fname, fargs]), timeout=timeout) |
| 133 | + if msgtype == _MSG_TYPE_REQUEST: |
| 134 | + self.msgid += 1 |
| 135 | + return Future(msgid, self.msgbuf, fname, fargs) |
| 136 | + |
| 137 | + def _dispatch(self, msgid, fname, fargs): |
| 138 | + func = None |
| 139 | + retobj = None |
| 140 | + error = None |
| 141 | + for obj in self.servers: |
| 142 | + if callable(obj) and obj.__name__ == fname: |
| 143 | + func = obj |
| 144 | + elif hasattr(obj, fname): |
| 145 | + func = getattr(obj, fname) |
| 146 | + if func is not None: |
| 147 | + break |
| 148 | + |
| 149 | + if func is not None: |
| 150 | + retobj = func(*fargs) |
| 151 | + else: |
| 152 | + error = "Unbound function called %s" % (fname) |
| 153 | + |
| 154 | + self._send_msg(msgid, _MSG_TYPE_RESPONSE, error, retobj) |
| 155 | + |
| 156 | + def bind(self, obj): |
| 157 | + """ |
| 158 | + Register an object or a function to be called by the remote processor. |
| 159 | + obj: An object whose methods can be called by remote processors, or a function. |
| 160 | + """ |
| 161 | + self.servers.append(obj) |
| 162 | + |
| 163 | + def start(self, firmware=None, num_channels=2, timeout=3000): |
| 164 | + """ |
| 165 | + Initializes OpenAMP, loads the remote processor's firmware and starts. |
| 166 | + firmware: A path to an elf file stored in the filesystem, or an address to an entry point in flash. |
| 167 | + num_channels: The number of channels to wait for the remote processor to |
| 168 | + create before starting to communicate with it. |
| 169 | + timeout: How long to wait for the remote processor to start, 0 means forever. |
| 170 | + """ |
| 171 | + # Initialize OpenAMP. |
| 172 | + openamp.init(ns_callback=self._bind_callback) |
| 173 | + |
| 174 | + # Keep a reference to the remote processor object, to stop the GC from collecting |
| 175 | + # it, which would call the finaliser and shut down the remote processor while it's |
| 176 | + # still being used. |
| 177 | + self.rproc = openamp.RProc(firmware) |
| 178 | + self.rproc.start() |
| 179 | + |
| 180 | + # Wait for remote processor to announce the end points. |
| 181 | + t = ticks_ms() |
| 182 | + while len(self.epts) != num_channels: |
| 183 | + if timeout > 0 and ticks_diff(ticks_ms(), t) > timeout: |
| 184 | + raise OSError("timeout waiting for the remote processor to start") |
| 185 | + sleep_ms(10) |
| 186 | + |
| 187 | + # Introduce a brief delay to allow the M4 sufficient time |
| 188 | + # to bind remote functions before invoking them. |
| 189 | + sleep_ms(100) |
| 190 | + |
| 191 | + def call(self, fname, *args, **kwargs): |
| 192 | + """ |
| 193 | + Synchronous call. The client is blocked until the RPC is finished. |
| 194 | + """ |
| 195 | + return self.call_async(fname, *args, *kwargs).join() |
| 196 | + |
| 197 | + def call_async(self, fname, *args, **kwargs): |
| 198 | + """ |
| 199 | + Asynchronous call. The client returns a Future object immediately. |
| 200 | + """ |
| 201 | + return self._send_msg(self.msgid, _MSG_TYPE_REQUEST, fname, list(args), *kwargs) |
0 commit comments