Skip to content

Prevent parent directory access, custom Errors #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 24, 2023
52 changes: 52 additions & 0 deletions adafruit_httpserver/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# SPDX-FileCopyrightText: Copyright (c) 2022 Dan Halbert for Adafruit Industries
#
# SPDX-License-Identifier: MIT
"""
`adafruit_httpserver.exceptions`
====================================================
* Author(s): Michał Pokusa
"""


class InvalidPathError(Exception):
"""
Parent class for all path related errors.
"""


class ParentDirectoryReferenceError(InvalidPathError):
"""
Path contains ``..``, a reference to the parent directory.
"""

def __init__(self, path: str) -> None:
"""Creates a new ``ParentDirectoryReferenceError`` for the ``path``."""
super().__init__(f"Parent directory reference in path: {path}")


class BackslashInPathError(InvalidPathError):
"""
Backslash ``\\`` in path.
"""

def __init__(self, path: str) -> None:
"""Creates a new ``BackslashInPathError`` for the ``path``."""
super().__init__(f"Backslash in path: {path}")


class ResponseAlreadySentError(Exception):
"""
Another ``HTTPResponse`` has already been sent. There can only be one per ``HTTPRequest``.
"""


class FileNotExistsError(Exception):
"""
Raised when a file does not exist.
"""

def __init__(self, path: str) -> None:
"""
Creates a new ``FileNotExistsError`` for the file at ``path``.
"""
super().__init__(f"File does not exist: {path}")
84 changes: 67 additions & 17 deletions adafruit_httpserver/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"""

try:
from typing import Optional, Dict, Union, Tuple
from typing import Optional, Dict, Union, Tuple, Callable
from socket import socket
from socketpool import SocketPool
except ImportError:
Expand All @@ -17,12 +17,33 @@
import os
from errno import EAGAIN, ECONNRESET

from .exceptions import (
BackslashInPathError,
FileNotExistsError,
ParentDirectoryReferenceError,
ResponseAlreadySentError,
)
from .mime_type import MIMEType
from .request import HTTPRequest
from .status import HTTPStatus, CommonHTTPStatus
from .headers import HTTPHeaders


def _prevent_multiple_send_calls(function: Callable):
"""
Decorator that prevents calling ``send`` or ``send_file`` more than once.
"""

def wrapper(self: "HTTPResponse", *args, **kwargs):
if self._response_already_sent: # pylint: disable=protected-access
raise ResponseAlreadySentError

result = function(self, *args, **kwargs)
return result

return wrapper


class HTTPResponse:
"""
Response to a given `HTTPRequest`. Use in `HTTPServer.route` decorator functions.
Expand Down Expand Up @@ -73,8 +94,8 @@ def route_func(request):
"""
Defaults to ``text/plain`` if not set.

Can be explicitly provided in the constructor, in `send()` or
implicitly determined from filename in `send_file()`.
Can be explicitly provided in the constructor, in ``send()`` or
implicitly determined from filename in ``send_file()``.
Comment on lines +97 to +98
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation builder should be able to use single backticks and produce cross references to other methods/functions here. Did that fail?

Copy link
Contributor Author

@michalpokusa michalpokusa Apr 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had problems with single ticks as it often couldn't find the member in X in module Y, not entirely sure why

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a big deal; I've also found it doesn't work under strange circumstances.

Copy link
Contributor Author

@michalpokusa michalpokusa Apr 21, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dhalbert Are there any more changes that you would like me to include?


Common MIME types are defined in `adafruit_httpserver.mime_type.MIMEType`.
"""
Expand All @@ -94,7 +115,7 @@ def __init__( # pylint: disable=too-many-arguments
Sets `status`, ``headers`` and `http_version`
and optionally default ``content_type``.

To send the response, call `send` or `send_file`.
To send the response, call ``send`` or ``send_file``.
For chunked response use
``with HTTPRequest(request, content_type=..., chunked=True) as r:`` and `send_chunk`.
"""
Expand All @@ -115,7 +136,7 @@ def _send_headers(
) -> None:
"""
Sends headers.
Implicitly called by `send` and `send_file` and in
Implicitly called by ``send`` and ``send_file`` and in
``with HTTPResponse(request, chunked=True) as response:`` context manager.
"""
headers = self.headers.copy()
Expand All @@ -141,6 +162,7 @@ def _send_headers(
self.request.connection, response_message_header.encode("utf-8")
)

@_prevent_multiple_send_calls
def send(
self,
body: str = "",
Expand All @@ -152,8 +174,6 @@ def send(

Should be called **only once** per response.
"""
if self._response_already_sent:
raise RuntimeError("Response was already sent")

if getattr(body, "encode", None):
encoded_response_message_body = body.encode("utf-8")
Expand All @@ -167,12 +187,41 @@ def send(
self._send_bytes(self.request.connection, encoded_response_message_body)
self._response_already_sent = True

def send_file(
@staticmethod
def _check_file_path_is_valid(file_path: str) -> bool:
"""
Checks if ``file_path`` is valid.
If not raises error corresponding to the problem.
"""

# Check for backslashes
if "\\" in file_path: # pylint: disable=anomalous-backslash-in-string
raise BackslashInPathError(file_path)

# Check each component of the path for parent directory references
for part in file_path.split("/"):
if part == "..":
raise ParentDirectoryReferenceError(file_path)

@staticmethod
def _get_file_length(file_path: str) -> int:
"""
Tries to get the length of the file at ``file_path``.
Raises ``FileNotExistsError`` if file does not exist.
"""
try:
return os.stat(file_path)[6]
except OSError:
raise FileNotExistsError(file_path) # pylint: disable=raise-missing-from

@_prevent_multiple_send_calls
def send_file( # pylint: disable=too-many-arguments
self,
filename: str = "index.html",
root_path: str = "./",
buffer_size: int = 1024,
head_only: bool = False,
safe: bool = True,
) -> None:
"""
Send response with content of ``filename`` located in ``root_path``.
Expand All @@ -181,25 +230,26 @@ def send_file(

Should be called **only once** per response.
"""
if self._response_already_sent:
raise RuntimeError("Response was already sent")

if safe:
self._check_file_path_is_valid(filename)

if not root_path.endswith("/"):
root_path += "/"
try:
file_length = os.stat(root_path + filename)[6]
except OSError:
# If the file doesn't exist, return 404.
HTTPResponse(self.request, status=CommonHTTPStatus.NOT_FOUND_404).send()
return
if filename.startswith("/"):
filename = filename[1:]

full_file_path = root_path + filename

file_length = self._get_file_length(full_file_path)

self._send_headers(
content_type=MIMEType.from_file_name(filename),
content_length=file_length,
)

if not head_only:
with open(root_path + filename, "rb") as file:
with open(full_file_path, "rb") as file:
while bytes_read := file.read(buffer_size):
self._send_bytes(self.request.connection, bytes_read)
self._response_already_sent = True
Expand Down
84 changes: 47 additions & 37 deletions adafruit_httpserver/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from errno import EAGAIN, ECONNRESET, ETIMEDOUT

from .exceptions import FileNotExistsError, InvalidPathError
from .methods import HTTPMethod
from .request import HTTPRequest
from .response import HTTPResponse
Expand All @@ -26,18 +27,19 @@
class HTTPServer:
"""A basic socket-based HTTP server."""

def __init__(self, socket_source: Protocol) -> None:
def __init__(self, socket_source: Protocol, root_path: str) -> None:
"""Create a server, and get it ready to run.

:param socket: An object that is a source of sockets. This could be a `socketpool`
in CircuitPython or the `socket` module in CPython.
:param str root_path: Root directory to serve files from
"""
self._buffer = bytearray(1024)
self._timeout = 1
self.routes = _HTTPRoutes()
self._socket_source = socket_source
self._sock = None
self.root_path = "/"
self.root_path = root_path

def route(self, path: str, method: HTTPMethod = HTTPMethod.GET) -> Callable:
"""
Expand All @@ -63,32 +65,28 @@ def route_decorator(func: Callable) -> Callable:

return route_decorator

def serve_forever(self, host: str, port: int = 80, root_path: str = "") -> None:
def serve_forever(self, host: str, port: int = 80) -> None:
"""Wait for HTTP requests at the given host and port. Does not return.

:param str host: host name or IP address
:param int port: port
:param str root_path: root directory to serve files from
"""
self.start(host, port, root_path)
self.start(host, port)

while True:
try:
self.poll()
except OSError:
continue

def start(self, host: str, port: int = 80, root_path: str = "") -> None:
def start(self, host: str, port: int = 80) -> None:
"""
Start the HTTP server at the given host and port. Requires calling
poll() in a while loop to handle incoming requests.

:param str host: host name or IP address
:param int port: port
:param str root_path: root directory to serve files from
"""
self.root_path = root_path

self._sock = self._socket_source.socket(
self._socket_source.AF_INET, self._socket_source.SOCK_STREAM
)
Expand Down Expand Up @@ -158,38 +156,50 @@ def poll(self):
conn, received_body_bytes, content_length
)

# Find a handler for the route
handler = self.routes.find_handler(
_HTTPRoute(request.path, request.method)
)

# If a handler for route exists and is callable, call it.
if handler is not None and callable(handler):
handler(request)

# If no handler exists and request method is GET, try to serve a file.
elif handler is None and request.method in (
HTTPMethod.GET,
HTTPMethod.HEAD,
):
filename = "index.html" if request.path == "/" else request.path
HTTPResponse(request).send_file(
filename=filename,
root_path=self.root_path,
buffer_size=self.request_buffer_size,
head_only=(request.method == HTTPMethod.HEAD),
try:
# If a handler for route exists and is callable, call it.
if handler is not None and callable(handler):
handler(request)

# If no handler exists and request method is GET or HEAD, try to serve a file.
elif handler is None and request.method in (
HTTPMethod.GET,
HTTPMethod.HEAD,
):
filename = "index.html" if request.path == "/" else request.path
HTTPResponse(request).send_file(
filename=filename,
root_path=self.root_path,
buffer_size=self.request_buffer_size,
head_only=(request.method == HTTPMethod.HEAD),
)
else:
HTTPResponse(
request, status=CommonHTTPStatus.BAD_REQUEST_400
).send()

except InvalidPathError as error:
HTTPResponse(request, status=CommonHTTPStatus.FORBIDDEN_403).send(
str(error)
)
else:
HTTPResponse(
request, status=CommonHTTPStatus.BAD_REQUEST_400
).send()

except OSError as ex:
# handle EAGAIN and ECONNRESET
if ex.errno == EAGAIN:
# there is no data available right now, try again later.

except FileNotExistsError as error:
HTTPResponse(request, status=CommonHTTPStatus.NOT_FOUND_404).send(
str(error)
)

except OSError as error:
# Handle EAGAIN and ECONNRESET
if error.errno == EAGAIN:
# There is no data available right now, try again later.
return
if ex.errno == ECONNRESET:
# connection reset by peer, try again later.
if error.errno == ECONNRESET:
# Connection reset by peer, try again later.
return
raise

Expand All @@ -204,7 +214,7 @@ def request_buffer_size(self) -> int:

Example::

server = HTTPServer(pool)
server = HTTPServer(pool, "/static")
server.request_buffer_size = 2048

server.serve_forever(str(wifi.radio.ipv4_address))
Expand All @@ -226,7 +236,7 @@ def socket_timeout(self) -> int:

Example::

server = HTTPServer(pool)
server = HTTPServer(pool, "/static")
server.socket_timeout = 3

server.serve_forever(str(wifi.radio.ipv4_address))
Expand Down
3 changes: 3 additions & 0 deletions adafruit_httpserver/status.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ class CommonHTTPStatus(HTTPStatus): # pylint: disable=too-few-public-methods
BAD_REQUEST_400 = HTTPStatus(400, "Bad Request")
"""400 Bad Request"""

FORBIDDEN_403 = HTTPStatus(403, "Forbidden")
"""403 Forbidden"""

NOT_FOUND_404 = HTTPStatus(404, "Not Found")
"""404 Not Found"""

Expand Down
3 changes: 3 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@

.. automodule:: adafruit_httpserver.mime_type
:members:

.. automodule:: adafruit_httpserver.exceptions
:members:
Loading