* Support asynchronous WiFi scan * Fixed buffer overflow in getChipId * ESP compatibility fixes * fixup! ESP compatibility fixes
423 lines
11 KiB
C++
423 lines
11 KiB
C++
/*
|
|
Arduino OTA.cpp - Simple Arduino IDE OTA handler
|
|
Modified 2022 Earle F. Philhower, III. All rights reserved.
|
|
|
|
Taken from the ESP8266 core libraries, (c) various authors.
|
|
|
|
This library is free software; you can redistribute it and/or
|
|
modify it under the terms of the GNU Lesser General Public
|
|
License as published by the Free Software Foundation; either
|
|
version 2.1 of the License, or (at your option) any later version.
|
|
|
|
This library is distributed in the hope that it will be useful,
|
|
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
|
|
Lesser General Public License for more details.
|
|
|
|
You should have received a copy of the GNU Lesser General Public
|
|
License along with this library; if not, write to the Free Software
|
|
Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
|
|
*/
|
|
|
|
#include <functional>
|
|
#include <WiFiUdp.h>
|
|
#include "ArduinoOTA.h"
|
|
#include "MD5Builder.h"
|
|
#include <PicoOTA.h>
|
|
#include <StreamString.h>
|
|
|
|
#include "lwip/udp.h"
|
|
#include "include/UdpContext.h"
|
|
|
|
#if !defined(NO_GLOBAL_INSTANCES) && !defined(NO_GLOBAL_MDNS)
|
|
#include <LEAmDNS.h>
|
|
#endif
|
|
|
|
//#ifdef DEBUG_ESP_OTA
|
|
//#ifdef DEBUG_ESP_PORT
|
|
//#define OTA_DEBUG DEBUG_ESP_PORT
|
|
//#endif
|
|
//#endif
|
|
#define OTA_DEBUG Serial
|
|
|
|
ArduinoOTAClass::ArduinoOTAClass() {
|
|
}
|
|
|
|
ArduinoOTAClass::~ArduinoOTAClass() {
|
|
if (_udp_ota) {
|
|
_udp_ota->unref();
|
|
_udp_ota = 0;
|
|
}
|
|
}
|
|
|
|
void ArduinoOTAClass::onStart(THandlerFunction fn) {
|
|
_start_callback = fn;
|
|
}
|
|
|
|
void ArduinoOTAClass::onEnd(THandlerFunction fn) {
|
|
_end_callback = fn;
|
|
}
|
|
|
|
void ArduinoOTAClass::onProgress(THandlerFunction_Progress fn) {
|
|
_progress_callback = fn;
|
|
}
|
|
|
|
void ArduinoOTAClass::onError(THandlerFunction_Error fn) {
|
|
_error_callback = fn;
|
|
}
|
|
|
|
void ArduinoOTAClass::setPort(uint16_t port) {
|
|
if (!_initialized && !_port && port) {
|
|
_port = port;
|
|
}
|
|
}
|
|
|
|
void ArduinoOTAClass::setHostname(const char * hostname) {
|
|
if (!_initialized && !_hostname.length() && hostname) {
|
|
_hostname = hostname;
|
|
}
|
|
}
|
|
|
|
String ArduinoOTAClass::getHostname() {
|
|
return _hostname;
|
|
}
|
|
|
|
void ArduinoOTAClass::setPassword(const char * password) {
|
|
if (!_initialized && !_password.length() && password) {
|
|
MD5Builder passmd5;
|
|
passmd5.begin();
|
|
passmd5.add(password);
|
|
passmd5.calculate();
|
|
_password = passmd5.toString();
|
|
}
|
|
}
|
|
|
|
void ArduinoOTAClass::setPasswordHash(const char * password) {
|
|
if (!_initialized && !_password.length() && password) {
|
|
_password = password;
|
|
}
|
|
}
|
|
|
|
void ArduinoOTAClass::setRebootOnSuccess(bool reboot) {
|
|
_rebootOnSuccess = reboot;
|
|
}
|
|
|
|
void ArduinoOTAClass::begin(bool useMDNS) {
|
|
if (_initialized) {
|
|
return;
|
|
}
|
|
|
|
_useMDNS = useMDNS;
|
|
|
|
if (!_hostname.length()) {
|
|
char tmp[2 * PICO_UNIQUE_BOARD_ID_SIZE_BYTES + 6];
|
|
sprintf(tmp, "pico-%s", rp2040.getChipID());
|
|
_hostname = tmp;
|
|
}
|
|
if (!_port) {
|
|
_port = 2040;
|
|
}
|
|
|
|
if (_udp_ota) {
|
|
_udp_ota->unref();
|
|
_udp_ota = 0;
|
|
}
|
|
|
|
_udp_ota = new UdpContext;
|
|
_udp_ota->ref();
|
|
|
|
if (!_udp_ota->listen(IP_ADDR_ANY, _port)) {
|
|
return;
|
|
}
|
|
_udp_ota->onRx(std::bind(&ArduinoOTAClass::_onRx, this));
|
|
|
|
#if !defined(NO_GLOBAL_INSTANCES) && !defined(NO_GLOBAL_MDNS)
|
|
if (_useMDNS) {
|
|
MDNS.begin(_hostname.c_str());
|
|
|
|
if (_password.length()) {
|
|
MDNS.enableArduino(_port, true);
|
|
} else {
|
|
MDNS.enableArduino(_port);
|
|
}
|
|
}
|
|
#endif
|
|
_initialized = true;
|
|
_state = OTA_IDLE;
|
|
#ifdef OTA_DEBUG
|
|
OTA_DEBUG.printf("OTA server at: %s.local:%u\n", _hostname.c_str(), _port);
|
|
#endif
|
|
}
|
|
|
|
int ArduinoOTAClass::parseInt() {
|
|
char data[16];
|
|
uint8_t index;
|
|
char value;
|
|
while (_udp_ota->peek() == ' ') {
|
|
_udp_ota->read();
|
|
}
|
|
for (index = 0; index < sizeof(data); ++index) {
|
|
value = _udp_ota->peek();
|
|
if (value < '0' || value > '9') {
|
|
data[index] = '\0';
|
|
return atoi(data);
|
|
}
|
|
data[index] = _udp_ota->read();
|
|
}
|
|
return 0;
|
|
}
|
|
|
|
String ArduinoOTAClass::readStringUntil(char end) {
|
|
String res;
|
|
int value;
|
|
while (true) {
|
|
value = _udp_ota->read();
|
|
if (value < 0 || value == '\0' || value == end) {
|
|
return res;
|
|
}
|
|
res += static_cast<char>(value);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
void ArduinoOTAClass::_onRx() {
|
|
if (!_udp_ota->next()) {
|
|
return;
|
|
}
|
|
IPAddress ota_ip;
|
|
|
|
if (_state == OTA_IDLE) {
|
|
int cmd = parseInt();
|
|
if (cmd != U_FLASH && cmd != U_FS) {
|
|
return;
|
|
}
|
|
_ota_ip = _udp_ota->getRemoteAddress();
|
|
_cmd = cmd;
|
|
_ota_port = parseInt();
|
|
_ota_udp_port = _udp_ota->getRemotePort();
|
|
_size = parseInt();
|
|
_udp_ota->read();
|
|
_md5 = readStringUntil('\n');
|
|
_md5.trim();
|
|
if (_md5.length() != 32) {
|
|
return;
|
|
}
|
|
|
|
ota_ip = _ota_ip;
|
|
|
|
if (_password.length()) {
|
|
MD5Builder nonce_md5;
|
|
nonce_md5.begin();
|
|
nonce_md5.add(String(micros()));
|
|
nonce_md5.calculate();
|
|
_nonce = nonce_md5.toString();
|
|
|
|
char auth_req[38];
|
|
sprintf(auth_req, "AUTH %s", _nonce.c_str());
|
|
_udp_ota->append((const char *)auth_req, strlen(auth_req));
|
|
_udp_ota->send(ota_ip, _ota_udp_port);
|
|
_state = OTA_WAITAUTH;
|
|
return;
|
|
} else {
|
|
_state = OTA_RUNUPDATE;
|
|
}
|
|
} else if (_state == OTA_WAITAUTH) {
|
|
int cmd = parseInt();
|
|
if (cmd != U_AUTH) {
|
|
_state = OTA_IDLE;
|
|
return;
|
|
}
|
|
_udp_ota->read();
|
|
String cnonce = readStringUntil(' ');
|
|
String response = readStringUntil('\n');
|
|
if (cnonce.length() != 32 || response.length() != 32) {
|
|
_state = OTA_IDLE;
|
|
return;
|
|
}
|
|
|
|
String challenge = _password + ':' + String(_nonce) + ':' + cnonce;
|
|
MD5Builder _challengemd5;
|
|
_challengemd5.begin();
|
|
_challengemd5.add(challenge);
|
|
_challengemd5.calculate();
|
|
String result = _challengemd5.toString();
|
|
|
|
ota_ip = _ota_ip;
|
|
// if(result.equalsConstantTime(response)) {
|
|
if (result.equals(response)) {
|
|
_state = OTA_RUNUPDATE;
|
|
} else {
|
|
_udp_ota->append("Authentication Failed", 21);
|
|
_udp_ota->send(ota_ip, _ota_udp_port);
|
|
if (_error_callback) {
|
|
_error_callback(OTA_AUTH_ERROR);
|
|
}
|
|
_state = OTA_IDLE;
|
|
}
|
|
}
|
|
|
|
while (_udp_ota->next()) {
|
|
_udp_ota->flush();
|
|
}
|
|
}
|
|
|
|
void ArduinoOTAClass::_runUpdate() {
|
|
IPAddress ota_ip = _ota_ip;
|
|
|
|
if (!LittleFS.begin()) {
|
|
#ifdef OTA_DEBUG
|
|
OTA_DEBUG.println("LittleFS Begin Error");
|
|
#endif
|
|
_udp_ota->append("ERR: ", 5);
|
|
_udp_ota->append("No Filesystem", 13);
|
|
_udp_ota->send(ota_ip, _ota_udp_port);
|
|
delay(100);
|
|
_udp_ota->listen(IP_ADDR_ANY, _port);
|
|
_state = OTA_IDLE;
|
|
return;
|
|
}
|
|
|
|
if (!Update.begin(_size, _cmd)) {
|
|
#ifdef OTA_DEBUG
|
|
OTA_DEBUG.println("Update Begin Error");
|
|
#endif
|
|
if (_error_callback) {
|
|
_error_callback(OTA_BEGIN_ERROR);
|
|
}
|
|
|
|
StreamString ss;
|
|
Update.printError(ss);
|
|
_udp_ota->append("ERR: ", 5);
|
|
_udp_ota->append(ss.c_str(), ss.length());
|
|
_udp_ota->send(ota_ip, _ota_udp_port);
|
|
delay(100);
|
|
_udp_ota->listen(IP_ADDR_ANY, _port);
|
|
_state = OTA_IDLE;
|
|
return;
|
|
}
|
|
|
|
_udp_ota->append("OK", 2);
|
|
_udp_ota->send(ota_ip, _ota_udp_port);
|
|
delay(100);
|
|
|
|
Update.setMD5(_md5.c_str());
|
|
|
|
if (_start_callback) {
|
|
_start_callback();
|
|
}
|
|
if (_progress_callback) {
|
|
_progress_callback(0, _size);
|
|
}
|
|
|
|
WiFiClient client;
|
|
if (!client.connect(_ota_ip, _ota_port)) {
|
|
#ifdef OTA_DEBUG
|
|
OTA_DEBUG.printf("Connect Failed\n");
|
|
#endif
|
|
_udp_ota->listen(IP_ADDR_ANY, _port);
|
|
if (_error_callback) {
|
|
_error_callback(OTA_CONNECT_ERROR);
|
|
}
|
|
_state = OTA_IDLE;
|
|
}
|
|
// OTA sends little packets
|
|
client.setNoDelay(true);
|
|
|
|
uint32_t written, total = 0;
|
|
while (!Update.isFinished() && (client.connected() || client.available())) {
|
|
int waited = 1000;
|
|
while (!client.available() && waited--) {
|
|
delay(1);
|
|
}
|
|
if (!waited) {
|
|
#ifdef OTA_DEBUG
|
|
OTA_DEBUG.printf("Receive Failed\n");
|
|
#endif
|
|
_udp_ota->listen(IP_ADDR_ANY, _port);
|
|
if (_error_callback) {
|
|
_error_callback(OTA_RECEIVE_ERROR);
|
|
}
|
|
_state = OTA_IDLE;
|
|
}
|
|
written = Update.write(client);
|
|
if (written > 0) {
|
|
client.print(written, DEC);
|
|
total += written;
|
|
if (_progress_callback) {
|
|
_progress_callback(total, _size);
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
if (Update.end()) {
|
|
// Ensure last count packet has been sent out and not combined with the final OK
|
|
client.flush();
|
|
delay(1000);
|
|
client.print("OK");
|
|
client.flush();
|
|
delay(1000);
|
|
client.stop();
|
|
#ifdef OTA_DEBUG
|
|
OTA_DEBUG.printf("Update Success\n");
|
|
#endif
|
|
if (_end_callback) {
|
|
_end_callback();
|
|
}
|
|
if (_rebootOnSuccess) {
|
|
#ifdef OTA_DEBUG
|
|
OTA_DEBUG.printf("Rebooting...\n");
|
|
#endif
|
|
LittleFS.end();
|
|
//let serial/network finish tasks that might be given in _end_callback
|
|
delay(100);
|
|
rp2040.reboot();
|
|
}
|
|
} else {
|
|
_udp_ota->listen(IP_ADDR_ANY, _port);
|
|
if (_error_callback) {
|
|
_error_callback(OTA_END_ERROR);
|
|
}
|
|
Update.printError(client);
|
|
#ifdef OTA_DEBUG
|
|
Update.printError(OTA_DEBUG);
|
|
#endif
|
|
_state = OTA_IDLE;
|
|
}
|
|
}
|
|
|
|
void ArduinoOTAClass::end() {
|
|
_initialized = false;
|
|
_udp_ota->unref();
|
|
_udp_ota = 0;
|
|
#if !defined(NO_GLOBAL_INSTANCES) && !defined(NO_GLOBAL_MDNS)
|
|
if (_useMDNS) {
|
|
MDNS.end();
|
|
}
|
|
#endif
|
|
_state = OTA_IDLE;
|
|
#ifdef OTA_DEBUG
|
|
OTA_DEBUG.printf("OTA server stopped.\n");
|
|
#endif
|
|
}
|
|
//this needs to be called in the loop()
|
|
void ArduinoOTAClass::handle() {
|
|
if (_state == OTA_RUNUPDATE) {
|
|
_runUpdate();
|
|
_state = OTA_IDLE;
|
|
}
|
|
|
|
#if !defined(NO_GLOBAL_INSTANCES) && !defined(NO_GLOBAL_MDNS)
|
|
if (_useMDNS) {
|
|
MDNS.update(); //handle MDNS update as well, given that ArduinoOTA relies on it anyways
|
|
}
|
|
#endif
|
|
}
|
|
|
|
int ArduinoOTAClass::getCommand() {
|
|
return _cmd;
|
|
}
|
|
|
|
#if !defined(NO_GLOBAL_INSTANCES) && !defined(NO_GLOBAL_ARDUINOOTA)
|
|
ArduinoOTAClass ArduinoOTA;
|
|
#endif
|