add a protocol to read/parse multipart responses

This commit is contained in:
H. Shay 2024-06-26 20:38:34 -07:00
parent 7c169f4d2c
commit 402f67dc67
4 changed files with 172 additions and 2 deletions

View file

@ -96,3 +96,6 @@ ignore_missing_imports = True
# https://github.com/twisted/treq/pull/366
[mypy-treq.*]
ignore_missing_imports = True
[mypy-multipart.*]
ignore_missing_imports = True

18
poetry.lock generated
View file

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
[[package]]
name = "annotated-types"
@ -2039,6 +2039,20 @@ files = [
[package.dependencies]
six = ">=1.5"
[[package]]
name = "python-multipart"
version = "0.0.9"
description = "A streaming multipart parser for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"},
{file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"},
]
[package.extras]
dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"]
[[package]]
name = "pytz"
version = "2022.7.1"
@ -3187,4 +3201,4 @@ user-search = ["pyicu"]
[metadata]
lock-version = "2.0"
python-versions = "^3.8.0"
content-hash = "107c8fb5c67360340854fbdba3c085fc5f9c7be24bcb592596a914eea621faea"
content-hash = "2971f8200039ef76c661b26ff9ef3b4487ed5b47a161352481dd73e23035f5a7"

View file

@ -248,6 +248,7 @@ Pympler = { version = "*", optional = true }
parameterized = { version = ">=0.7.4", optional = true }
idna = { version = ">=2.5", optional = true }
pyicu = { version = ">=2.10.2", optional = true }
python-multipart = "^0.0.9"
[tool.poetry.extras]
# NB: Packages that should be part of `pip install matrix-synapse[all]` need to be specified

View file

@ -35,6 +35,8 @@ from typing import (
Union,
)
import attr
import multipart
import treq
from canonicaljson import encode_canonical_json
from netaddr import AddrFormatError, IPAddress, IPSet
@ -1006,6 +1008,130 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
self._maybe_fail()
@attr.s(auto_attribs=True, slots=True)
class MultipartResponse:
"""
A small class to hold parsed values of a multipart response.
"""
json: bytes = b"{}"
length: Optional[int] = None
content_type: Optional[bytes] = None
disposition: Optional[bytes] = None
url: Optional[bytes] = None
class _MultipartParserProtocol(protocol.Protocol):
"""
Protocol to read and parse a MSC3916 multipart/mixed response
"""
transport: Optional[ITCPTransport] = None
def __init__(
self,
stream: ByteWriteable,
deferred: defer.Deferred,
boundary: str,
max_length: Optional[int],
) -> None:
self.stream = stream
self.deferred = deferred
self.boundary = boundary
self.max_length = max_length
self.parser = None
self.multipart_response = MultipartResponse()
self.has_redirect = False
self.in_json = False
self.json_done = False
self.file_length = 0
self.total_length = 0
self.in_disposition = False
self.in_content_type = False
def dataReceived(self, incoming_data: bytes) -> None:
if self.deferred.called:
return
# we don't have a parser yet, instantiate it
if not self.parser:
def on_header_field(data: bytes, start: int, end: int) -> None:
if data[start:end] == b"Location":
self.has_redirect = True
if data[start:end] == b"Content-Disposition":
self.in_disposition = True
if data[start:end] == b"Content-Type":
self.in_content_type = True
def on_header_value(data: bytes, start: int, end: int) -> None:
# the first header should be content-type for application/json
if not self.in_json and not self.json_done:
assert data[start:end] == b"application/json"
self.in_json = True
elif self.has_redirect:
self.multipart_response.url = data[start:end]
elif self.in_content_type:
self.multipart_response.content_type = data[start:end]
self.in_content_type = False
elif self.in_disposition:
self.multipart_response.disposition = data[start:end]
self.in_disposition = False
def on_part_data(data: bytes, start: int, end: int) -> None:
# we've seen json header but haven't written the json data
if self.in_json and not self.json_done:
self.multipart_response.json = data[start:end]
self.json_done = True
# we have a redirect header rather than a file, and have already captured it
elif self.has_redirect:
return
# otherwise we are in the file part
else:
logger.info("Writing multipart file data to stream")
try:
self.stream.write(data[start:end])
except Exception as e:
logger.warning(
f"Exception encountered writing file data to stream: {e}"
)
self.deferred.errback()
self.file_length += end - start
callbacks = {
"on_header_field": on_header_field,
"on_header_value": on_header_value,
"on_part_data": on_part_data,
}
self.parser = multipart.MultipartParser(self.boundary, callbacks)
self.total_length += len(incoming_data)
if self.max_length is not None and self.total_length >= self.max_length:
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
assert self.transport is not None
self.transport.abortConnection()
try:
self.parser.write(incoming_data) # type: ignore[attr-defined]
except Exception as e:
logger.warning(f"Exception writing to multipart parser: {e}")
self.deferred.errback()
return
def connectionLost(self, reason: Failure = connectionDone) -> None:
# If the maximum size was already exceeded, there's nothing to do.
if self.deferred.called:
return
if reason.check(ResponseDone):
self.multipart_response.length = self.file_length
self.deferred.callback(self.multipart_response)
else:
self.deferred.errback(reason)
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
@ -1091,6 +1217,32 @@ def read_body_with_max_size(
return d
def read_multipart_response(
response: IResponse, stream: ByteWriteable, boundary: str, max_length: Optional[int]
) -> "defer.Deferred[MultipartResponse]":
"""
Reads a MSC3916 multipart/mixed response and parses it, reading the file part (if it contains one) into
the stream passed in and returning a deferred resolving to a MultipartResponse
Args:
response: The HTTP response to read from.
stream: The file-object to write to.
boundary: the multipart/mixed boundary string
max_length: maximum allowable length of the response
"""
d: defer.Deferred[MultipartResponse] = defer.Deferred()
# If the Content-Length header gives a size larger than the maximum allowed
# size, do not bother downloading the body.
if max_length is not None and response.length != UNKNOWN_LENGTH:
if response.length > max_length:
response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
return d
response.deliverBody(_MultipartParserProtocol(stream, d, boundary, max_length))
return d
def encode_query_args(args: Optional[QueryParams]) -> bytes:
"""
Encodes a map of query arguments to bytes which can be appended to a URL.