Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions python_multipart/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,13 @@ def __init__(self, name: bytes | None) -> None: ...
def set_none(self) -> None: ...

class FileProtocol(_FormProtocol, Protocol):
def __init__(self, file_name: bytes | None, field_name: bytes | None, config: FileConfig) -> None: ...
def __init__(
self,
file_name: bytes | None,
field_name: bytes | None,
config: FileConfig,
content_type: bytes | None = None,
) -> None: ...

OnFieldCallback = Callable[[FieldProtocol], None]
OnFileCallback = Callable[[FileProtocol], None]
Expand Down Expand Up @@ -358,13 +364,20 @@ class File:
config: The configuration for this File. See above for valid configuration keys and their corresponding values.
""" # noqa: E501

def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}) -> None:
def __init__(
self,
file_name: bytes | None,
field_name: bytes | None = None,
config: FileConfig = {},
content_type: bytes | None = None,
) -> None:
# Save configuration, set other variables default.
self.logger = logging.getLogger(__name__)
self._config = config
self._in_memory = True
self._bytes_written = 0
self._fileobj: BytesIO | BufferedRandom = BytesIO()
self._content_type = content_type

# Save the provided field/file name.
self._field_name = field_name
Expand Down Expand Up @@ -392,6 +405,11 @@ def file_name(self) -> bytes | None:
"""The file name given in the upload request."""
return self._file_name

@property
def content_type(self) -> bytes | None:
"""The Content-Type given in the upload request."""
return self._content_type

@property
def actual_file_name(self) -> bytes | None:
"""The file name that this file is saved as. Will be None if it's not
Expand Down Expand Up @@ -570,7 +588,9 @@ def close(self) -> None:
self._fileobj.close()

def __repr__(self) -> str:
return "{}(file_name={!r}, field_name={!r})".format(self.__class__.__name__, self.file_name, self.field_name)
return "{}(file_name={!r}, field_name={!r}, content_type={!r})".format(
self.__class__.__name__, self.file_name, self.field_name, self.content_type
)


class BaseParser:
Expand Down Expand Up @@ -1695,7 +1715,12 @@ def on_headers_finished() -> None:
if file_name is None:
f_multi = FieldClass(field_name)
else:
f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config))
f_multi = FileClass(
file_name,
field_name,
config=cast("FileConfig", self.config),
content_type=headers.get(b"Content-Type", None),
)
is_file = True

# Parse the given Content-Transfer-Encoding to determine what
Expand Down
25 changes: 25 additions & 0 deletions tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,6 +1367,31 @@ def on_header_begin() -> None:
# for each header in the multipart message.
self.assertEqual(calls, 3)

def test_file_content_type_is_set(self) -> None:
"""
This test verifies that the content_type is set on File
https://github.com/Kludex/python-multipart/issues/207
"""

file = None

with open(os.path.join(http_tests_dir, "single_file.http"), "rb") as f:
test_data = f.read()

def on_file(f: FileProtocol) -> None:
nonlocal file
file = f

parser = FormParser("multipart/form-data", None, on_file, boundary=b"----WebKitFormBoundary5BZGOJCWtXGYC9HW")

# Create multipart parser and feed it
i = parser.write(test_data)
parser.finalize()

self.assertEqual(i, len(test_data))
self.assertIsNotNone(file)
self.assertEqual(cast(File, file).content_type, b"text/plain")


class TestHelperFunctions(unittest.TestCase):
def test_create_form_parser(self) -> None:
Expand Down