mirror of
https://github.com/element-hq/synapse
synced 2024-10-02 08:02:41 +00:00
add a protocol to read/parse multipart responses
This commit is contained in:
parent
7c169f4d2c
commit
402f67dc67
4 changed files with 172 additions and 2 deletions
3
mypy.ini
3
mypy.ini
|
@ -96,3 +96,6 @@ ignore_missing_imports = True
|
|||
# https://github.com/twisted/treq/pull/366
|
||||
[mypy-treq.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-multipart.*]
|
||||
ignore_missing_imports = True
|
||||
|
|
18
poetry.lock
generated
18
poetry.lock
generated
|
@ -1,4 +1,4 @@
|
|||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "annotated-types"
|
||||
|
@ -2039,6 +2039,20 @@ files = [
|
|||
[package.dependencies]
|
||||
six = ">=1.5"
|
||||
|
||||
[[package]]
|
||||
name = "python-multipart"
|
||||
version = "0.0.9"
|
||||
description = "A streaming multipart parser for Python"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"},
|
||||
{file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"]
|
||||
|
||||
[[package]]
|
||||
name = "pytz"
|
||||
version = "2022.7.1"
|
||||
|
@ -3187,4 +3201,4 @@ user-search = ["pyicu"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.8.0"
|
||||
content-hash = "107c8fb5c67360340854fbdba3c085fc5f9c7be24bcb592596a914eea621faea"
|
||||
content-hash = "2971f8200039ef76c661b26ff9ef3b4487ed5b47a161352481dd73e23035f5a7"
|
||||
|
|
|
@ -248,6 +248,7 @@ Pympler = { version = "*", optional = true }
|
|||
parameterized = { version = ">=0.7.4", optional = true }
|
||||
idna = { version = ">=2.5", optional = true }
|
||||
pyicu = { version = ">=2.10.2", optional = true }
|
||||
python-multipart = "^0.0.9"
|
||||
|
||||
[tool.poetry.extras]
|
||||
# NB: Packages that should be part of `pip install matrix-synapse[all]` need to be specified
|
||||
|
|
|
@ -35,6 +35,8 @@ from typing import (
|
|||
Union,
|
||||
)
|
||||
|
||||
import attr
|
||||
import multipart
|
||||
import treq
|
||||
from canonicaljson import encode_canonical_json
|
||||
from netaddr import AddrFormatError, IPAddress, IPSet
|
||||
|
@ -1006,6 +1008,130 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
|||
self._maybe_fail()
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, slots=True)
|
||||
class MultipartResponse:
|
||||
"""
|
||||
A small class to hold parsed values of a multipart response.
|
||||
"""
|
||||
|
||||
json: bytes = b"{}"
|
||||
length: Optional[int] = None
|
||||
content_type: Optional[bytes] = None
|
||||
disposition: Optional[bytes] = None
|
||||
url: Optional[bytes] = None
|
||||
|
||||
|
||||
class _MultipartParserProtocol(protocol.Protocol):
|
||||
"""
|
||||
Protocol to read and parse a MSC3916 multipart/mixed response
|
||||
"""
|
||||
|
||||
transport: Optional[ITCPTransport] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stream: ByteWriteable,
|
||||
deferred: defer.Deferred,
|
||||
boundary: str,
|
||||
max_length: Optional[int],
|
||||
) -> None:
|
||||
self.stream = stream
|
||||
self.deferred = deferred
|
||||
self.boundary = boundary
|
||||
self.max_length = max_length
|
||||
self.parser = None
|
||||
self.multipart_response = MultipartResponse()
|
||||
self.has_redirect = False
|
||||
self.in_json = False
|
||||
self.json_done = False
|
||||
self.file_length = 0
|
||||
self.total_length = 0
|
||||
self.in_disposition = False
|
||||
self.in_content_type = False
|
||||
|
||||
def dataReceived(self, incoming_data: bytes) -> None:
|
||||
if self.deferred.called:
|
||||
return
|
||||
|
||||
# we don't have a parser yet, instantiate it
|
||||
if not self.parser:
|
||||
|
||||
def on_header_field(data: bytes, start: int, end: int) -> None:
|
||||
if data[start:end] == b"Location":
|
||||
self.has_redirect = True
|
||||
if data[start:end] == b"Content-Disposition":
|
||||
self.in_disposition = True
|
||||
if data[start:end] == b"Content-Type":
|
||||
self.in_content_type = True
|
||||
|
||||
def on_header_value(data: bytes, start: int, end: int) -> None:
|
||||
# the first header should be content-type for application/json
|
||||
if not self.in_json and not self.json_done:
|
||||
assert data[start:end] == b"application/json"
|
||||
self.in_json = True
|
||||
elif self.has_redirect:
|
||||
self.multipart_response.url = data[start:end]
|
||||
elif self.in_content_type:
|
||||
self.multipart_response.content_type = data[start:end]
|
||||
self.in_content_type = False
|
||||
elif self.in_disposition:
|
||||
self.multipart_response.disposition = data[start:end]
|
||||
self.in_disposition = False
|
||||
|
||||
def on_part_data(data: bytes, start: int, end: int) -> None:
|
||||
# we've seen json header but haven't written the json data
|
||||
if self.in_json and not self.json_done:
|
||||
self.multipart_response.json = data[start:end]
|
||||
self.json_done = True
|
||||
# we have a redirect header rather than a file, and have already captured it
|
||||
elif self.has_redirect:
|
||||
return
|
||||
# otherwise we are in the file part
|
||||
else:
|
||||
logger.info("Writing multipart file data to stream")
|
||||
try:
|
||||
self.stream.write(data[start:end])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Exception encountered writing file data to stream: {e}"
|
||||
)
|
||||
self.deferred.errback()
|
||||
self.file_length += end - start
|
||||
|
||||
callbacks = {
|
||||
"on_header_field": on_header_field,
|
||||
"on_header_value": on_header_value,
|
||||
"on_part_data": on_part_data,
|
||||
}
|
||||
self.parser = multipart.MultipartParser(self.boundary, callbacks)
|
||||
|
||||
self.total_length += len(incoming_data)
|
||||
if self.max_length is not None and self.total_length >= self.max_length:
|
||||
self.deferred.errback(BodyExceededMaxSize())
|
||||
# Close the connection (forcefully) since all the data will get
|
||||
# discarded anyway.
|
||||
assert self.transport is not None
|
||||
self.transport.abortConnection()
|
||||
|
||||
try:
|
||||
self.parser.write(incoming_data) # type: ignore[attr-defined]
|
||||
except Exception as e:
|
||||
logger.warning(f"Exception writing to multipart parser: {e}")
|
||||
self.deferred.errback()
|
||||
return
|
||||
|
||||
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||
# If the maximum size was already exceeded, there's nothing to do.
|
||||
if self.deferred.called:
|
||||
return
|
||||
|
||||
if reason.check(ResponseDone):
|
||||
self.multipart_response.length = self.file_length
|
||||
self.deferred.callback(self.multipart_response)
|
||||
else:
|
||||
self.deferred.errback(reason)
|
||||
|
||||
|
||||
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
|
||||
|
||||
|
@ -1091,6 +1217,32 @@ def read_body_with_max_size(
|
|||
return d
|
||||
|
||||
|
||||
def read_multipart_response(
|
||||
response: IResponse, stream: ByteWriteable, boundary: str, max_length: Optional[int]
|
||||
) -> "defer.Deferred[MultipartResponse]":
|
||||
"""
|
||||
Reads a MSC3916 multipart/mixed response and parses it, reading the file part (if it contains one) into
|
||||
the stream passed in and returning a deferred resolving to a MultipartResponse
|
||||
|
||||
Args:
|
||||
response: The HTTP response to read from.
|
||||
stream: The file-object to write to.
|
||||
boundary: the multipart/mixed boundary string
|
||||
max_length: maximum allowable length of the response
|
||||
"""
|
||||
d: defer.Deferred[MultipartResponse] = defer.Deferred()
|
||||
|
||||
# If the Content-Length header gives a size larger than the maximum allowed
|
||||
# size, do not bother downloading the body.
|
||||
if max_length is not None and response.length != UNKNOWN_LENGTH:
|
||||
if response.length > max_length:
|
||||
response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
|
||||
return d
|
||||
|
||||
response.deliverBody(_MultipartParserProtocol(stream, d, boundary, max_length))
|
||||
return d
|
||||
|
||||
|
||||
def encode_query_args(args: Optional[QueryParams]) -> bytes:
|
||||
"""
|
||||
Encodes a map of query arguments to bytes which can be appended to a URL.
|
||||
|
|
Loading…
Reference in a new issue