DNSServer refactoring, switch to AsyncUDP (#7482)

* DNSServer: switch to AsyncUDP instead of WiFiUDP

AsyncUDP offers event driven approch for handling udp dns req's
WiFiUDP hooks to loop() for packet processing and making useless malloc's each run

* DNSServer code refactoring

get rid of intermediate mem buffers and extra data copies,
most of the data could be referenced or copied from the source packet
 - removed _buffer member
 - replaced DNSQuestion.QName from uint8_t[] to char*

added sanity checks for mem bounds
optimize label/packet length calculations
other code cleanup

* DNSServer drop dynamically allocated member structs

DNSHeader and DNSQuestion structs could be created on stack
no need to keep it as obj members

* DNSServer: labels min length checks, simplified labels parser

* DNSServer use default settings for catch-all setup

 - default constructor and start() method simply runs a catch-all DNS setup
 - avoid string comparison for domain reqs in catch-all mode
 - use IPaddress class for _resolvedIP (looking for IPv6 support in future)

* CaptivePortal example refactored

 - use webserver instead of simple tcp setver
 - use redirects to allows CaptivePortal detection pop-ups in modern systems

* DNSServer status getters added

add isUp() method - returns 'true' if server is up and UDP socket is listening for UDP req's
add isCaptive() method - returns 'true' if server runs in catch-all (captive portal mode)
some doxygen comments added
start() method now keeps existing IP address if any

---------

Co-authored-by: Lucas Saavedra Vaz <lucas.vaz@espressif.com>
Co-authored-by: Me No Dev <me-no-dev@users.noreply.github.com>
This commit is contained in:
vortigont 2023-12-18 21:47:04 +09:00 committed by GitHub
parent 44f83b0455
commit d91271019c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 253 additions and 200 deletions

View file

@ -1,52 +1,59 @@
/*
This example enables catch-all Captive portal for ESP32 Access-Point
It will allow modern devices/OSes to detect that WiFi connection is
limited and offer a user to access a banner web-page.
There is no need to find and open device's IP address/URL, i.e. http://192.168.4.1/
This works for Android, Ubuntu, FireFox, Windows, maybe others...
*/
#include <Arduino.h>
#include <WiFi.h>
#include <DNSServer.h>
#include <WebServer.h>
const byte DNS_PORT = 53;
IPAddress apIP(8,8,4,4); // The default android DNS
DNSServer dnsServer;
WiFiServer server(80);
WebServer server(80);
String responseHTML = ""
"<!DOCTYPE html><html><head><title>CaptivePortal</title></head><body>"
"<h1>Hello World!</h1><p>This is a captive portal example. All requests will "
"be redirected here.</p></body></html>";
static const char responsePortal[] = R"===(
<!DOCTYPE html><html><head><title>ESP32 CaptivePortal</title></head><body>
<h1>Hello World!</h1><p>This is a captive portal example page. All unknown http requests will
be redirected here.</p></body></html>
)===";
void setup() {
// index page handler
void handleRoot() {
server.send(200, "text/plain", "Hello from esp32!");
}
// this will redirect unknown http req's to our captive portal page
// based on this redirect various systems could detect that WiFi AP has a captive portal page
void handleNotFound() {
server.sendHeader("Location", "/portal");
server.send(302, "text/plain", "redirect to captive portal");
}
void setup() {
Serial.begin(115200);
WiFi.mode(WIFI_AP);
WiFi.softAP("ESP32-DNSServer");
WiFi.softAPConfig(apIP, apIP, IPAddress(255, 255, 255, 0));
// if DNSServer is started with "*" for domain name, it will reply with
// provided IP to all DNS request
dnsServer.start(DNS_PORT, "*", apIP);
// by default DNSServer is started serving any "*" domain name. It will reply
// AccessPoint's IP to all DNS request (this is requred for Captive Portal detection)
dnsServer.start();
// serve a simple root page
server.on("/", handleRoot);
// serve portal page
server.on("/portal",[](){server.send(200, "text/html", responsePortal);});
// all unknown pages are redirected to captive portal
server.onNotFound(handleNotFound);
server.begin();
}
void loop() {
dnsServer.processNextRequest();
WiFiClient client = server.available(); // listen for incoming clients
if (client) {
String currentLine = "";
while (client.connected()) {
if (client.available()) {
char c = client.read();
if (c == '\n') {
if (currentLine.length() == 0) {
client.println("HTTP/1.1 200 OK");
client.println("Content-type:text/html");
client.println();
client.print(responseHTML);
break;
} else {
currentLine = "";
}
} else if (c != '\r') {
currentLine += c;
}
}
}
client.stop();
}
server.handleClient();
delay(5); // give CPU some idle time
}

View file

@ -1,6 +1,8 @@
#include "DNSServer.h"
#include <lwip/def.h>
#include <Arduino.h>
#include <WiFi.h>
// #define DEBUG_ESP_DNS
#ifdef DEBUG_ESP_PORT
@ -9,45 +11,37 @@
#define DEBUG_OUTPUT Serial
#endif
DNSServer::DNSServer()
{
_ttl = htonl(DNS_DEFAULT_TTL);
_errorReplyCode = DNSReplyCode::NonExistentDomain;
_dnsHeader = (DNSHeader*) malloc( sizeof(DNSHeader) ) ;
_dnsQuestion = (DNSQuestion*) malloc( sizeof(DNSQuestion) ) ;
_buffer = NULL;
_currentPacketSize = 0;
_port = 0;
#define DNS_MIN_REQ_LEN 17 // minimal size for DNS request asking ROOT = DNS_HEADER_SIZE + 1 null byte for Name + 4 bytes type/class
DNSServer::DNSServer() : _port(DNS_DEFAULT_PORT), _ttl(htonl(DNS_DEFAULT_TTL)), _errorReplyCode(DNSReplyCode::NonExistentDomain){}
DNSServer::DNSServer(const String &domainName) : _port(DNS_DEFAULT_PORT), _ttl(htonl(DNS_DEFAULT_TTL)), _errorReplyCode(DNSReplyCode::NonExistentDomain), _domainName(domainName){};
bool DNSServer::start(){
if (_resolvedIP.operator uint32_t() == 0){ // no address is set, try to obtain AP interface's IP
if (WiFi.getMode() & WIFI_AP){
_resolvedIP = WiFi.softAPIP();
} else return false; // won't run if WiFi is not in AP mode
}
_udp.close();
_udp.onPacket([this](AsyncUDPPacket& pkt){ this->_handleUDP(pkt); });
return _udp.listen(_port);
}
DNSServer::~DNSServer()
{
if (_dnsHeader) {
free(_dnsHeader);
_dnsHeader = NULL;
}
if (_dnsQuestion) {
free(_dnsQuestion);
_dnsQuestion = NULL;
}
if (_buffer) {
free(_buffer);
_buffer = NULL;
}
}
bool DNSServer::start(const uint16_t &port, const String &domainName,
const IPAddress &resolvedIP)
{
bool DNSServer::start(uint16_t port, const String &domainName, const IPAddress &resolvedIP){
_port = port;
_buffer = NULL;
_domainName = domainName;
_resolvedIP[0] = resolvedIP[0];
_resolvedIP[1] = resolvedIP[1];
_resolvedIP[2] = resolvedIP[2];
_resolvedIP[3] = resolvedIP[3];
downcaseAndRemoveWwwPrefix(_domainName);
return _udp.begin(_port) == 1;
if (domainName != "*"){
_domainName = domainName;
downcaseAndRemoveWwwPrefix(_domainName);
} else
_domainName.clear();
_resolvedIP = resolvedIP;
_udp.close();
_udp.onPacket([this](AsyncUDPPacket& pkt){ this->_handleUDP(pkt); });
return _udp.listen(_port);
}
void DNSServer::setErrorReplyCode(const DNSReplyCode &replyCode)
@ -62,9 +56,7 @@ void DNSServer::setTTL(const uint32_t &ttl)
void DNSServer::stop()
{
_udp.stop();
free(_buffer);
_buffer = NULL;
_udp.close();
}
void DNSServer::downcaseAndRemoveWwwPrefix(String &domainName)
@ -73,151 +65,125 @@ void DNSServer::downcaseAndRemoveWwwPrefix(String &domainName)
domainName.replace("www.", "");
}
void DNSServer::processNextRequest()
void DNSServer::_handleUDP(AsyncUDPPacket& pkt)
{
_currentPacketSize = _udp.parsePacket();
if (_currentPacketSize)
{
// Allocate buffer for the DNS query
if (_buffer != NULL)
free(_buffer);
_buffer = (unsigned char*)malloc(_currentPacketSize * sizeof(char));
if (_buffer == NULL)
return;
if (pkt.length() < DNS_MIN_REQ_LEN) return; // truncated packet or not a DNS req
// Put the packet received in the buffer and get DNS header (beginning of message)
// and the question
_udp.read(_buffer, _currentPacketSize);
memcpy( _dnsHeader, _buffer, DNS_HEADER_SIZE ) ;
if ( requestIncludesOnlyOneQuestion() )
// get DNS header (beginning of message)
DNSHeader dnsHeader;
DNSQuestion dnsQuestion;
memcpy( &dnsHeader, pkt.data(), DNS_HEADER_SIZE );
if (dnsHeader.QR != DNS_QR_QUERY) return; // ignore non-query mesages
if ( requestIncludesOnlyOneQuestion(dnsHeader) )
{
/*
// The QName has a variable length, maximum 255 bytes and is comprised of multiple labels.
// Each label contains a byte to describe its length and the label itself. The list of
// labels terminates with a zero-valued byte. In "github.com", we have two labels "github" & "com"
// Iterate through the labels and copy them as they come into a single buffer (for simplicity's sake)
_dnsQuestion->QNameLength = 0 ;
while ( _buffer[ DNS_HEADER_SIZE + _dnsQuestion->QNameLength ] != 0 )
{
memcpy( (void*) &_dnsQuestion->QName[_dnsQuestion->QNameLength], (void*) &_buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength], _buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength] + 1 ) ;
_dnsQuestion->QNameLength += _buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength] + 1 ;
}
_dnsQuestion->QName[_dnsQuestion->QNameLength] = 0 ;
_dnsQuestion->QNameLength++ ;
// Copy the QType and QClass
memcpy( &_dnsQuestion->QType, (void*) &_buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength], sizeof(_dnsQuestion->QType) ) ;
memcpy( &_dnsQuestion->QClass, (void*) &_buffer[DNS_HEADER_SIZE + _dnsQuestion->QNameLength + sizeof(_dnsQuestion->QType)], sizeof(_dnsQuestion->QClass) ) ;
*/
const char * enoflbls = strchr(reinterpret_cast<const char*>(pkt.data()) + DNS_HEADER_SIZE, 0); // find end_of_label marker
++enoflbls; // advance after null terminator
dnsQuestion.QName = pkt.data() + DNS_HEADER_SIZE; // we can reference labels from the request
dnsQuestion.QNameLength = enoflbls - (char*)pkt.data() - DNS_HEADER_SIZE;
/*
check if we aint going out of pkt bounds
proper dns req should have label terminator at least 4 bytes before end of packet
*/
if (dnsQuestion.QNameLength > pkt.length() - DNS_HEADER_SIZE - sizeof(dnsQuestion.QType) - sizeof(dnsQuestion.QClass)) return; // malformed packet
// Copy the QType and QClass
memcpy( &dnsQuestion.QType, enoflbls, sizeof(dnsQuestion.QType) );
memcpy( &dnsQuestion.QClass, enoflbls + sizeof(dnsQuestion.QType), sizeof(dnsQuestion.QClass) );
}
if (_dnsHeader->QR == DNS_QR_QUERY &&
_dnsHeader->OPCode == DNS_OPCODE_QUERY &&
requestIncludesOnlyOneQuestion() &&
(_domainName == "*" || getDomainNameWithoutWwwPrefix() == _domainName)
// will reply with IP only to "*" or if doman matches without www. subdomain
if (dnsHeader.OPCode == DNS_OPCODE_QUERY &&
requestIncludesOnlyOneQuestion(dnsHeader) &&
(_domainName.isEmpty() ||
getDomainNameWithoutWwwPrefix(static_cast<const unsigned char*>(dnsQuestion.QName), dnsQuestion.QNameLength) == _domainName)
)
{
replyWithIP();
}
else if (_dnsHeader->QR == DNS_QR_QUERY)
{
replyWithCustomCode();
replyWithIP(pkt, dnsHeader, dnsQuestion);
return;
}
free(_buffer);
_buffer = NULL;
}
// otherwise reply with custom code
replyWithCustomCode(pkt, dnsHeader);
}
bool DNSServer::requestIncludesOnlyOneQuestion()
bool DNSServer::requestIncludesOnlyOneQuestion(DNSHeader& dnsHeader)
{
return ntohs(_dnsHeader->QDCount) == 1 &&
_dnsHeader->ANCount == 0 &&
_dnsHeader->NSCount == 0 &&
_dnsHeader->ARCount == 0;
return ntohs(dnsHeader.QDCount) == 1 &&
dnsHeader.ANCount == 0 &&
dnsHeader.NSCount == 0 &&
dnsHeader.ARCount == 0;
}
String DNSServer::getDomainNameWithoutWwwPrefix()
String DNSServer::getDomainNameWithoutWwwPrefix(const unsigned char* start, size_t len)
{
// Error checking : if the buffer containing the DNS request is a null pointer, return an empty domain
String parsedDomainName = "";
if (_buffer == NULL)
return parsedDomainName;
// Set the start of the domain just after the header (12 bytes). If equal to null character, return an empty domain
unsigned char *start = _buffer + DNS_OFFSET_DOMAIN_NAME;
if (*start == 0)
{
return parsedDomainName;
}
String parsedDomainName(start, --len); // exclude trailing null byte from labels length, String constructor will add it anyway
int pos = 0;
while(true)
while(pos<len)
{
unsigned char labelLength = *(start + pos);
for(int i = 0; i < labelLength; i++)
{
pos++;
parsedDomainName += (char)*(start + pos);
}
pos++;
if (*(start + pos) == 0)
{
downcaseAndRemoveWwwPrefix(parsedDomainName);
return parsedDomainName;
}
else
{
parsedDomainName += ".";
}
parsedDomainName.setCharAt(pos, 0x2e); // replace label len byte with dot char "."
pos += *(start + pos);
++pos;
}
parsedDomainName.remove(0,1); // remove first "." char
downcaseAndRemoveWwwPrefix(parsedDomainName);
return parsedDomainName;
}
void DNSServer::replyWithIP()
void DNSServer::replyWithIP(AsyncUDPPacket& req, DNSHeader& dnsHeader, DNSQuestion& dnsQuestion)
{
_udp.beginPacket(_udp.remoteIP(), _udp.remotePort());
AsyncUDPMessage rpl;
// Change the type of message to a response and set the number of answers equal to
// the number of questions in the header
_dnsHeader->QR = DNS_QR_RESPONSE;
_dnsHeader->ANCount = _dnsHeader->QDCount;
_udp.write( (unsigned char*) _dnsHeader, DNS_HEADER_SIZE ) ;
dnsHeader.QR = DNS_QR_RESPONSE;
dnsHeader.ANCount = dnsHeader.QDCount;
rpl.write( (unsigned char*) &dnsHeader, DNS_HEADER_SIZE ) ;
// Write the question
_udp.write(_dnsQuestion->QName, _dnsQuestion->QNameLength) ;
_udp.write( (unsigned char*) &_dnsQuestion->QType, 2 ) ;
_udp.write( (unsigned char*) &_dnsQuestion->QClass, 2 ) ;
rpl.write(dnsQuestion.QName, dnsQuestion.QNameLength) ;
rpl.write( (uint8_t*) &dnsQuestion.QType, 2 ) ;
rpl.write( (uint8_t*) &dnsQuestion.QClass, 2 ) ;
// Write the answer
// Use DNS name compression : instead of repeating the name in this RNAME occurence,
// set the two MSB of the byte corresponding normally to the length to 1. The following
// 14 bits must be used to specify the offset of the domain name in the message
// (<255 here so the first byte has the 6 LSB at 0)
_udp.write((uint8_t) 0xC0);
_udp.write((uint8_t) DNS_OFFSET_DOMAIN_NAME);
rpl.write((uint8_t) 0xC0);
rpl.write((uint8_t) DNS_OFFSET_DOMAIN_NAME);
// DNS type A : host address, DNS class IN for INternet, returning an IPv4 address
uint16_t answerType = htons(DNS_TYPE_A), answerClass = htons(DNS_CLASS_IN), answerIPv4 = htons(DNS_RDLENGTH_IPV4) ;
_udp.write((unsigned char*) &answerType, 2 );
_udp.write((unsigned char*) &answerClass, 2 );
_udp.write((unsigned char*) &_ttl, 4); // DNS Time To Live
_udp.write((unsigned char*) &answerIPv4, 2 );
_udp.write(_resolvedIP, sizeof(_resolvedIP)); // The IP address to return
_udp.endPacket();
rpl.write((unsigned char*) &answerType, 2 );
rpl.write((unsigned char*) &answerClass, 2 );
rpl.write((unsigned char*) &_ttl, 4); // DNS Time To Live
rpl.write((unsigned char*) &answerIPv4, 2 );
uint32_t ip = _resolvedIP;
rpl.write(reinterpret_cast<uint8_t*>(&ip), sizeof(uint32_t)); // The IPv4 address to return
_udp.sendTo(rpl, req.remoteIP(), req.remotePort());
#ifdef DEBUG_ESP_DNS
DEBUG_OUTPUT.printf("DNS responds: %s for %s\n",
IPAddress(_resolvedIP).toString().c_str(), getDomainNameWithoutWwwPrefix().c_str() );
_resolvedIP.toString().c_str(), getDomainNameWithoutWwwPrefix(static_cast<const unsigned char*>(dnsQuestion.QName), dnsQuestion.QNameLength).c_str() );
#endif
}
void DNSServer::replyWithCustomCode()
void DNSServer::replyWithCustomCode(AsyncUDPPacket& req, DNSHeader& dnsHeader)
{
_dnsHeader->QR = DNS_QR_RESPONSE;
_dnsHeader->RCode = (unsigned char)_errorReplyCode;
_dnsHeader->QDCount = 0;
dnsHeader.QR = DNS_QR_RESPONSE;
dnsHeader.RCode = static_cast<uint16_t>(_errorReplyCode);
dnsHeader.QDCount = 0;
_udp.beginPacket(_udp.remoteIP(), _udp.remotePort());
_udp.write((unsigned char*)_dnsHeader, sizeof(DNSHeader));
_udp.endPacket();
AsyncUDPMessage rpl(sizeof(DNSHeader));
rpl.write(reinterpret_cast<const uint8_t*>(&dnsHeader), sizeof(DNSHeader));
_udp.sendTo(rpl, req.remoteIP(), req.remotePort());
}

View file

@ -1,15 +1,15 @@
#ifndef DNSServer_h
#define DNSServer_h
#include <WiFiUdp.h>
#pragma once
#include <AsyncUDP.h>
#define DNS_QR_QUERY 0
#define DNS_QR_RESPONSE 1
#define DNS_OPCODE_QUERY 0
#define DNS_DEFAULT_TTL 60 // Default Time To Live : time interval in seconds that the resource record should be cached before being discarded
#define DNS_OFFSET_DOMAIN_NAME 12 // Offset in bytes to reach the domain name in the DNS message
#define DNS_HEADER_SIZE 12
#define DNS_OFFSET_DOMAIN_NAME DNS_HEADER_SIZE // Offset in bytes to reach the domain name labels in the DNS message
#define DNS_DEFAULT_PORT 53
enum class DNSReplyCode
enum class DNSReplyCode:uint16_t
{
NoError = 0,
FormError = 1,
@ -59,14 +59,14 @@ struct DNSHeader
uint16_t Flags;
};
uint16_t QDCount; // number of question entries
uint16_t ANCount; // number of answer entries
uint16_t ANCount; // number of ANswer entries
uint16_t NSCount; // number of authority entries
uint16_t ARCount; // number of resource entries
uint16_t ARCount; // number of Additional Resource entries
};
struct DNSQuestion
{
uint8_t QName[256] ; //need 1 Byte for zero termination!
const uint8_t* QName;
uint16_t QNameLength ;
uint16_t QType ;
uint16_t QClass ;
@ -75,36 +75,116 @@ struct DNSQuestion
class DNSServer
{
public:
/**
* @brief Construct a new DNSServer object
* by default server is configured to run in "Captive-portal" mode
* it must be started with start() call to establish a listening socket
*
*/
DNSServer();
~DNSServer();
void processNextRequest();
/**
* @brief Construct a new DNSServer object
* builds DNS server with default parameters
* @param domainName - domain name to serve
*/
DNSServer(const String &domainName);
~DNSServer(){}; // default d-tor
// Copy semantics not implemented (won't run on same UDP port anyway)
DNSServer(const DNSServer&) = delete;
DNSServer& operator=(const DNSServer&) = delete;
/**
* @brief stub, left for compatibility with an old version
* does nothing actually
*
*/
void processNextRequest(){};
/**
* @brief Set the Error Reply Code for all req's not matching predifined domain
*
* @param replyCode
*/
void setErrorReplyCode(const DNSReplyCode &replyCode);
/**
* @brief set TTL for successfull replies
*
* @param ttl in seconds
*/
void setTTL(const uint32_t &ttl);
// Returns true if successful, false if there are no sockets available
bool start(const uint16_t &port,
/**
* @brief (re)Starts a server with current configuration or with default parameters
* if it's the first call.
* Defaults are:
* port: 53
* domainName: any
* ip: WiFi AP's IP address
*
* @return true on success
* @return false if IP or socket error
*/
bool start();
/**
* @brief (re)Starts a server with provided configuration
*
* @return true on success
* @return false if IP or socket error
*/
bool start(uint16_t port,
const String &domainName,
const IPAddress &resolvedIP);
// stops the DNS server
/**
* @brief stops the server and close UDP socket
*
*/
void stop();
/**
* @brief returns true if DNS server runs in captive-portal mode
* i.e. all requests are served with AP's ip address
*
* @return true if catch-all mode active
* @return false otherwise
*/
inline bool isCaptive() const { return _domainName.isEmpty(); };
/**
* @brief returns 'true' if server is up and UDP socket is listening for UDP req's
*
* @return true if server is up
* @return false otherwise
*/
inline bool isUp() { return _udp.connected(); };
private:
WiFiUDP _udp;
AsyncUDP _udp;
uint16_t _port;
String _domainName;
unsigned char _resolvedIP[4];
int _currentPacketSize;
unsigned char* _buffer;
DNSHeader* _dnsHeader;
uint32_t _ttl;
DNSReplyCode _errorReplyCode;
DNSQuestion* _dnsQuestion ;
String _domainName;
IPAddress _resolvedIP;
void downcaseAndRemoveWwwPrefix(String &domainName);
String getDomainNameWithoutWwwPrefix();
bool requestIncludesOnlyOneQuestion();
void replyWithIP();
void replyWithCustomCode();
/**
* @brief Get the Domain Name Without Www Prefix object
* scan labels in DNS packet and build a string of a domain name
* truncate any www. label if found
* @param start a pointer to the start of labels records in DNS packet
* @param len labels length
* @return String
*/
String getDomainNameWithoutWwwPrefix(const unsigned char* start, size_t len);
inline bool requestIncludesOnlyOneQuestion(DNSHeader& dnsHeader);
void replyWithIP(AsyncUDPPacket& req, DNSHeader& dnsHeader, DNSQuestion& dnsQuestion);
inline void replyWithCustomCode(AsyncUDPPacket& req, DNSHeader& dnsHeader);
void _handleUDP(AsyncUDPPacket& pkt);
};
#endif