mirror of
https://github.com/element-hq/synapse
synced 2024-07-04 08:43:29 +00:00
add a protocol to read/parse multipart responses
This commit is contained in:
parent
7c169f4d2c
commit
402f67dc67
3
mypy.ini
3
mypy.ini
|
@ -96,3 +96,6 @@ ignore_missing_imports = True
|
||||||
# https://github.com/twisted/treq/pull/366
|
# https://github.com/twisted/treq/pull/366
|
||||||
[mypy-treq.*]
|
[mypy-treq.*]
|
||||||
ignore_missing_imports = True
|
ignore_missing_imports = True
|
||||||
|
|
||||||
|
[mypy-multipart.*]
|
||||||
|
ignore_missing_imports = True
|
||||||
|
|
18
poetry.lock
generated
18
poetry.lock
generated
|
@ -1,4 +1,4 @@
|
||||||
# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
|
# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "annotated-types"
|
name = "annotated-types"
|
||||||
|
@ -2039,6 +2039,20 @@ files = [
|
||||||
[package.dependencies]
|
[package.dependencies]
|
||||||
six = ">=1.5"
|
six = ">=1.5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "python-multipart"
|
||||||
|
version = "0.0.9"
|
||||||
|
description = "A streaming multipart parser for Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"},
|
||||||
|
{file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pytz"
|
name = "pytz"
|
||||||
version = "2022.7.1"
|
version = "2022.7.1"
|
||||||
|
@ -3187,4 +3201,4 @@ user-search = ["pyicu"]
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.8.0"
|
python-versions = "^3.8.0"
|
||||||
content-hash = "107c8fb5c67360340854fbdba3c085fc5f9c7be24bcb592596a914eea621faea"
|
content-hash = "2971f8200039ef76c661b26ff9ef3b4487ed5b47a161352481dd73e23035f5a7"
|
||||||
|
|
|
@ -248,6 +248,7 @@ Pympler = { version = "*", optional = true }
|
||||||
parameterized = { version = ">=0.7.4", optional = true }
|
parameterized = { version = ">=0.7.4", optional = true }
|
||||||
idna = { version = ">=2.5", optional = true }
|
idna = { version = ">=2.5", optional = true }
|
||||||
pyicu = { version = ">=2.10.2", optional = true }
|
pyicu = { version = ">=2.10.2", optional = true }
|
||||||
|
python-multipart = "^0.0.9"
|
||||||
|
|
||||||
[tool.poetry.extras]
|
[tool.poetry.extras]
|
||||||
# NB: Packages that should be part of `pip install matrix-synapse[all]` need to be specified
|
# NB: Packages that should be part of `pip install matrix-synapse[all]` need to be specified
|
||||||
|
|
|
@ -35,6 +35,8 @@ from typing import (
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import attr
|
||||||
|
import multipart
|
||||||
import treq
|
import treq
|
||||||
from canonicaljson import encode_canonical_json
|
from canonicaljson import encode_canonical_json
|
||||||
from netaddr import AddrFormatError, IPAddress, IPSet
|
from netaddr import AddrFormatError, IPAddress, IPSet
|
||||||
|
@ -1006,6 +1008,130 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||||
self._maybe_fail()
|
self._maybe_fail()
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(auto_attribs=True, slots=True)
|
||||||
|
class MultipartResponse:
|
||||||
|
"""
|
||||||
|
A small class to hold parsed values of a multipart response.
|
||||||
|
"""
|
||||||
|
|
||||||
|
json: bytes = b"{}"
|
||||||
|
length: Optional[int] = None
|
||||||
|
content_type: Optional[bytes] = None
|
||||||
|
disposition: Optional[bytes] = None
|
||||||
|
url: Optional[bytes] = None
|
||||||
|
|
||||||
|
|
||||||
|
class _MultipartParserProtocol(protocol.Protocol):
|
||||||
|
"""
|
||||||
|
Protocol to read and parse a MSC3916 multipart/mixed response
|
||||||
|
"""
|
||||||
|
|
||||||
|
transport: Optional[ITCPTransport] = None
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
stream: ByteWriteable,
|
||||||
|
deferred: defer.Deferred,
|
||||||
|
boundary: str,
|
||||||
|
max_length: Optional[int],
|
||||||
|
) -> None:
|
||||||
|
self.stream = stream
|
||||||
|
self.deferred = deferred
|
||||||
|
self.boundary = boundary
|
||||||
|
self.max_length = max_length
|
||||||
|
self.parser = None
|
||||||
|
self.multipart_response = MultipartResponse()
|
||||||
|
self.has_redirect = False
|
||||||
|
self.in_json = False
|
||||||
|
self.json_done = False
|
||||||
|
self.file_length = 0
|
||||||
|
self.total_length = 0
|
||||||
|
self.in_disposition = False
|
||||||
|
self.in_content_type = False
|
||||||
|
|
||||||
|
def dataReceived(self, incoming_data: bytes) -> None:
|
||||||
|
if self.deferred.called:
|
||||||
|
return
|
||||||
|
|
||||||
|
# we don't have a parser yet, instantiate it
|
||||||
|
if not self.parser:
|
||||||
|
|
||||||
|
def on_header_field(data: bytes, start: int, end: int) -> None:
|
||||||
|
if data[start:end] == b"Location":
|
||||||
|
self.has_redirect = True
|
||||||
|
if data[start:end] == b"Content-Disposition":
|
||||||
|
self.in_disposition = True
|
||||||
|
if data[start:end] == b"Content-Type":
|
||||||
|
self.in_content_type = True
|
||||||
|
|
||||||
|
def on_header_value(data: bytes, start: int, end: int) -> None:
|
||||||
|
# the first header should be content-type for application/json
|
||||||
|
if not self.in_json and not self.json_done:
|
||||||
|
assert data[start:end] == b"application/json"
|
||||||
|
self.in_json = True
|
||||||
|
elif self.has_redirect:
|
||||||
|
self.multipart_response.url = data[start:end]
|
||||||
|
elif self.in_content_type:
|
||||||
|
self.multipart_response.content_type = data[start:end]
|
||||||
|
self.in_content_type = False
|
||||||
|
elif self.in_disposition:
|
||||||
|
self.multipart_response.disposition = data[start:end]
|
||||||
|
self.in_disposition = False
|
||||||
|
|
||||||
|
def on_part_data(data: bytes, start: int, end: int) -> None:
|
||||||
|
# we've seen json header but haven't written the json data
|
||||||
|
if self.in_json and not self.json_done:
|
||||||
|
self.multipart_response.json = data[start:end]
|
||||||
|
self.json_done = True
|
||||||
|
# we have a redirect header rather than a file, and have already captured it
|
||||||
|
elif self.has_redirect:
|
||||||
|
return
|
||||||
|
# otherwise we are in the file part
|
||||||
|
else:
|
||||||
|
logger.info("Writing multipart file data to stream")
|
||||||
|
try:
|
||||||
|
self.stream.write(data[start:end])
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(
|
||||||
|
f"Exception encountered writing file data to stream: {e}"
|
||||||
|
)
|
||||||
|
self.deferred.errback()
|
||||||
|
self.file_length += end - start
|
||||||
|
|
||||||
|
callbacks = {
|
||||||
|
"on_header_field": on_header_field,
|
||||||
|
"on_header_value": on_header_value,
|
||||||
|
"on_part_data": on_part_data,
|
||||||
|
}
|
||||||
|
self.parser = multipart.MultipartParser(self.boundary, callbacks)
|
||||||
|
|
||||||
|
self.total_length += len(incoming_data)
|
||||||
|
if self.max_length is not None and self.total_length >= self.max_length:
|
||||||
|
self.deferred.errback(BodyExceededMaxSize())
|
||||||
|
# Close the connection (forcefully) since all the data will get
|
||||||
|
# discarded anyway.
|
||||||
|
assert self.transport is not None
|
||||||
|
self.transport.abortConnection()
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.parser.write(incoming_data) # type: ignore[attr-defined]
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Exception writing to multipart parser: {e}")
|
||||||
|
self.deferred.errback()
|
||||||
|
return
|
||||||
|
|
||||||
|
def connectionLost(self, reason: Failure = connectionDone) -> None:
|
||||||
|
# If the maximum size was already exceeded, there's nothing to do.
|
||||||
|
if self.deferred.called:
|
||||||
|
return
|
||||||
|
|
||||||
|
if reason.check(ResponseDone):
|
||||||
|
self.multipart_response.length = self.file_length
|
||||||
|
self.deferred.callback(self.multipart_response)
|
||||||
|
else:
|
||||||
|
self.deferred.errback(reason)
|
||||||
|
|
||||||
|
|
||||||
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||||
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
|
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
|
||||||
|
|
||||||
|
@ -1091,6 +1217,32 @@ def read_body_with_max_size(
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
|
||||||
|
def read_multipart_response(
|
||||||
|
response: IResponse, stream: ByteWriteable, boundary: str, max_length: Optional[int]
|
||||||
|
) -> "defer.Deferred[MultipartResponse]":
|
||||||
|
"""
|
||||||
|
Reads a MSC3916 multipart/mixed response and parses it, reading the file part (if it contains one) into
|
||||||
|
the stream passed in and returning a deferred resolving to a MultipartResponse
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The HTTP response to read from.
|
||||||
|
stream: The file-object to write to.
|
||||||
|
boundary: the multipart/mixed boundary string
|
||||||
|
max_length: maximum allowable length of the response
|
||||||
|
"""
|
||||||
|
d: defer.Deferred[MultipartResponse] = defer.Deferred()
|
||||||
|
|
||||||
|
# If the Content-Length header gives a size larger than the maximum allowed
|
||||||
|
# size, do not bother downloading the body.
|
||||||
|
if max_length is not None and response.length != UNKNOWN_LENGTH:
|
||||||
|
if response.length > max_length:
|
||||||
|
response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
|
||||||
|
return d
|
||||||
|
|
||||||
|
response.deliverBody(_MultipartParserProtocol(stream, d, boundary, max_length))
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
def encode_query_args(args: Optional[QueryParams]) -> bytes:
|
def encode_query_args(args: Optional[QueryParams]) -> bytes:
|
||||||
"""
|
"""
|
||||||
Encodes a map of query arguments to bytes which can be appended to a URL.
|
Encodes a map of query arguments to bytes which can be appended to a URL.
|
||||||
|
|
Loading…
Reference in a new issue