Skip to content

Commit d2982d3

Browse files
author
Me No Dev
committed
Make ArduinoOTA AUTH async
still up to the user to call ArduinoOTA.handle() to start the upload
1 parent 20f372a commit d2982d3

File tree

2 files changed

+147
-96
lines changed

2 files changed

+147
-96
lines changed

libraries/ArduinoOTA/ArduinoOTA.cpp

Lines changed: 139 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,23 @@
1-
#include <ESP8266WiFi.h>
2-
#include <ESP8266mDNS.h>
1+
#define LWIP_OPEN_SRC
2+
#include <functional>
33
#include <WiFiUdp.h>
44
#include "ArduinoOTA.h"
55
#include "MD5Builder.h"
66

7+
extern "C" {
8+
#include "osapi.h"
9+
#include "ets_sys.h"
10+
#include "user_interface.h"
11+
}
12+
13+
#include "lwip/opt.h"
14+
#include "lwip/udp.h"
15+
#include "lwip/inet.h"
16+
#include "lwip/igmp.h"
17+
#include "lwip/mem.h"
18+
#include "include/UdpContext.h"
19+
#include <ESP8266mDNS.h>
20+
721
//#define OTA_DEBUG 1
822

923
ArduinoOTAClass::ArduinoOTAClass()
@@ -16,6 +30,7 @@ ArduinoOTAClass::ArduinoOTAClass()
1630
, _end_callback(NULL)
1731
, _progress_callback(NULL)
1832
, _error_callback(NULL)
33+
, _udp_ota(0)
1934
{
2035
}
2136

@@ -59,7 +74,6 @@ void ArduinoOTAClass::setPassword(const char * password) {
5974
void ArduinoOTAClass::begin() {
6075
if (_initialized)
6176
return;
62-
_initialized = true;
6377

6478
if (!_hostname.length()) {
6579
char tmp[15];
@@ -70,20 +84,136 @@ void ArduinoOTAClass::begin() {
7084
_port = 8266;
7185
}
7286

73-
_udp_ota.begin(_port);
87+
_udp_ota = new UdpContext;
88+
_udp_ota->ref();
89+
90+
if(!_udp_ota->listen(*IP_ADDR_ANY, _port))
91+
return;
92+
_udp_ota->onRx(std::bind(&ArduinoOTAClass::_onRx, this));
7493
MDNS.begin(_hostname.c_str());
7594

7695
if (_password.length()) {
7796
MDNS.enableArduino(_port, true);
7897
} else {
7998
MDNS.enableArduino(_port);
8099
}
100+
_initialized = true;
81101
_state = OTA_IDLE;
82102
#if OTA_DEBUG
83103
Serial.printf("OTA server at: %s.local:%u\n", _hostname.c_str(), _port);
84104
#endif
85105
}
86106

107+
int ArduinoOTAClass::parseInt(){
108+
char data[16];
109+
uint8_t index = 0;
110+
char value;
111+
while(_udp_ota->peek() == ' ') _udp_ota->read();
112+
while(true){
113+
value = _udp_ota->peek();
114+
if(value < '0' || value > '9'){
115+
data[index++] = '\0';
116+
return atoi(data);
117+
}
118+
data[index++] = _udp_ota->read();
119+
}
120+
return 0;
121+
}
122+
123+
String ArduinoOTAClass::readStringUntil(char end){
124+
String res = "";
125+
char value;
126+
while(true){
127+
value = _udp_ota->read();
128+
if(value == '\0' || value == end){
129+
return res;
130+
}
131+
res += value;
132+
}
133+
return res;
134+
}
135+
136+
void ArduinoOTAClass::_onRx(){
137+
if(!_udp_ota->next()) return;
138+
ip_addr_t ota_ip;
139+
140+
if (_state == OTA_IDLE) {
141+
int cmd = parseInt();
142+
if (cmd != U_FLASH && cmd != U_SPIFFS)
143+
return;
144+
_ota_ip = _udp_ota->getRemoteAddress();
145+
_cmd = cmd;
146+
_ota_port = parseInt();
147+
_size = parseInt();
148+
_udp_ota->read();
149+
_md5 = readStringUntil('\n');
150+
_md5.trim();
151+
if(_md5.length() != 32)
152+
return;
153+
154+
ota_ip.addr = (uint32_t)_ota_ip;
155+
156+
if (_password.length()){
157+
MD5Builder nonce_md5;
158+
nonce_md5.begin();
159+
nonce_md5.add(String(micros()));
160+
nonce_md5.calculate();
161+
_nonce = nonce_md5.toString();
162+
163+
char auth_req[38];
164+
sprintf(auth_req, "AUTH %s", _nonce.c_str());
165+
_udp_ota->append((const char *)auth_req, strlen(auth_req));
166+
_udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
167+
_state = OTA_WAITAUTH;
168+
return;
169+
} else {
170+
_udp_ota->append("OK", 2);
171+
_udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
172+
_state = OTA_RUNUPDATE;
173+
}
174+
} else if (_state == OTA_WAITAUTH) {
175+
int cmd = parseInt();
176+
if (cmd != U_AUTH) {
177+
_state = OTA_IDLE;
178+
return;
179+
}
180+
_udp_ota->read();
181+
String cnonce = readStringUntil(' ');
182+
String response = readStringUntil('\n');
183+
if (cnonce.length() != 32 || response.length() != 32) {
184+
_state = OTA_IDLE;
185+
return;
186+
}
187+
188+
MD5Builder _passmd5;
189+
_passmd5.begin();
190+
_passmd5.add(_password);
191+
_passmd5.calculate();
192+
String passmd5 = _passmd5.toString();
193+
194+
String challenge = passmd5 + ":" + String(_nonce) + ":" + cnonce;
195+
MD5Builder _challengemd5;
196+
_challengemd5.begin();
197+
_challengemd5.add(challenge);
198+
_challengemd5.calculate();
199+
String result = _challengemd5.toString();
200+
201+
ota_ip.addr = (uint32_t)_ota_ip;
202+
if(result.equals(response)){
203+
_udp_ota->append("OK", 2);
204+
_udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
205+
_state = OTA_RUNUPDATE;
206+
} else {
207+
_udp_ota->append("Authentication Failed", 21);
208+
_udp_ota->send(&ota_ip, _udp_ota->getRemotePort());
209+
if (_error_callback) _error_callback(OTA_AUTH_ERROR);
210+
_state = OTA_IDLE;
211+
}
212+
}
213+
214+
while(_udp_ota->next()) _udp_ota->flush();
215+
}
216+
87217
void ArduinoOTAClass::_runUpdate() {
88218
if (!Update.begin(_size, _cmd)) {
89219
#if OTA_DEBUG
@@ -92,7 +222,7 @@ void ArduinoOTAClass::_runUpdate() {
92222
if (_error_callback) {
93223
_error_callback(OTA_BEGIN_ERROR);
94224
}
95-
_udp_ota.begin(_port);
225+
_udp_ota->listen(*IP_ADDR_ANY, _port);
96226
_state = OTA_IDLE;
97227
return;
98228
}
@@ -112,7 +242,7 @@ void ArduinoOTAClass::_runUpdate() {
112242
#if OTA_DEBUG
113243
Serial.printf("Connect Failed\n");
114244
#endif
115-
_udp_ota.begin(_port);
245+
_udp_ota->listen(*IP_ADDR_ANY, _port);
116246
if (_error_callback) {
117247
_error_callback(OTA_CONNECT_ERROR);
118248
}
@@ -128,7 +258,7 @@ void ArduinoOTAClass::_runUpdate() {
128258
#if OTA_DEBUG
129259
Serial.printf("Recieve Failed\n");
130260
#endif
131-
_udp_ota.begin(_port);
261+
_udp_ota->listen(*IP_ADDR_ANY, _port);
132262
if (_error_callback) {
133263
_error_callback(OTA_RECIEVE_ERROR);
134264
}
@@ -156,7 +286,7 @@ void ArduinoOTAClass::_runUpdate() {
156286
}
157287
ESP.restart();
158288
} else {
159-
_udp_ota.begin(_port);
289+
_udp_ota->listen(*IP_ADDR_ANY, _port);
160290
if (_error_callback) {
161291
_error_callback(OTA_END_ERROR);
162292
}
@@ -169,94 +299,9 @@ void ArduinoOTAClass::_runUpdate() {
169299
}
170300

171301
void ArduinoOTAClass::handle() {
172-
if (!_udp_ota) {
173-
_udp_ota.begin(_port);
174-
#if OTA_DEBUG
175-
Serial.println("OTA restarted");
176-
#endif
177-
}
178-
179-
if (!_udp_ota.parsePacket()) return;
180-
181-
if (_state == OTA_IDLE) {
182-
int cmd = _udp_ota.parseInt();
183-
if (cmd != U_FLASH && cmd != U_SPIFFS)
184-
return;
185-
_ota_ip = _udp_ota.remoteIP();
186-
_cmd = cmd;
187-
_ota_port = _udp_ota.parseInt();
188-
_size = _udp_ota.parseInt();
189-
_udp_ota.read();
190-
_md5 = _udp_ota.readStringUntil('\n');
191-
_md5.trim();
192-
if(_md5.length() != 32)
193-
return;
194-
195-
#if OTA_DEBUG
196-
Serial.print("Update Start: ip:");
197-
Serial.print(_ota_ip);
198-
Serial.printf(", port:%d, size:%d, md5:%s\n", _ota_port, _size, _md5.c_str());
199-
#endif
200-
201-
_udp_ota.beginPacket(_ota_ip, _udp_ota.remotePort());
202-
if (_password.length()){
203-
MD5Builder nonce_md5;
204-
nonce_md5.begin();
205-
nonce_md5.add(String(micros()));
206-
nonce_md5.calculate();
207-
_nonce = nonce_md5.toString();
208-
_udp_ota.printf("AUTH %s", _nonce.c_str());
209-
_udp_ota.endPacket();
210-
_state = OTA_WAITAUTH;
211-
return;
212-
} else {
213-
_udp_ota.print("OK");
214-
_udp_ota.endPacket();
215-
_state = OTA_RUNUPDATE;
216-
}
217-
} else if (_state == OTA_WAITAUTH) {
218-
int cmd = _udp_ota.parseInt();
219-
if (cmd != U_AUTH) {
220-
_state = OTA_IDLE;
221-
return;
222-
}
223-
_udp_ota.read();
224-
String cnonce = _udp_ota.readStringUntil(' ');
225-
String response = _udp_ota.readStringUntil('\n');
226-
if (cnonce.length() != 32 || response.length() != 32) {
227-
_state = OTA_IDLE;
228-
return;
229-
}
230-
231-
MD5Builder _passmd5;
232-
_passmd5.begin();
233-
_passmd5.add(_password);
234-
_passmd5.calculate();
235-
String passmd5 = _passmd5.toString();
236-
237-
String challenge = passmd5 + ":" + String(_nonce) + ":" + cnonce;
238-
MD5Builder _challengemd5;
239-
_challengemd5.begin();
240-
_challengemd5.add(challenge);
241-
_challengemd5.calculate();
242-
String result = _challengemd5.toString();
243-
244-
if(result.equals(response)){
245-
_udp_ota.beginPacket(_ota_ip, _udp_ota.remotePort());
246-
_udp_ota.print("OK");
247-
_udp_ota.endPacket();
248-
_state = OTA_RUNUPDATE;
249-
} else {
250-
_udp_ota.beginPacket(_ota_ip, _udp_ota.remotePort());
251-
_udp_ota.print("Authentication Failed");
252-
_udp_ota.endPacket();
253-
if (_error_callback) _error_callback(OTA_AUTH_ERROR);
254-
_state = OTA_IDLE;
255-
}
256-
}
257-
258302
if (_state == OTA_RUNUPDATE) {
259303
_runUpdate();
304+
_state = OTA_IDLE;
260305
}
261306
}
262307

libraries/ArduinoOTA/ArduinoOTA.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
#ifndef __ARDUINO_OTA_H
22
#define __ARDUINO_OTA_H
33

4-
class WiFiUDP;
4+
#include <ESP8266WiFi.h>
5+
#include <WiFiUdp.h>
6+
7+
class UdpContext;
58

69
#define OTA_CALLBACK(callback) void (*callback)()
710
#define OTA_CALLBACK_PROGRESS(callback) void (*callback)(unsigned int, unsigned int)
@@ -41,7 +44,7 @@ class ArduinoOTAClass
4144
String _password;
4245
String _hostname;
4346
String _nonce;
44-
WiFiUDP _udp_ota;
47+
UdpContext *_udp_ota;
4548
bool _initialized;
4649
ota_state_t _state;
4750
int _size;
@@ -56,6 +59,9 @@ class ArduinoOTAClass
5659
OTA_CALLBACK_PROGRESS(_progress_callback);
5760

5861
void _runUpdate(void);
62+
void _onRx(void);
63+
int parseInt(void);
64+
String readStringUntil(char end);
5965
};
6066

6167
extern ArduinoOTAClass ArduinoOTA;

0 commit comments

Comments
 (0)