# SPDX-FileCopyrightText: Copyright (c) 2022 Dan Halbert for Adafruit Industries # # SPDX-License-Identifier: MIT """ `adafruit_httpserver.server` ==================================================== * Author(s): Dan Halbert, MichaƂ Pokusa """ try: from typing import Callable, Protocol, Union, List, Tuple from socket import socket from socketpool import SocketPool except ImportError: pass from errno import EAGAIN, ECONNRESET, ETIMEDOUT from .authentication import Basic, Bearer, require_authentication from .exceptions import ( ServerStoppedError, AuthenticationError, FileNotExistsError, InvalidPathError, ServingFilesDisabledError, ) from .methods import GET, HEAD from .request import Request from .response import Response from .route import _Routes, _Route from .status import BAD_REQUEST_400, UNAUTHORIZED_401, FORBIDDEN_403, NOT_FOUND_404 class Server: """A basic socket-based HTTP server.""" def __init__(self, socket_source: Protocol, root_path: str = None) -> None: """Create a server, and get it ready to run. :param socket: An object that is a source of sockets. This could be a `socketpool` in CircuitPython or the `socket` module in CPython. :param str root_path: Root directory to serve files from """ self._auths = [] self._buffer = bytearray(1024) self._timeout = 1 self.routes = _Routes() self._socket_source = socket_source self._sock = None self.root_path = root_path self.stopped = False def route(self, path: str, methods: Union[str, List[str]] = GET) -> Callable: """ Decorator used to add a route. :param str path: URL path :param str methods: HTTP method(s): ``"GET"``, ``"POST"``, ``["GET", "POST"]`` etc. Example:: # Default method is GET @server.route("/example") def route_func(request): ... # It is necessary to specify other methods like POST, PUT, etc. @server.route("/example", POST) def route_func(request): ... # Multiple methods can be specified @server.route("/example", [GET, POST]) def route_func(request): ... # URL parameters can be specified @server.route("/example/", GET) def route_func(request, my_parameter): ... """ if isinstance(methods, str): methods = [methods] def route_decorator(func: Callable) -> Callable: for method in methods: self.routes.add(_Route(path, method), func) return func return route_decorator def serve_forever(self, host: str, port: int = 80) -> None: """ Wait for HTTP requests at the given host and port. Does not return. Ignores any exceptions raised by the handler function and continues to serve. Returns only when the server is stopped by calling ``.stop()``. :param str host: host name or IP address :param int port: port """ self.start(host, port) while not self.stopped: try: self.poll() except KeyboardInterrupt: # Exit on Ctrl-C e.g. during development return except: # pylint: disable=bare-except continue def start(self, host: str, port: int = 80) -> None: """ Start the HTTP server at the given host and port. Requires calling ``.poll()`` in a while loop to handle incoming requests. :param str host: host name or IP address :param int port: port """ self.stopped = False self._sock = self._socket_source.socket( self._socket_source.AF_INET, self._socket_source.SOCK_STREAM ) self._sock.bind((host, port)) self._sock.listen(10) self._sock.setblocking(False) # Non-blocking socket def stop(self) -> None: """ Stops the server from listening for new connections and closes the socket. Current requests will be processed. Server can be started again by calling ``.start()`` or ``.serve_forever()``. """ self.stopped = True self._sock.close() def _receive_request( self, sock: Union["SocketPool.Socket", "socket.socket"], client_address: Tuple[str, int], ) -> Request: """Receive bytes from socket until the whole request is received.""" # Receiving data until empty line header_bytes = self._receive_header_bytes(sock) # Return if no data received if not header_bytes: return None request = Request(self, sock, client_address, header_bytes) content_length = int(request.headers.get("Content-Length", 0)) received_body_bytes = request.body # Receiving remaining body bytes request.body = self._receive_body_bytes( sock, received_body_bytes, content_length ) return request def _receive_header_bytes( self, sock: Union["SocketPool.Socket", "socket.socket"] ) -> bytes: """Receive bytes until a empty line is received.""" received_bytes = bytes() while b"\r\n\r\n" not in received_bytes: try: length = sock.recv_into(self._buffer, len(self._buffer)) received_bytes += self._buffer[:length] except OSError as ex: if ex.errno == ETIMEDOUT: break raise except Exception as ex: raise ex return received_bytes def _receive_body_bytes( self, sock: Union["SocketPool.Socket", "socket.socket"], received_body_bytes: bytes, content_length: int, ) -> bytes: """Receive bytes until the given content length is received.""" while len(received_body_bytes) < content_length: try: length = sock.recv_into(self._buffer, len(self._buffer)) received_body_bytes += self._buffer[:length] except OSError as ex: if ex.errno == ETIMEDOUT: break raise except Exception as ex: raise ex return received_body_bytes[:content_length] def _serve_file_from_filesystem(self, request: Request): filename = "index.html" if request.path == "/" else request.path root_path = self.root_path buffer_size = self.request_buffer_size head_only = request.method == HEAD with Response(request) as response: response.send_file(filename, root_path, buffer_size, head_only) def _handle_request(self, request: Request, handler: Union[Callable, None]): try: # Check server authentications if necessary if self._auths: require_authentication(request, self._auths) # Handler for route exists and is callable if handler is not None and callable(handler): handler(request) # Handler is not found... # ...no root_path, access to filesystem disabled, return 404. elif self.root_path is None: raise ServingFilesDisabledError # ..root_path is set, access to filesystem enabled... # ...request.method is GET or HEAD, try to serve a file from the filesystem. elif request.method in [GET, HEAD]: self._serve_file_from_filesystem(request) # ... else: Response(request, status=BAD_REQUEST_400).send() except AuthenticationError: headers = {"WWW-Authenticate": 'Basic charset="UTF-8"'} Response(request, status=UNAUTHORIZED_401, headers=headers).send() except InvalidPathError as error: Response(request, status=FORBIDDEN_403).send(str(error)) except (FileNotExistsError, ServingFilesDisabledError) as error: Response(request, status=NOT_FOUND_404).send(str(error)) def poll(self): """ Call this method inside your main loop to get the server to check for new incoming client requests. When a request comes in, it will be handled by the handler function. """ if self.stopped: raise ServerStoppedError try: conn, client_address = self._sock.accept() with conn: conn.settimeout(self._timeout) # Receive the whole request if (request := self._receive_request(conn, client_address)) is None: return # Find a handler for the route handler = self.routes.find_handler(_Route(request.path, request.method)) # Handle the request self._handle_request(request, handler) except OSError as error: # There is no data available right now, try again later. if error.errno == EAGAIN: return # Connection reset by peer, try again later. if error.errno == ECONNRESET: return raise def require_authentication(self, auths: List[Union[Basic, Bearer]]) -> None: """ Requires authentication for all routes and files in ``root_path``. Any non-authenticated request will be rejected with a 401 status code. Example:: server = Server(pool, "/static") server.require_authentication([Basic("user", "pass")]) """ self._auths = auths @property def request_buffer_size(self) -> int: """ The maximum size of the incoming request buffer. If the default size isn't adequate to handle your incoming data you can set this after creating the server instance. Default size is 1024 bytes. Example:: server = Server(pool, "/static") server.request_buffer_size = 2048 server.serve_forever(str(wifi.radio.ipv4_address)) """ return len(self._buffer) @request_buffer_size.setter def request_buffer_size(self, value: int) -> None: self._buffer = bytearray(value) @property def socket_timeout(self) -> int: """ Timeout after which the socket will stop waiting for more incoming data. Must be set to positive integer or float. Default is 1 second. When exceeded, raises `OSError` with `errno.ETIMEDOUT`. Example:: server = Server(pool, "/static") server.socket_timeout = 3 server.serve_forever(str(wifi.radio.ipv4_address)) """ return self._timeout @socket_timeout.setter def socket_timeout(self, value: int) -> None: if isinstance(value, (int, float)) and value > 0: self._timeout = value else: raise ValueError("Server.socket_timeout must be a positive numeric value.")