feat(webserver): Middleware with default middleware for cors, authc, curl-like logging (#10750)

* feat(webserver): Middleware with default middleware for cors, authc, curl-like logging

* ci(pre-commit): Apply automatic fixes

---------

Co-authored-by: Rodrigo Garcia <rodrigo.garcia@espressif.com>
Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com>
This commit is contained in:
Mathieu Carbou 2025-01-07 11:00:50 +01:00 committed by GitHub
parent 089cbabf17
commit b07eb175d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 893 additions and 99 deletions

View file

@ -242,7 +242,11 @@ set(ARDUINO_LIBRARY_USB_SRCS
set(ARDUINO_LIBRARY_WebServer_SRCS
libraries/WebServer/src/WebServer.cpp
libraries/WebServer/src/Parsing.cpp
libraries/WebServer/src/detail/mimetable.cpp)
libraries/WebServer/src/detail/mimetable.cpp
libraries/WebServer/src/middleware/MiddlewareChain.cpp
libraries/WebServer/src/middleware/AuthenticationMiddleware.cpp
libraries/WebServer/src/middleware/CorsMiddleware.cpp
libraries/WebServer/src/middleware/LoggingMiddleware.cpp)
set(ARDUINO_LIBRARY_NetworkClientSecure_SRCS
libraries/NetworkClientSecure/src/ssl_client.cpp

View file

@ -0,0 +1,186 @@
/**
* Basic example of using Middlewares with WebServer
*
* Middleware are common request/response processing functions that can be applied globally to all incoming requests or to specific handlers.
* They allow for a common processing thus saving memory and space to avoid duplicating code or states on multiple handlers.
*
* Once the example is flashed (with the correct WiFi credentials), you can test the following scenarios with the listed curl commands:
* - CORS Middleware: answers to OPTIONS requests with the specified CORS headers and also add CORS headers to the response when the request has the Origin header
* - Logging Middleware: logs the request and response to an output in a curl-like format
* - Authentication Middleware: test the authentication with Digest Auth
*
* You can also add your own Middleware by extending the Middleware class and implementing the run method.
* When implementing a Middleware, you can decide when to call the next Middleware in the chain by calling next().
*
* Middleware are execute in order of addition, the ones attached to the server will be executed first.
*/
#include <WiFi.h>
#include <WebServer.h>
#include <Middlewares.h>
// Your AP WiFi Credentials
// ( This is the AP your ESP will broadcast )
const char *ap_ssid = "ESP32_Demo";
const char *ap_password = "";
WebServer server(80);
LoggingMiddleware logger;
CorsMiddleware cors;
AuthenticationMiddleware auth;
void setup(void) {
Serial.begin(115200);
WiFi.softAP(ap_ssid, ap_password);
Serial.print("IP address: ");
Serial.println(WiFi.AP.localIP());
// curl-like output example:
//
// > curl -v -X OPTIONS -H "origin: http://192.168.4.1" http://192.168.4.1/
//
// Connection from 192.168.4.2:51683
// > OPTIONS / HTTP/1.1
// > Host: 192.168.4.1
// > User-Agent: curl/8.10.0
// > Accept: */*
// > origin: http://192.168.4.1
// >
// * Processed in 5 ms
// < HTTP/1.HTTP/1.1 200 OK
// < Content-Type: text/html
// < Access-Control-Allow-Origin: http://192.168.4.1
// < Access-Control-Allow-Methods: POST,GET,OPTIONS,DELETE
// < Access-Control-Allow-Headers: X-Custom-Header
// < Access-Control-Allow-Credentials: false
// < Access-Control-Max-Age: 600
// < Content-Length: 0
// < Connection: close
// <
logger.setOutput(Serial);
cors.setOrigin("http://192.168.4.1");
cors.setMethods("POST,GET,OPTIONS,DELETE");
cors.setHeaders("X-Custom-Header");
cors.setAllowCredentials(false);
cors.setMaxAge(600);
auth.setUsername("admin");
auth.setPassword("admin");
auth.setRealm("My Super App");
auth.setAuthMethod(DIGEST_AUTH);
auth.setAuthFailureMessage("Authentication Failed");
server.addMiddleware(&logger);
server.addMiddleware(&cors);
// Not authenticated
//
// Test CORS preflight request with:
// > curl -v -X OPTIONS -H "origin: http://192.168.4.1" http://192.168.4.1/
//
// Test cross-domain request with:
// > curl -v -X GET -H "origin: http://192.168.4.1" http://192.168.4.1/
//
server.on("/", []() {
server.send(200, "text/plain", "Home");
});
// Authenticated
//
// > curl -v -X GET -H "origin: http://192.168.4.1" http://192.168.4.1/protected
//
// Outputs:
//
// * Connection from 192.168.4.2:51750
// > GET /protected HTTP/1.1
// > Host: 192.168.4.1
// > User-Agent: curl/8.10.0
// > Accept: */*
// > origin: http://192.168.4.1
// >
// * Processed in 7 ms
// < HTTP/1.HTTP/1.1 401 Unauthorized
// < Content-Type: text/html
// < Access-Control-Allow-Origin: http://192.168.4.1
// < Access-Control-Allow-Methods: POST,GET,OPTIONS,DELETE
// < Access-Control-Allow-Headers: X-Custom-Header
// < Access-Control-Allow-Credentials: false
// < Access-Control-Max-Age: 600
// < WWW-Authenticate: Digest realm="My Super App", qop="auth", nonce="ac388a64184e3e102aae6fff1c9e8d76", opaque="e7d158f2b54d25328142d118ff0f932d"
// < Content-Length: 21
// < Connection: close
// <
//
// > curl -v -X GET -H "origin: http://192.168.4.1" --digest -u admin:admin http://192.168.4.1/protected
//
// Outputs:
//
// * Connection from 192.168.4.2:53662
// > GET /protected HTTP/1.1
// > Authorization: Digest username="admin", realm="My Super App", nonce="db9e6824eb2a13bc7b2bf8f3c43db896", uri="/protected", cnonce="NTliZDZiNTcwODM2MzAyY2JjMDBmZGJmNzFiY2ZmNzk=", nc=00000001, qop=auth, response="6ebd145ba0d3496a4a73f5ae79ff5264", opaque="23d739c22810282ff820538cba98bda4"
// > Host: 192.168.4.1
// > User-Agent: curl/8.10.0
// > Accept: */*
// > origin: http://192.168.4.1
// >
// Request handling...
// * Processed in 7 ms
// < HTTP/1.HTTP/1.1 200 OK
// < Content-Type: text/plain
// < Access-Control-Allow-Origin: http://192.168.4.1
// < Access-Control-Allow-Methods: POST,GET,OPTIONS,DELETE
// < Access-Control-Allow-Headers: X-Custom-Header
// < Access-Control-Allow-Credentials: false
// < Access-Control-Max-Age: 600
// < Content-Length: 9
// < Connection: close
// <
server
.on(
"/protected",
[]() {
Serial.println("Request handling...");
server.send(200, "text/plain", "Protected");
}
)
.addMiddleware(&auth);
// Not found is also handled by global middleware
//
// curl -v -X GET -H "origin: http://192.168.4.1" http://192.168.4.1/inexsting
//
// Outputs:
//
// * Connection from 192.168.4.2:53683
// > GET /inexsting HTTP/1.1
// > Host: 192.168.4.1
// > User-Agent: curl/8.10.0
// > Accept: */*
// > origin: http://192.168.4.1
// >
// * Processed in 16 ms
// < HTTP/1.HTTP/1.1 404 Not Found
// < Content-Type: text/plain
// < Access-Control-Allow-Origin: http://192.168.4.1
// < Access-Control-Allow-Methods: POST,GET,OPTIONS,DELETE
// < Access-Control-Allow-Headers: X-Custom-Header
// < Access-Control-Allow-Credentials: false
// < Access-Control-Max-Age: 600
// < Content-Length: 14
// < Connection: close
// <
server.onNotFound([]() {
server.send(404, "text/plain", "Page not found");
});
server.collectAllHeaders();
server.begin();
Serial.println("HTTP server started");
}
void loop(void) {
server.handleClient();
delay(2); //allow the cpu to switch to other tasks
}

View file

@ -0,0 +1,5 @@
{
"requires": [
"CONFIG_SOC_WIFI_SUPPORTED=y"
]
}

View file

@ -0,0 +1,66 @@
#ifndef MIDDLEWARES_H
#define MIDDLEWARES_H
#include <WebServer.h>
#include <Stream.h>
#include <assert.h>
// curl-like logging middleware
class LoggingMiddleware : public Middleware {
public:
void setOutput(Print &output);
bool run(WebServer &server, Middleware::Callback next) override;
private:
Print *_out = nullptr;
};
class CorsMiddleware : public Middleware {
public:
CorsMiddleware &setOrigin(const char *origin);
CorsMiddleware &setMethods(const char *methods);
CorsMiddleware &setHeaders(const char *headers);
CorsMiddleware &setAllowCredentials(bool credentials);
CorsMiddleware &setMaxAge(uint32_t seconds);
void addCORSHeaders(WebServer &server);
bool run(WebServer &server, Middleware::Callback next) override;
private:
String _origin = F("*");
String _methods = F("*");
String _headers = F("*");
bool _credentials = true;
uint32_t _maxAge = 86400;
};
class AuthenticationMiddleware : public Middleware {
public:
AuthenticationMiddleware &setUsername(const char *username);
AuthenticationMiddleware &setPassword(const char *password);
AuthenticationMiddleware &setPasswordHash(const char *sha1AsBase64orHex);
AuthenticationMiddleware &setCallback(WebServer::THandlerFunctionAuthCheck fn);
AuthenticationMiddleware &setRealm(const char *realm);
AuthenticationMiddleware &setAuthMethod(HTTPAuthMethod method);
AuthenticationMiddleware &setAuthFailureMessage(const char *message);
bool isAllowed(WebServer &server) const;
bool run(WebServer &server, Middleware::Callback next) override;
private:
String _username;
String _password;
bool _hash = false;
WebServer::THandlerFunctionAuthCheck _callback;
const char *_realm = nullptr;
HTTPAuthMethod _method = BASIC_AUTH;
String _authFailMsg;
};
#endif

View file

@ -78,8 +78,14 @@ bool WebServer::_parseRequest(NetworkClient &client) {
String req = client.readStringUntil('\r');
client.readStringUntil('\n');
//reset header value
for (int i = 0; i < _headerKeysCount; ++i) {
_currentHeaders[i].value = String();
if (_collectAllHeaders) {
// clear previous headers
collectAllHeaders();
} else {
// clear previous headers
for (RequestArgument *header = _currentHeaders; header; header = header->next) {
header->value = String();
}
}
// First line of HTTP request looks like "GET /path HTTP/1.1"
@ -154,9 +160,6 @@ bool WebServer::_parseRequest(NetworkClient &client) {
headerValue.trim();
_collectHeader(headerName.c_str(), headerValue.c_str());
log_v("headerName: %s", headerName.c_str());
log_v("headerValue: %s", headerValue.c_str());
if (headerName.equalsIgnoreCase(FPSTR(Content_Type))) {
using namespace mime;
if (headerValue.startsWith(FPSTR(mimeTable[txt].mimeType))) {
@ -254,9 +257,6 @@ bool WebServer::_parseRequest(NetworkClient &client) {
headerValue = req.substring(headerDiv + 2);
_collectHeader(headerName.c_str(), headerValue.c_str());
log_v("headerName: %s", headerName.c_str());
log_v("headerValue: %s", headerValue.c_str());
if (headerName.equalsIgnoreCase("Host")) {
_hostHeader = headerValue;
}
@ -272,12 +272,29 @@ bool WebServer::_parseRequest(NetworkClient &client) {
}
bool WebServer::_collectHeader(const char *headerName, const char *headerValue) {
for (int i = 0; i < _headerKeysCount; i++) {
if (_currentHeaders[i].key.equalsIgnoreCase(headerName)) {
_currentHeaders[i].value = headerValue;
RequestArgument *last = nullptr;
for (RequestArgument *header = _currentHeaders; header; header = header->next) {
if (header->next == nullptr) {
last = header;
}
if (header->key.equalsIgnoreCase(headerName)) {
header->value = headerValue;
log_v("header collected: %s: %s", headerName, headerValue);
return true;
}
}
assert(last);
if (_collectAllHeaders) {
last->next = new RequestArgument();
last->next->key = headerName;
last->next->value = headerValue;
_headerKeysCount++;
log_v("header collected: %s: %s", headerName, headerValue);
return true;
}
log_v("header skipped: %s: %s", headerName, headerValue);
return false;
}

View file

@ -41,31 +41,28 @@ static const char WWW_Authenticate[] = "WWW-Authenticate";
static const char Content_Length[] = "Content-Length";
static const char ETAG_HEADER[] = "If-None-Match";
WebServer::WebServer(IPAddress addr, int port)
: _corsEnabled(false), _server(addr, port), _currentMethod(HTTP_ANY), _currentVersion(0), _currentStatus(HC_NONE), _statusChange(0), _nullDelay(true),
_currentHandler(nullptr), _firstHandler(nullptr), _lastHandler(nullptr), _currentArgCount(0), _currentArgs(nullptr), _postArgsLen(0), _postArgs(nullptr),
_headerKeysCount(0), _currentHeaders(nullptr), _contentLength(0), _clientContentLength(0), _chunked(false) {
WebServer::WebServer(IPAddress addr, int port) : _server(addr, port) {
log_v("WebServer::Webserver(addr=%s, port=%d)", addr.toString().c_str(), port);
}
WebServer::WebServer(int port)
: _corsEnabled(false), _server(port), _currentMethod(HTTP_ANY), _currentVersion(0), _currentStatus(HC_NONE), _statusChange(0), _nullDelay(true),
_currentHandler(nullptr), _firstHandler(nullptr), _lastHandler(nullptr), _currentArgCount(0), _currentArgs(nullptr), _postArgsLen(0), _postArgs(nullptr),
_headerKeysCount(0), _currentHeaders(nullptr), _contentLength(0), _clientContentLength(0), _chunked(false) {
WebServer::WebServer(int port) : _server(port) {
log_v("WebServer::Webserver(port=%d)", port);
}
WebServer::~WebServer() {
_server.close();
if (_currentHeaders) {
delete[] _currentHeaders;
}
_clearRequestHeaders();
_clearResponseHeaders();
delete _chain;
RequestHandler *handler = _firstHandler;
while (handler) {
RequestHandler *next = handler->next();
delete handler;
handler = next;
}
_firstHandler = nullptr;
}
void WebServer::begin() {
@ -436,7 +433,17 @@ void WebServer::handleClient() {
_currentClient.setTimeout(HTTP_MAX_SEND_WAIT); /* / 1000 removed, WifiClient setTimeout changed to ms */
if (_parseRequest(_currentClient)) {
_contentLength = CONTENT_LENGTH_NOT_SET;
_handleRequest();
_responseCode = 0;
_clearResponseHeaders();
// Run server-level middlewares
if (_chain) {
_chain->runChain(*this, [this]() {
return _handleRequest();
});
} else {
_handleRequest();
}
if (_currentClient.isSSE()) {
_currentStatus = HC_WAIT_CLOSE;
@ -495,16 +502,22 @@ void WebServer::stop() {
}
void WebServer::sendHeader(const String &name, const String &value, bool first) {
String headerLine = name;
headerLine += F(": ");
headerLine += value;
headerLine += "\r\n";
RequestArgument *header = new RequestArgument();
header->key = name;
header->value = value;
if (first) {
_responseHeaders = headerLine + _responseHeaders;
if (!_responseHeaders || first) {
header->next = _responseHeaders;
_responseHeaders = header;
} else {
_responseHeaders += headerLine;
RequestArgument *last = _responseHeaders;
while (last->next) {
last = last->next;
}
last->next = header;
}
_responseHeaderCount++;
}
void WebServer::setContentLength(const size_t contentLength) {
@ -529,11 +542,14 @@ void WebServer::enableETag(bool enable, ETagFunction fn) {
}
void WebServer::_prepareHeader(String &response, int code, const char *content_type, size_t contentLength) {
response = String(F("HTTP/1.")) + String(_currentVersion) + ' ';
response += String(code);
response += ' ';
response += _responseCodeToString(code);
response += "\r\n";
_responseCode = code;
response.concat(version());
response.concat(' ');
response.concat(String(code));
response.concat(' ');
response.concat(responseCodeToString(code));
response.concat(F("\r\n"));
using namespace mime;
if (!content_type) {
@ -558,9 +574,14 @@ void WebServer::_prepareHeader(String &response, int code, const char *content_t
}
sendHeader(String(F("Connection")), String(F("close")));
response += _responseHeaders;
response += "\r\n";
_responseHeaders = "";
for (RequestArgument *header = _responseHeaders; header; header = header->next) {
response.concat(header->key);
response.concat(F(": "));
response.concat(header->value);
response.concat(F("\r\n"));
}
response.concat(F("\r\n"));
}
void WebServer::send(int code, const char *content_type, const String &content) {
@ -568,9 +589,6 @@ void WebServer::send(int code, const char *content_type, const String &content)
// Can we assume the following?
//if(code == 200 && content.length() == 0 && _contentLength == CONTENT_LENGTH_NOT_SET)
// _contentLength = CONTENT_LENGTH_UNKNOWN;
if (content.length() == 0) {
log_w("content length is zero");
}
_prepareHeader(header, code, content_type, content.length());
_currentClientWrite(header.c_str(), header.length());
if (content.length()) {
@ -728,39 +746,43 @@ bool WebServer::hasArg(const String &name) const {
}
String WebServer::header(const String &name) const {
for (int i = 0; i < _headerKeysCount; ++i) {
if (_currentHeaders[i].key.equalsIgnoreCase(name)) {
return _currentHeaders[i].value;
for (RequestArgument *current = _currentHeaders; current; current = current->next) {
if (current->key.equalsIgnoreCase(name)) {
return current->value;
}
}
return "";
}
void WebServer::collectHeaders(const char *headerKeys[], const size_t headerKeysCount) {
_headerKeysCount = headerKeysCount + 2;
if (_currentHeaders) {
delete[] _currentHeaders;
}
_currentHeaders = new RequestArgument[_headerKeysCount];
_currentHeaders[0].key = FPSTR(AUTHORIZATION_HEADER);
_currentHeaders[1].key = FPSTR(ETAG_HEADER);
collectAllHeaders();
_collectAllHeaders = false;
_headerKeysCount += headerKeysCount;
RequestArgument *last = _currentHeaders->next;
for (int i = 2; i < _headerKeysCount; i++) {
_currentHeaders[i].key = headerKeys[i - 2];
last->next = new RequestArgument();
last->next->key = headerKeys[i - 2];
last = last->next;
}
}
String WebServer::header(int i) const {
if (i < _headerKeysCount) {
return _currentHeaders[i].value;
RequestArgument *current = _currentHeaders;
while (current && i--) {
current = current->next;
}
return "";
return current ? current->value : emptyString;
}
String WebServer::headerName(int i) const {
if (i < _headerKeysCount) {
return _currentHeaders[i].key;
RequestArgument *current = _currentHeaders;
while (current && i--) {
current = current->next;
}
return "";
return current ? current->key : emptyString;
}
int WebServer::headers() const {
@ -768,12 +790,7 @@ int WebServer::headers() const {
}
bool WebServer::hasHeader(const String &name) const {
for (int i = 0; i < _headerKeysCount; ++i) {
if ((_currentHeaders[i].key.equalsIgnoreCase(name)) && (_currentHeaders[i].value.length() > 0)) {
return true;
}
}
return false;
return header(name).length() > 0;
}
String WebServer::hostHeader() const {
@ -788,16 +805,17 @@ void WebServer::onNotFound(THandlerFunction fn) {
_notFoundHandler = fn;
}
void WebServer::_handleRequest() {
bool WebServer::_handleRequest() {
bool handled = false;
if (!_currentHandler) {
log_e("request handler not found");
} else {
handled = _currentHandler->handle(*this, _currentMethod, _currentUri);
if (_currentHandler) {
handled = _currentHandler->process(*this, _currentMethod, _currentUri);
if (!handled) {
log_e("request handler failed to handle request");
}
}
// DO NOT LOG if _currentHandler == null !!
// This is is valid use case to handle any other requests
// Also, this is just causing log flooding
if (!handled && _notFoundHandler) {
_notFoundHandler();
handled = true;
@ -811,6 +829,7 @@ void WebServer::_handleRequest() {
_finalizeResponse();
}
_currentUri = "";
return handled;
}
void WebServer::_finalizeResponse() {
@ -819,7 +838,7 @@ void WebServer::_finalizeResponse() {
}
}
String WebServer::_responseCodeToString(int code) {
String WebServer::responseCodeToString(int code) {
switch (code) {
case 100: return F("Continue");
case 101: return F("Switching Protocols");
@ -864,3 +883,108 @@ String WebServer::_responseCodeToString(int code) {
default: return F("");
}
}
void WebServer::_clearResponseHeaders() {
_responseHeaderCount = 0;
RequestArgument *current = _responseHeaders;
while (current) {
RequestArgument *next = current->next;
delete current;
current = next;
}
_responseHeaders = nullptr;
}
void WebServer::_clearRequestHeaders() {
_headerKeysCount = 0;
RequestArgument *current = _currentHeaders;
while (current) {
RequestArgument *next = current->next;
delete current;
current = next;
}
_currentHeaders = nullptr;
}
void WebServer::collectAllHeaders() {
_clearRequestHeaders();
_currentHeaders = new RequestArgument();
_currentHeaders->key = FPSTR(AUTHORIZATION_HEADER);
_currentHeaders->next = new RequestArgument();
_currentHeaders->next->key = FPSTR(ETAG_HEADER);
_headerKeysCount = 2;
_collectAllHeaders = true;
}
const String &WebServer::responseHeader(String name) const {
for (RequestArgument *current = _responseHeaders; current; current = current->next) {
if (current->key.equalsIgnoreCase(name)) {
return current->value;
}
}
return emptyString;
}
const String &WebServer::responseHeader(int i) const {
RequestArgument *current = _responseHeaders;
while (current && i--) {
current = current->next;
}
return current ? current->value : emptyString;
}
const String &WebServer::responseHeaderName(int i) const {
RequestArgument *current = _responseHeaders;
while (current && i--) {
current = current->next;
}
return current ? current->key : emptyString;
}
bool WebServer::hasResponseHeader(const String &name) const {
return header(name).length() > 0;
}
int WebServer::clientContentLength() const {
return _clientContentLength;
}
const String WebServer::version() const {
String v;
v.reserve(8);
v.concat(F("HTTP/1."));
v.concat(_currentVersion);
return v;
}
int WebServer::responseCode() const {
return _responseCode;
}
int WebServer::responseHeaders() const {
return _responseHeaderCount;
}
WebServer &WebServer::addMiddleware(Middleware *middleware) {
if (!_chain) {
_chain = new MiddlewareChain();
}
_chain->addMiddleware(middleware);
return *this;
}
WebServer &WebServer::addMiddleware(Middleware::Function fn) {
if (!_chain) {
_chain = new MiddlewareChain();
}
_chain->addMiddleware(fn);
return *this;
}
WebServer &WebServer::removeMiddleware(Middleware *middleware) {
if (_chain) {
_chain->removeMiddleware(middleware);
}
return *this;
}

View file

@ -92,6 +92,7 @@ typedef struct {
void *data; // additional data
} HTTPRaw;
#include "middleware/Middleware.h"
#include "detail/RequestHandler.h"
namespace fs {
@ -158,6 +159,10 @@ public:
void onNotFound(THandlerFunction fn); //called when handler is not assigned
void onFileUpload(THandlerFunction ufn); //handle file uploads
WebServer &addMiddleware(Middleware *middleware);
WebServer &addMiddleware(Middleware::Function fn);
WebServer &removeMiddleware(Middleware *middleware);
String uri() const {
return _currentUri;
}
@ -181,17 +186,23 @@ public:
int args() const; // get arguments count
bool hasArg(const String &name) const; // check if argument exists
void collectHeaders(const char *headerKeys[], const size_t headerKeysCount); // set the request headers to collect
void collectAllHeaders(); // collect all request headers
String header(const String &name) const; // get request header value by name
String header(int i) const; // get request header value by number
String headerName(int i) const; // get request header name by number
int headers() const; // get header count
bool hasHeader(const String &name) const; // check if header exists
int clientContentLength() const {
return _clientContentLength;
} // return "content-length" of incoming HTTP header from "_currentClient"
int clientContentLength() const; // return "content-length" of incoming HTTP header from "_currentClient"
const String version() const; // get the HTTP version string
String hostHeader() const; // get request host header if available or empty String if not
String hostHeader() const; // get request host header if available or empty String if not
int responseCode() const; // get the HTTP response code set
int responseHeaders() const; // get the HTTP response headers count
const String &responseHeader(String name) const; // get the HTTP response header value by name
const String &responseHeader(int i) const; // get the HTTP response header value by number
const String &responseHeaderName(int i) const; // get the HTTP response header name by number
bool hasResponseHeader(const String &name) const; // check if response header exists
// send response to the client
// code - HTTP response code, can be 200 or 404
@ -228,6 +239,8 @@ public:
bool _eTagEnabled = false;
ETagFunction _eTagFunction = nullptr;
static String responseCodeToString(int code);
protected:
virtual size_t _currentClientWrite(const char *b, size_t l) {
return _currentClient.write(b, l);
@ -237,11 +250,10 @@ protected:
}
void _addRequestHandler(RequestHandler *handler);
bool _removeRequestHandler(RequestHandler *handler);
void _handleRequest();
bool _handleRequest();
void _finalizeResponse();
bool _parseRequest(NetworkClient &client);
void _parseArguments(const String &data);
static String _responseCodeToString(int code);
bool _parseForm(NetworkClient &client, const String &boundary, uint32_t len);
bool _parseFormUploadAborted();
void _uploadWriteByte(uint8_t b);
@ -255,48 +267,57 @@ protected:
// for extracting Auth parameters
String _extractParam(String &authReq, const String &param, const char delimit = '"');
void _clearResponseHeaders();
void _clearRequestHeaders();
struct RequestArgument {
String key;
String value;
RequestArgument *next;
};
boolean _corsEnabled;
boolean _corsEnabled = false;
NetworkServer _server;
NetworkClient _currentClient;
HTTPMethod _currentMethod;
HTTPMethod _currentMethod = HTTP_ANY;
String _currentUri;
uint8_t _currentVersion;
HTTPClientStatus _currentStatus;
unsigned long _statusChange;
boolean _nullDelay;
uint8_t _currentVersion = 0;
HTTPClientStatus _currentStatus = HC_NONE;
unsigned long _statusChange = 0;
boolean _nullDelay = true;
RequestHandler *_currentHandler;
RequestHandler *_firstHandler;
RequestHandler *_lastHandler;
THandlerFunction _notFoundHandler;
THandlerFunction _fileUploadHandler;
RequestHandler *_currentHandler = nullptr;
RequestHandler *_firstHandler = nullptr;
RequestHandler *_lastHandler = nullptr;
THandlerFunction _notFoundHandler = nullptr;
THandlerFunction _fileUploadHandler = nullptr;
int _currentArgCount;
RequestArgument *_currentArgs;
int _postArgsLen;
RequestArgument *_postArgs;
int _currentArgCount = 0;
RequestArgument *_currentArgs = nullptr;
int _postArgsLen = 0;
RequestArgument *_postArgs = nullptr;
std::unique_ptr<HTTPUpload> _currentUpload;
std::unique_ptr<HTTPRaw> _currentRaw;
int _headerKeysCount;
RequestArgument *_currentHeaders;
size_t _contentLength;
int _clientContentLength; // "Content-Length" from header of incoming POST or GET request
String _responseHeaders;
int _headerKeysCount = 0;
RequestArgument *_currentHeaders = nullptr;
size_t _contentLength = 0;
int _clientContentLength = 0; // "Content-Length" from header of incoming POST or GET request
RequestArgument *_responseHeaders = nullptr;
String _hostHeader;
bool _chunked;
bool _chunked = false;
String _snonce; // Store noance and opaque for future comparison
String _sopaque;
String _srealm; // Store the Auth realm between Calls
int _responseHeaderCount = 0;
int _responseCode = 0;
bool _collectAllHeaders = false;
MiddlewareChain *_chain = nullptr;
};
#endif //ESP8266WEBSERVER_H

View file

@ -6,7 +6,9 @@
class RequestHandler {
public:
virtual ~RequestHandler() {}
virtual ~RequestHandler() {
delete _chain;
}
/*
note: old handler API for backward compatibility
@ -75,8 +77,14 @@ public:
_next = r;
}
RequestHandler &addMiddleware(Middleware *middleware);
RequestHandler &addMiddleware(Middleware::Function fn);
RequestHandler &removeMiddleware(Middleware *middleware);
bool process(WebServer &server, HTTPMethod requestMethod, String requestUri);
private:
RequestHandler *_next = nullptr;
MiddlewareChain *_chain = nullptr;
protected:
std::vector<String> pathArgs;

View file

@ -10,6 +10,39 @@
using namespace mime;
RequestHandler &RequestHandler::addMiddleware(Middleware *middleware) {
if (!_chain) {
_chain = new MiddlewareChain();
}
_chain->addMiddleware(middleware);
return *this;
}
RequestHandler &RequestHandler::addMiddleware(Middleware::Function fn) {
if (!_chain) {
_chain = new MiddlewareChain();
}
_chain->addMiddleware(fn);
return *this;
}
RequestHandler &RequestHandler::removeMiddleware(Middleware *middleware) {
if (_chain) {
_chain->removeMiddleware(middleware);
}
return *this;
}
bool RequestHandler::process(WebServer &server, HTTPMethod requestMethod, String requestUri) {
if (_chain) {
return _chain->runChain(server, [this, &server, &requestMethod, &requestUri]() {
return handle(server, requestMethod, requestUri);
});
} else {
return handle(server, requestMethod, requestUri);
}
}
class FunctionRequestHandler : public RequestHandler {
public:
FunctionRequestHandler(WebServer::THandlerFunction fn, WebServer::THandlerFunction ufn, const Uri &uri, HTTPMethod method)

View file

@ -0,0 +1,82 @@
#include "Middlewares.h"
AuthenticationMiddleware &AuthenticationMiddleware::setUsername(const char *username) {
_username = username;
_callback = nullptr;
return *this;
}
AuthenticationMiddleware &AuthenticationMiddleware::setPassword(const char *password) {
_password = password;
_hash = false;
_callback = nullptr;
return *this;
}
AuthenticationMiddleware &AuthenticationMiddleware::setPasswordHash(const char *sha1AsBase64orHex) {
_password = sha1AsBase64orHex;
_hash = true;
_callback = nullptr;
return *this;
}
AuthenticationMiddleware &AuthenticationMiddleware::setCallback(WebServer::THandlerFunctionAuthCheck fn) {
assert(fn);
_callback = fn;
_hash = false;
_username = emptyString;
_password = emptyString;
return *this;
}
AuthenticationMiddleware &AuthenticationMiddleware::setRealm(const char *realm) {
_realm = realm;
return *this;
}
AuthenticationMiddleware &AuthenticationMiddleware::setAuthMethod(HTTPAuthMethod method) {
_method = method;
return *this;
}
AuthenticationMiddleware &AuthenticationMiddleware::setAuthFailureMessage(const char *message) {
_authFailMsg = message;
return *this;
}
bool AuthenticationMiddleware::isAllowed(WebServer &server) const {
if (_callback) {
return server.authenticate(_callback);
}
if (!_username.isEmpty() && !_password.isEmpty()) {
if (_hash) {
return server.authenticateBasicSHA1(_username.c_str(), _password.c_str());
} else {
return server.authenticate(_username.c_str(), _password.c_str());
}
}
return true;
}
bool AuthenticationMiddleware::run(WebServer &server, Middleware::Callback next) {
bool authenticationRequired = false;
if (_callback) {
authenticationRequired = !server.authenticate(_callback);
} else if (!_username.isEmpty() && !_password.isEmpty()) {
if (_hash) {
authenticationRequired = !server.authenticateBasicSHA1(_username.c_str(), _password.c_str());
} else {
authenticationRequired = !server.authenticate(_username.c_str(), _password.c_str());
}
}
if (authenticationRequired) {
server.requestAuthentication(_method, _realm, _authFailMsg);
return true;
} else {
return next();
}
}

View file

@ -0,0 +1,47 @@
#include "Middlewares.h"
CorsMiddleware &CorsMiddleware::setOrigin(const char *origin) {
_origin = origin;
return *this;
}
CorsMiddleware &CorsMiddleware::setMethods(const char *methods) {
_methods = methods;
return *this;
}
CorsMiddleware &CorsMiddleware::setHeaders(const char *headers) {
_headers = headers;
return *this;
}
CorsMiddleware &CorsMiddleware::setAllowCredentials(bool credentials) {
_credentials = credentials;
return *this;
}
CorsMiddleware &CorsMiddleware::setMaxAge(uint32_t seconds) {
_maxAge = seconds;
return *this;
}
void CorsMiddleware::addCORSHeaders(WebServer &server) {
server.sendHeader(F("Access-Control-Allow-Origin"), _origin.c_str());
server.sendHeader(F("Access-Control-Allow-Methods"), _methods.c_str());
server.sendHeader(F("Access-Control-Allow-Headers"), _headers.c_str());
server.sendHeader(F("Access-Control-Allow-Credentials"), _credentials ? F("true") : F("false"));
server.sendHeader(F("Access-Control-Max-Age"), String(_maxAge).c_str());
}
bool CorsMiddleware::run(WebServer &server, Middleware::Callback next) {
// Origin header ? => CORS handling
if (server.hasHeader(F("Origin"))) {
addCORSHeaders(server);
// check if this is a preflight request => handle it and return
if (server.method() == HTTP_OPTIONS) {
server.send(200);
return true;
}
}
return next();
}

View file

@ -0,0 +1,74 @@
#include "Middlewares.h"
void LoggingMiddleware::setOutput(Print &output) {
_out = &output;
}
bool LoggingMiddleware::run(WebServer &server, Middleware::Callback next) {
if (_out == nullptr) {
return next();
}
_out->print(F("* Connection from "));
_out->print(server.client().remoteIP().toString());
_out->print(F(":"));
_out->println(server.client().remotePort());
_out->print(F("> "));
const HTTPMethod method = server.method();
if (method == HTTP_ANY) {
_out->print(F("HTTP_ANY"));
} else {
_out->print(http_method_str(method));
}
_out->print(F(" "));
_out->print(server.uri());
_out->print(F(" "));
_out->println(server.version());
int n = server.headers();
for (int i = 0; i < n; i++) {
String v = server.header(i);
if (!v.isEmpty()) {
// because these 2 are always there, eventually empty: "Authorization", "If-None-Match"
_out->print(F("> "));
_out->print(server.headerName(i));
_out->print(F(": "));
_out->println(server.header(i));
}
}
_out->println(F(">"));
uint32_t elapsed = millis();
const bool ret = next();
elapsed = millis() - elapsed;
if (ret) {
_out->print(F("* Processed in "));
_out->print(elapsed);
_out->println(F(" ms"));
_out->print(F("< "));
_out->print(F("HTTP/1."));
_out->print(server.version());
_out->print(F(" "));
_out->print(server.responseCode());
_out->print(F(" "));
_out->println(WebServer::responseCodeToString(server.responseCode()));
n = server.responseHeaders();
for (int i = 0; i < n; i++) {
_out->print(F("< "));
_out->print(server.responseHeaderName(i));
_out->print(F(": "));
_out->println(server.responseHeader(i));
}
_out->println(F("<"));
} else {
_out->println(F("* Not processed!"));
}
return ret;
}

View file

@ -0,0 +1,54 @@
#ifndef MIDDLEWARE_H
#define MIDDLEWARE_H
#include <assert.h>
#include <functional>
class MiddlewareChain;
class WebServer;
class Middleware {
public:
typedef std::function<bool(void)> Callback;
typedef std::function<bool(WebServer &server, Callback next)> Function;
virtual ~Middleware() {}
virtual bool run(WebServer &server, Callback next) {
return next();
};
private:
friend MiddlewareChain;
Middleware *_next = nullptr;
bool _freeOnRemoval = false;
};
class MiddlewareFunction : public Middleware {
public:
MiddlewareFunction(Middleware::Function fn) : _fn(fn) {}
bool run(WebServer &server, Middleware::Callback next) override {
return _fn(server, next);
}
private:
Middleware::Function _fn;
};
class MiddlewareChain {
public:
~MiddlewareChain();
void addMiddleware(Middleware::Function fn);
void addMiddleware(Middleware *middleware);
bool removeMiddleware(Middleware *middleware);
bool runChain(WebServer &server, Middleware::Callback finalizer);
private:
Middleware *_root = nullptr;
Middleware *_current = nullptr;
};
#endif

View file

@ -0,0 +1,73 @@
#include "Middleware.h"
MiddlewareChain::~MiddlewareChain() {
Middleware *current = _root;
while (current) {
Middleware *next = current->_next;
if (current->_freeOnRemoval) {
delete current;
}
current = next;
}
_root = nullptr;
}
void MiddlewareChain::addMiddleware(Middleware::Function fn) {
MiddlewareFunction *middleware = new MiddlewareFunction(fn);
middleware->_freeOnRemoval = true;
addMiddleware(middleware);
}
void MiddlewareChain::addMiddleware(Middleware *middleware) {
if (!_root) {
_root = middleware;
return;
}
Middleware *current = _root;
while (current->_next) {
current = current->_next;
}
current->_next = middleware;
}
bool MiddlewareChain::removeMiddleware(Middleware *middleware) {
if (!_root) {
return false;
}
if (_root == middleware) {
_root = _root->_next;
if (middleware->_freeOnRemoval) {
delete middleware;
}
return true;
}
Middleware *current = _root;
while (current->_next) {
if (current->_next == middleware) {
current->_next = current->_next->_next;
if (middleware->_freeOnRemoval) {
delete middleware;
}
return true;
}
current = current->_next;
}
return false;
}
bool MiddlewareChain::runChain(WebServer &server, Middleware::Callback finalizer) {
if (!_root) {
return finalizer();
}
_current = _root;
Middleware::Callback next;
next = [this, &server, &next, finalizer]() {
if (!_current) {
return finalizer();
}
Middleware *that = _current;
_current = _current->_next;
return that->run(server, next);
};
return next();
}