Adafruit_CircuitPython_HTTP.../adafruit_httpserver/request.py
2025-05-16 16:11:43 +00:00

479 lines
14 KiB
Python

# SPDX-FileCopyrightText: Copyright (c) 2022 Dan Halbert for Adafruit Industries, Michał Pokusa
#
# SPDX-License-Identifier: MIT
"""
`adafruit_httpserver.request`
====================================================
* Author(s): Dan Halbert, Michał Pokusa
"""
try:
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
if TYPE_CHECKING:
from .server import Server
except ImportError:
pass
import json
from .headers import Headers
from .interfaces import _IFieldStorage, _ISocket, _IXSSSafeFieldStorage
from .methods import DELETE, PATCH, POST, PUT
class QueryParams(_IXSSSafeFieldStorage):
"""
Class for parsing and storing GET query parameters requests.
Examples::
query_params = QueryParams("foo=bar&baz=qux&baz=quux")
# QueryParams({"foo": ["bar"], "baz": ["qux", "quux"]})
query_params.get("foo") # "bar"
query_params["foo"] # "bar"
query_params.get("non-existent-key") # None
query_params.get_list("baz") # ["qux", "quux"]
"unknown-key" in query_params # False
query_params.fields # ["foo", "baz"]
"""
_storage: Dict[str, List[str]]
def __init__(self, query_string: str) -> None:
self._storage = {}
for query_param in query_string.split("&"):
if "=" in query_param:
key, value = query_param.split("=", 1)
self._add_field_value(key, value)
elif query_param:
self._add_field_value(query_param, "")
def _add_field_value(self, field_name: str, value: str) -> None:
super()._add_field_value(field_name, value)
def get(self, field_name: str, default: str = None, *, safe=True) -> Union[str, None]:
return super().get(field_name, default, safe=safe)
def get_list(self, field_name: str, *, safe=True) -> List[str]:
return super().get_list(field_name, safe=safe)
def __str__(self) -> str:
return "&".join(
f"{field_name}={value}"
for field_name in self.fields
for value in self.get_list(field_name)
)
class File:
"""
Class representing a file uploaded via POST.
Examples::
file = request.form_data.files.get("uploaded_file")
# File(filename="foo.txt", content_type="text/plain", size=14)
file.content
# "Hello, world!\\n"
"""
filename: str
"""Filename of the file."""
content_type: str
"""Content type of the file."""
content: Union[str, bytes]
"""Content of the file."""
def __init__(self, filename: str, content_type: str, content: Union[str, bytes]) -> None:
self.filename = filename
self.content_type = content_type
self.content = content
@property
def content_bytes(self) -> bytes:
"""
Content of the file as bytes.
It is recommended to use this instead of ``content`` as it will always return bytes.
Example::
file = request.form_data.files.get("uploaded_file")
with open(file.filename, "wb") as f:
f.write(file.content_bytes)
"""
return self.content.encode("utf-8") if isinstance(self.content, str) else self.content
@property
def size(self) -> int:
"""Length of the file content."""
return len(self.content)
def __repr__(self) -> str:
filename, content_type, size = (
self.filename,
self.content_type,
self.size,
)
return f"<{self.__class__.__name__} {filename=}, {content_type=}, {size=}>"
class Files(_IFieldStorage):
"""Class for files uploaded via POST."""
_storage: Dict[str, List[File]]
def __init__(self) -> None:
self._storage = {}
def _add_field_value(self, field_name: str, value: File) -> None:
super()._add_field_value(field_name, value)
def get(self, field_name: str, default: Any = None) -> Union[File, Any, None]:
return super().get(field_name, default)
def get_list(self, field_name: str) -> List[File]:
return super().get_list(field_name)
class FormData(_IXSSSafeFieldStorage):
"""
Class for parsing and storing form data from POST requests.
Supports ``application/x-www-form-urlencoded``, ``multipart/form-data`` and ``text/plain``
content types.
Examples::
form_data = FormData(b"foo=bar&baz=qux&baz=quuz", "application/x-www-form-urlencoded")
# or
form_data = FormData(b"foo=bar\\r\\nbaz=qux\\r\\nbaz=quux", "text/plain")
# FormData({"foo": ["bar"], "baz": ["qux", "quux"]})
form_data.get("foo") # "bar"
form_data["foo"] # "bar"
form_data.get("non-existent-key") # None
form_data.get_list("baz") # ["qux", "quux"]
"unknown-key" in form_data # False
form_data.fields # ["foo", "baz"]
"""
_storage: Dict[str, List[Union[str, bytes]]]
files: Files
@staticmethod
def _check_is_supported_content_type(content_type: str) -> None:
return content_type in {
"application/x-www-form-urlencoded",
"multipart/form-data",
"text/plain",
}
def __init__(self, data: bytes, headers: Headers, *, debug: bool = False) -> None:
self._storage = {}
self.files = Files()
self.content_type = headers.get_directive("Content-Type")
content_length = int(headers.get("Content-Length", 0))
if debug and not self._check_is_supported_content_type(self.content_type):
_debug_unsupported_form_content_type(self.content_type)
if self.content_type == "application/x-www-form-urlencoded":
self._parse_x_www_form_urlencoded(data[:content_length])
elif self.content_type == "multipart/form-data":
boundary = headers.get_parameter("Content-Type", "boundary")
self._parse_multipart_form_data(data[:content_length], boundary)
elif self.content_type == "text/plain":
self._parse_text_plain(data[:content_length])
def _parse_x_www_form_urlencoded(self, data: bytes) -> None:
if not (decoded_data := data.decode("utf-8").strip("&")):
return
for field_name, value in [
key_value.split("=", 1) if "=" in key_value else (key_value, "")
for key_value in decoded_data.split("&")
]:
self._add_field_value(field_name, value)
def _parse_multipart_form_data(self, data: bytes, boundary: str) -> None:
blocks = data.split(b"--" + boundary.encode())[1:-1]
for block in blocks:
header_bytes, content_bytes = block.split(b"\r\n\r\n", 1)
headers = Headers(header_bytes.decode("utf-8").strip())
field_name = headers.get_parameter("Content-Disposition", "name")
filename = headers.get_parameter("Content-Disposition", "filename")
content_type = headers.get_directive("Content-Type", "text/plain")
charset = headers.get_parameter("Content-Type", "charset", "utf-8")
content = content_bytes[:-2] # remove trailing \r\n
value = content.decode(charset) if content_type == "text/plain" else content
# TODO: Other text content types (e.g. application/json) should be decoded as well and
if filename is not None:
self.files._add_field_value(field_name, File(filename, content_type, value))
else:
self._add_field_value(field_name, value)
def _parse_text_plain(self, data: bytes) -> None:
lines = data.decode("utf-8").split("\r\n")[:-1]
for line in lines:
field_name, value = line.split("=", 1)
self._add_field_value(field_name, value)
def _add_field_value(self, field_name: str, value: Union[str, bytes]) -> None:
super()._add_field_value(field_name, value)
def get(
self, field_name: str, default: Union[str, bytes] = None, *, safe=True
) -> Union[str, bytes, None]:
return super().get(field_name, default, safe=safe)
def get_list(self, field_name: str, *, safe=True) -> List[Union[str, bytes]]:
return super().get_list(field_name, safe=safe)
def __repr__(self) -> str:
class_name = self.__class__.__name__
return f"<{class_name} {repr(self._storage)}, files={repr(self.files._storage)}>"
class Request:
"""
Incoming request, constructed from raw incoming bytes.
It is passed as first argument to all route handlers.
"""
server: "Server"
"""
Server object that received the request.
"""
connection: _ISocket
"""
Socket object used to send and receive data on the connection.
"""
client_address: Tuple[str, int]
"""
Address and port bound to the socket on the other end of the connection.
Example::
request.client_address # ('192.168.137.1', 40684)
"""
method: str
"""Request method e.g. "GET" or "POST"."""
path: str
"""Path of the request, e.g. ``"/foo/bar"``."""
query_params: QueryParams
"""
Query/GET parameters in the request.
Example::
request = Request(..., raw_request=b"GET /?foo=bar&baz=qux HTTP/1.1...")
request.query_params # QueryParams({"foo": "bar"})
request.query_params["foo"] # "bar"
request.query_params.get_list("baz") # ["qux"]
"""
http_version: str
"""HTTP version, e.g. ``"HTTP/1.1"``."""
headers: Headers
"""
Headers from the request.
"""
raw_request: bytes
"""
Raw ``bytes`` that were received from the client.
Should **not** be modified directly.
"""
def __init__(
self,
server: "Server",
connection: _ISocket,
client_address: Tuple[str, int],
raw_request: bytes = None,
) -> None:
self.server = server
self.connection = connection
self.client_address = client_address
self.raw_request = raw_request
self._form_data = None
self._cookies = None
if raw_request is None:
raise ValueError("raw_request cannot be None")
try:
(
self.method,
self.path,
self.query_params,
self.http_version,
self.headers,
) = self._parse_request_header(self._raw_header_bytes)
except Exception as error:
raise ValueError("Unparseable raw_request: ", raw_request) from error
@property
def body(self) -> bytes:
"""Body of the request, as bytes."""
return self._raw_body_bytes
@body.setter
def body(self, body: bytes) -> None:
self.raw_request = self._raw_header_bytes + b"\r\n\r\n" + body
@staticmethod
def _parse_cookies(cookie_header: str) -> None:
"""Parse cookies from headers."""
if cookie_header is None:
return {}
return {
name: value.strip('"')
for name, value in [cookie.strip().split("=", 1) for cookie in cookie_header.split(";")]
}
@property
def cookies(self) -> Dict[str, str]:
"""
Cookies sent with the request.
Example::
request.headers["Cookie"]
# "foo=bar; baz=qux; foo=quux"
request.cookies
# {"foo": "quux", "baz": "qux"}
"""
if self._cookies is None:
self._cookies = self._parse_cookies(self.headers.get("Cookie"))
return self._cookies
@property
def form_data(self) -> Union[FormData, None]:
"""
POST data of the request.
Example::
# application/x-www-form-urlencoded
request = Request(...,
raw_request=b\"\"\"...
foo=bar&baz=qux\"\"\"
)
# or
# multipart/form-data
request = Request(...,
raw_request=b\"\"\"...
--boundary
Content-Disposition: form-data; name="foo"
bar
--boundary
Content-Disposition: form-data; name="baz"
qux
--boundary--\"\"\"
)
# or
# text/plain
request = Request(...,
raw_request=b\"\"\"...
foo=bar
baz=qux
\"\"\"
)
request.form_data # FormData({'foo': ['bar'], 'baz': ['qux']})
request.form_data["foo"] # "bar"
request.form_data.get_list("baz") # ["qux"]
"""
if self._form_data is None and self.method == "POST":
self._form_data = FormData(self.body, self.headers, debug=self.server.debug)
return self._form_data
def json(self) -> Union[dict, None]:
"""
Body of the request, as a JSON-decoded dictionary.
Only available for POST, PUT, PATCH and DELETE requests.
"""
return (
json.loads(self.body)
if (self.body and self.method in {POST, PUT, PATCH, DELETE})
else None
)
@property
def _raw_header_bytes(self) -> bytes:
"""Returns headers bytes."""
empty_line_index = self.raw_request.find(b"\r\n\r\n")
return self.raw_request[:empty_line_index]
@property
def _raw_body_bytes(self) -> bytes:
"""Returns body bytes."""
empty_line_index = self.raw_request.find(b"\r\n\r\n")
return self.raw_request[empty_line_index + 4 :]
@staticmethod
def _parse_request_header(
header_bytes: bytes,
) -> Tuple[str, str, QueryParams, str, Headers]:
"""Parse HTTP Start line to method, path, query_params and http_version."""
start_line, headers_string = header_bytes.decode("utf-8").strip().split("\r\n", 1)
method, path, http_version = start_line.strip().split()
path = path if "?" in path else path + "?"
path, query_string = path.split("?", 1)
query_params = QueryParams(query_string)
headers = Headers(headers_string)
return method, path, query_params, http_version, headers
def __repr__(self) -> str:
path = self.path + (f"?{self.query_params}" if self.query_params else "")
return f'<{self.__class__.__name__} "{self.method} {path}">'
def _debug_unsupported_form_content_type(content_type: str) -> None:
"""Warns when an unsupported form content type is used."""
print(
f"WARNING: Unsupported Content-Type: {content_type}. "
"Only `application/x-www-form-urlencoded`, `multipart/form-data` and `text/plain` are "
"supported."
)