Skip to content
This repository was archived by the owner on Sep 16, 2024. It is now read-only.

Added PyBytes libraries as frozen for devices with stricter memory co… #135

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 268 additions & 0 deletions esp32/frozen/OTA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
import network
import socket
import ssl
import machine
import ujson
import uhashlib
import ubinascii
import gc
import pycom
import os
from binascii import hexlify

# Try to get version number
# try:
# from OTA_VERSION import VERSION
# except ImportError:
# VERSION = '1.0.0'


class OTA():
# The following two methods need to be implemented in a subclass for the
# specific transport mechanism e.g. WiFi

def connect(self):
raise NotImplementedError()

def get_data(self, req, dest_path=None, hash=False):
raise NotImplementedError()

# OTA methods

def get_current_version(self):
return os.uname().release

def get_update_manifest(self):
current_version = self.get_current_version()
sysname = os.uname().sysname
wmac = hexlify(machine.unique_id()).decode('ascii')
request_template = "manifest.json?current_ver={}&sysname={}&wmac={}"
req = request_template.format(current_version, sysname, wmac)
manifest_data = self.get_data(req).decode()
manifest = ujson.loads(manifest_data)
gc.collect()
return manifest

def reboot(self):
machine.reset()

def update(self):
try:
manifest = self.get_update_manifest()
except Exception as e:
print('Error reading the manifest, aborting: {}'.format(e))
return 0

if manifest is None:
print("Already on the latest version")
return 1

# Download new files and verify hashes
for f in manifest['new'] + manifest['update']:
# Upto 5 retries
for _ in range(5):
try:
self.get_file(f)
break
except Exception as e:
print(e)
msg = "Error downloading `{}` retrying..."
print(msg.format(f['URL']))
return 0
else:
raise Exception("Failed to download `{}`".format(f['URL']))

# Backup old files
# only once all files have been successfully downloaded
for f in manifest['update']:
self.backup_file(f)

# Rename new files to proper name
for f in manifest['new'] + manifest['update']:
new_path = "{}.new".format(f['dst_path'])
dest_path = "{}".format(f['dst_path'])

os.rename(new_path, dest_path)

# `Delete` files no longer required
# This actually makes a backup of the files incase we need to roll back
for f in manifest['delete']:
self.delete_file(f)

# Flash firmware
if "firmware" in manifest:
self.write_firmware(manifest['firmware'])

# Save version number
# try:
# self.backup_file({"dst_path": "/flash/OTA_VERSION.py"})
# except OSError:
# pass # There isnt a previous file to backup
# with open("/flash/OTA_VERSION.py", 'w') as fp:
# fp.write("VERSION = '{}'".format(manifest['version']))
# from OTA_VERSION import VERSION

return 2

def get_file(self, f):
new_path = "{}.new".format(f['dst_path'])

# If a .new file exists from a previously failed update delete it
try:
os.remove(new_path)
except OSError:
pass # The file didnt exist

# Download new file with a .new extension to not overwrite the existing
# file until the hash is verified.
hash = self.get_data(f['URL'].split("/", 3)[-1],
dest_path=new_path,
hash=True)

# Hash mismatch
if hash != f['hash']:
print(hash, f['hash'])
msg = "Downloaded file's hash does not match expected hash"
raise Exception(msg)

def backup_file(self, f):
bak_path = "{}.bak".format(f['dst_path'])
dest_path = "{}".format(f['dst_path'])

# Delete previous backup if it exists
try:
os.remove(bak_path)
except OSError:
pass # There isnt a previous backup

# Backup current file
os.rename(dest_path, bak_path)

def delete_file(self, f):
bak_path = "/{}.bak_del".format(f)
dest_path = "/{}".format(f)

# Delete previous delete backup if it exists
try:
os.remove(bak_path)
except OSError:
pass # There isnt a previous delete backup

# Backup current file
os.rename(dest_path, bak_path)

def write_firmware(self, f):
hash = self.get_data(f['URL'].split("/", 3)[-1],
hash=True,
firmware=True)
# TODO: Add verification when released in future firmware


class WiFiOTA(OTA):
def __init__(self, ssid, password, ip, port):
self.SSID = ssid
self.password = password
self.ip = ip
self.port = port

def connect(self):
self.wlan = network.WLAN(mode=network.WLAN.STA)
if not self.wlan.isconnected() or self.wlan.ssid() != self.SSID:
for net in self.wlan.scan():
if net.ssid == self.SSID:
self.wlan.connect(self.SSID, auth=(network.WLAN.WPA2,
self.password))
while not self.wlan.isconnected():
machine.idle() # save power while waiting
break
else:
raise Exception("Cannot find network '{}'".format(self.SSID))
else:
# Already connected to the correct WiFi
pass

def _http_get(self, path, host):
req_fmt = 'GET /{} HTTP/1.0\r\nHost: {}\r\n\r\n'
req = bytes(req_fmt.format(path, host), 'utf8')
return req

def get_data(self, req, dest_path=None, hash=False, firmware=False):
h = None

useSSL = int(self.port) == 443

# Connect to server
print("Requesting: {} to {}:{} with SSL? {}".format(req, self.ip, self.port, useSSL))
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
s.connect(socket.getaddrinfo(self.ip, self.port)[0][-1])
if (int(self.port) == 443):
print("Wrapping socket")
s = ssl.wrap_socket(s)

print("Sending request")
# Request File
s.sendall(self._http_get(req, "{}:{}".format(self.ip, self.port)))

try:
content = bytearray()
fp = None
if dest_path is not None:
print('dest_path {}'.format(dest_path))
if firmware:
raise Exception("Cannot write firmware to a file")
fp = open(dest_path, 'wb')

if firmware:
print('start')
pycom.ota_start()

h = uhashlib.sha1()

# Get data from server
result = s.recv(50)

start_writing = False
while (len(result) > 0):
# Ignore the HTTP headers
if not start_writing:
if "\r\n\r\n" in result:
start_writing = True
result = result.decode().split("\r\n\r\n")[1].encode()

if start_writing:
if firmware:
pycom.ota_write(result)
elif fp is None:
content.extend(result)
else:
fp.write(result)

if hash:
h.update(result)

result = s.recv(50)

s.close()

if fp is not None:
fp.close()
if firmware:
pycom.ota_finish()

except Exception as e:
gc.mem_free()
# Since only one hash operation is allowed at Once
# ensure we close it if there is an error
if h is not None:
h.digest()
raise e

hash_val = ubinascii.hexlify(h.digest()).decode()

if dest_path is None:
if hash:
return (bytes(content), hash_val)
else:
return bytes(content)
elif hash:
return hash_val
89 changes: 89 additions & 0 deletions esp32/frozen/mqtt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import time
import mqtt_core


class MQTTClient:

DELAY = 5
# This errors are thrown by connect function, I wouldn't be able to find
# anywhere a complete list of these error codes
ERRORS = {
'-1': 'MQTTClient: Can\'t connect to MQTT server',
'-4': 'MQTTClient: Bad credentials'
}

def __init__(self, client_id, server, port=0, user=None, password=None, keepalive=0, ssl=False,
ssl_params={}, reconnect=True):
self.__reconnect = reconnect
self.__mqtt = mqtt_core.MQTTClient(client_id, server, port, user, password, keepalive,
ssl, ssl_params)

def getError(self, x):
"""Return a human readable error instead of its code number"""
message = str(x)
return self.ERRORS.get(str(x), 'Unknown error ' + message)

def connect(self, clean_session=True):
i = 0
while 1:
try:
return self.__mqtt.connect(clean_session)
except OSError as e:
print(self.getError(e))
if (not self.__reconnect):
raise Exception('Reconnection Disabled.')
i += 1
self.delay(i)

def set_callback(self, f):
self.__mqtt.set_callback(f)

def subscribe(self, topic, qos=0):
self.__mqtt.subscribe(topic, qos)

def check_msg(self):
while 1:
try:
return self.__mqtt.check_msg()
except OSError as e:
print("Error check_msg", e)

if (not self.__reconnect):
raise Exception('Reconnection Disabled.')
self.reconnect()

def delay(self, i):
time.sleep(self.DELAY)

def reconnect(self):
print("Reconnecting...")
i = 0
while 1:
try:
return self.__mqtt.connect(True)
except OSError as e:
print('reconnect error', e)
i += 1
self.delay(i)

def publish(self, topic, msg, retain=False, qos=0):
while 1:
try:
return self.__mqtt.publish(topic, msg, retain, qos)
except OSError as e:
print("Error publish", e)

if (not self.__reconnect):
raise Exception('Reconnection Disabled.')
self.reconnect()

def wait_msg(self):
while 1:
try:
return self.__mqtt.wait_msg()
except OSError as e:
print("Error wait_msg {}".format(e))

if (not self.__reconnect):
raise Exception('Reconnection Disabled.')
self.reconnect()
Loading