From 2244d33a991538aa8a2fa3e5f23af4eb0960364d Mon Sep 17 00:00:00 2001 From: Harry Lees Date: Sat, 16 Aug 2025 17:07:32 +0100 Subject: [PATCH 1/2] Add `content_type` field to File class --- python_multipart/multipart.py | 25 ++++++++++++++++++++++--- tests/test_multipart.py | 25 +++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 3 deletions(-) diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index 6c84829..4c1a626 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -358,13 +358,20 @@ class File: config: The configuration for this File. See above for valid configuration keys and their corresponding values. """ # noqa: E501 - def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}) -> None: + def __init__( + self, + file_name: bytes | None, + field_name: bytes | None = None, + config: FileConfig = {}, + content_type: bytes | None = None, + ) -> None: # Save configuration, set other variables default. self.logger = logging.getLogger(__name__) self._config = config self._in_memory = True self._bytes_written = 0 self._fileobj: BytesIO | BufferedRandom = BytesIO() + self._content_type = content_type # Save the provided field/file name. self._field_name = field_name @@ -392,6 +399,11 @@ def file_name(self) -> bytes | None: """The file name given in the upload request.""" return self._file_name + @property + def content_type(self) -> bytes | None: + """The Content-Type given in the upload request.""" + return self._content_type + @property def actual_file_name(self) -> bytes | None: """The file name that this file is saved as. Will be None if it's not @@ -570,7 +582,9 @@ def close(self) -> None: self._fileobj.close() def __repr__(self) -> str: - return "{}(file_name={!r}, field_name={!r})".format(self.__class__.__name__, self.file_name, self.field_name) + return "{}(file_name={!r}, field_name={!r}, content_type={!r})".format( + self.__class__.__name__, self.file_name, self.field_name, self.content_type + ) class BaseParser: @@ -1695,7 +1709,12 @@ def on_headers_finished() -> None: if file_name is None: f_multi = FieldClass(field_name) else: - f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config)) + f_multi = FileClass( + file_name, + field_name, + config=cast("FileConfig", self.config), + content_type=headers.get(b"Content-Type", None), + ) is_file = True # Parse the given Content-Transfer-Encoding to determine what diff --git a/tests/test_multipart.py b/tests/test_multipart.py index ce92ff4..40238cc 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1367,6 +1367,31 @@ def on_header_begin() -> None: # for each header in the multipart message. self.assertEqual(calls, 3) + def test_file_content_type_is_set(self) -> None: + """ + This test verifies that the content_type is set on File + https://github.com/Kludex/python-multipart/issues/207 + """ + + file: FileProtocol | None = None + + with open(os.path.join(http_tests_dir, "single_file.http"), "rb") as f: + test_data = f.read() + + def on_file(f: FileProtocol) -> None: + nonlocal file + file = f + + parser = FormParser("multipart/form-data", None, on_file, boundary=b"----WebKitFormBoundary5BZGOJCWtXGYC9HW") + + # Create multipart parser and feed it + i = parser.write(test_data) + parser.finalize() + + self.assertEqual(i, len(test_data)) + self.assertIsNotNone(file) + self.assertEqual(file.content_type, b"text/plain") + class TestHelperFunctions(unittest.TestCase): def test_create_form_parser(self) -> None: From aea973c7d7ab62462cd9cd993f10be3f57a3f7c1 Mon Sep 17 00:00:00 2001 From: Harry Lees Date: Sat, 16 Aug 2025 17:38:18 +0100 Subject: [PATCH 2/2] Fix `scripts/check` checks --- python_multipart/multipart.py | 8 +++++++- tests/test_multipart.py | 4 ++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index 4c1a626..393c74f 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -73,7 +73,13 @@ def __init__(self, name: bytes | None) -> None: ... def set_none(self) -> None: ... class FileProtocol(_FormProtocol, Protocol): - def __init__(self, file_name: bytes | None, field_name: bytes | None, config: FileConfig) -> None: ... + def __init__( + self, + file_name: bytes | None, + field_name: bytes | None, + config: FileConfig, + content_type: bytes | None = None, + ) -> None: ... OnFieldCallback = Callable[[FieldProtocol], None] OnFileCallback = Callable[[FileProtocol], None] diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 40238cc..a3f038f 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1373,7 +1373,7 @@ def test_file_content_type_is_set(self) -> None: https://github.com/Kludex/python-multipart/issues/207 """ - file: FileProtocol | None = None + file = None with open(os.path.join(http_tests_dir, "single_file.http"), "rb") as f: test_data = f.read() @@ -1390,7 +1390,7 @@ def on_file(f: FileProtocol) -> None: self.assertEqual(i, len(test_data)) self.assertIsNotNone(file) - self.assertEqual(file.content_type, b"text/plain") + self.assertEqual(cast(File, file).content_type, b"text/plain") class TestHelperFunctions(unittest.TestCase):