Add missing type hints to synapse.http. (#11571)

This commit is contained in:
Patrick Cloke 2021-12-14 07:00:47 -05:00 committed by GitHub
parent ff6fd52160
commit 33abbc3278
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 76 additions and 51 deletions

1
changelog.d/11571.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to `synapse.http`.

View file

@ -161,6 +161,9 @@ disallow_untyped_defs = False
[mypy-synapse.handlers.*] [mypy-synapse.handlers.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.http.server]
disallow_untyped_defs = True
[mypy-synapse.metrics.*] [mypy-synapse.metrics.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -25,7 +25,7 @@ from synapse.api.errors import SynapseError
class RequestTimedOutError(SynapseError): class RequestTimedOutError(SynapseError):
"""Exception representing timeout of an outbound request""" """Exception representing timeout of an outbound request"""
def __init__(self, msg): def __init__(self, msg: str):
super().__init__(504, msg) super().__init__(504, msg)
@ -33,7 +33,7 @@ ACCESS_TOKEN_RE = re.compile(r"(\?.*access(_|%5[Ff])token=)[^&]*(.*)$")
CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$") CLIENT_SECRET_RE = re.compile(r"(\?.*client(_|%5[Ff])secret=)[^&]*(.*)$")
def redact_uri(uri): def redact_uri(uri: str) -> str:
"""Strips sensitive information from the uri replaces with <redacted>""" """Strips sensitive information from the uri replaces with <redacted>"""
uri = ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri) uri = ACCESS_TOKEN_RE.sub(r"\1<redacted>\3", uri)
return CLIENT_SECRET_RE.sub(r"\1<redacted>\3", uri) return CLIENT_SECRET_RE.sub(r"\1<redacted>\3", uri)
@ -46,7 +46,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
https://twistedmatrix.com/trac/ticket/6528 https://twistedmatrix.com/trac/ticket/6528
""" """
def stopProducing(self): def stopProducing(self) -> None:
try: try:
FileBodyProducer.stopProducing(self) FileBodyProducer.stopProducing(self)
except task.TaskStopped: except task.TaskStopped:

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple
from twisted.web.server import Request from twisted.web.server import Request
@ -32,7 +32,11 @@ class AdditionalResource(DirectServeJsonResource):
and exception handling. and exception handling.
""" """
def __init__(self, hs: "HomeServer", handler): def __init__(
self,
hs: "HomeServer",
handler: Callable[[Request], Awaitable[Optional[Tuple[int, Any]]]],
):
"""Initialise AdditionalResource """Initialise AdditionalResource
The ``handler`` should return a deferred which completes when it has The ``handler`` should return a deferred which completes when it has
@ -47,7 +51,7 @@ class AdditionalResource(DirectServeJsonResource):
super().__init__() super().__init__()
self._handler = handler self._handler = handler
def _async_render(self, request: Request): async def _async_render(self, request: Request) -> Optional[Tuple[int, Any]]:
# Cheekily pass the result straight through, so we don't need to worry # Cheekily pass the result straight through, so we don't need to worry
# if its an awaitable or not. # if its an awaitable or not.
return self._handler(request) return await self._handler(request)

View file

@ -30,6 +30,7 @@ from typing import (
Iterable, Iterable,
Iterator, Iterator,
List, List,
NoReturn,
Optional, Optional,
Pattern, Pattern,
Tuple, Tuple,
@ -170,7 +171,9 @@ def return_html_error(
respond_with_html(request, code, body) respond_with_html(request, code, body)
def wrap_async_request_handler(h): def wrap_async_request_handler(
h: Callable[["_AsyncResource", SynapseRequest], Awaitable[None]]
) -> Callable[["_AsyncResource", SynapseRequest], "defer.Deferred[None]"]:
"""Wraps an async request handler so that it calls request.processing. """Wraps an async request handler so that it calls request.processing.
This helps ensure that work done by the request handler after the request is completed This helps ensure that work done by the request handler after the request is completed
@ -183,7 +186,9 @@ def wrap_async_request_handler(h):
logged until the deferred completes. logged until the deferred completes.
""" """
async def wrapped_async_request_handler(self, request): async def wrapped_async_request_handler(
self: "_AsyncResource", request: SynapseRequest
) -> None:
with request.processing(): with request.processing():
await h(self, request) await h(self, request)
@ -240,18 +245,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
context from the request the servlet is handling. context from the request the servlet is handling.
""" """
def __init__(self, extract_context=False): def __init__(self, extract_context: bool = False):
super().__init__() super().__init__()
self._extract_context = extract_context self._extract_context = extract_context
def render(self, request): def render(self, request: SynapseRequest) -> int:
"""This gets called by twisted every time someone sends us a request.""" """This gets called by twisted every time someone sends us a request."""
defer.ensureDeferred(self._async_render_wrapper(request)) defer.ensureDeferred(self._async_render_wrapper(request))
return NOT_DONE_YET return NOT_DONE_YET
@wrap_async_request_handler @wrap_async_request_handler
async def _async_render_wrapper(self, request: SynapseRequest): async def _async_render_wrapper(self, request: SynapseRequest) -> None:
"""This is a wrapper that delegates to `_async_render` and handles """This is a wrapper that delegates to `_async_render` and handles
exceptions, return values, metrics, etc. exceptions, return values, metrics, etc.
""" """
@ -271,7 +276,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
f = failure.Failure() f = failure.Failure()
self._send_error_response(f, request) self._send_error_response(f, request)
async def _async_render(self, request: Request): async def _async_render(self, request: SynapseRequest) -> Optional[Tuple[int, Any]]:
"""Delegates to `_async_render_<METHOD>` methods, or returns a 400 if """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if
no appropriate method exists. Can be overridden in sub classes for no appropriate method exists. Can be overridden in sub classes for
different routing. different routing.
@ -318,7 +323,7 @@ class DirectServeJsonResource(_AsyncResource):
formatting responses and errors as JSON. formatting responses and errors as JSON.
""" """
def __init__(self, canonical_json=False, extract_context=False): def __init__(self, canonical_json: bool = False, extract_context: bool = False):
super().__init__(extract_context) super().__init__(extract_context)
self.canonical_json = canonical_json self.canonical_json = canonical_json
@ -327,7 +332,7 @@ class DirectServeJsonResource(_AsyncResource):
request: SynapseRequest, request: SynapseRequest,
code: int, code: int,
response_object: Any, response_object: Any,
): ) -> None:
"""Implements _AsyncResource._send_response""" """Implements _AsyncResource._send_response"""
# TODO: Only enable CORS for the requests that need it. # TODO: Only enable CORS for the requests that need it.
respond_with_json( respond_with_json(
@ -368,34 +373,45 @@ class JsonResource(DirectServeJsonResource):
isLeaf = True isLeaf = True
def __init__(self, hs: "HomeServer", canonical_json=True, extract_context=False): def __init__(
self,
hs: "HomeServer",
canonical_json: bool = True,
extract_context: bool = False,
):
super().__init__(canonical_json, extract_context) super().__init__(canonical_json, extract_context)
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.path_regexs: Dict[bytes, List[_PathEntry]] = {} self.path_regexs: Dict[bytes, List[_PathEntry]] = {}
self.hs = hs self.hs = hs
def register_paths(self, method, path_patterns, callback, servlet_classname): def register_paths(
self,
method: str,
path_patterns: Iterable[Pattern],
callback: ServletCallback,
servlet_classname: str,
) -> None:
""" """
Registers a request handler against a regular expression. Later request URLs are Registers a request handler against a regular expression. Later request URLs are
checked against these regular expressions in order to identify an appropriate checked against these regular expressions in order to identify an appropriate
handler for that request. handler for that request.
Args: Args:
method (str): GET, POST etc method: GET, POST etc
path_patterns (Iterable[str]): A list of regular expressions to which path_patterns: A list of regular expressions to which the request
the request URLs are compared. URLs are compared.
callback (function): The handler for the request. Usually a Servlet callback: The handler for the request. Usually a Servlet
servlet_classname (str): The name of the handler to be used in prometheus servlet_classname: The name of the handler to be used in prometheus
and opentracing logs. and opentracing logs.
""" """
method = method.encode("utf-8") # method is bytes on py3 method_bytes = method.encode("utf-8")
for path_pattern in path_patterns: for path_pattern in path_patterns:
logger.debug("Registering for %s %s", method, path_pattern.pattern) logger.debug("Registering for %s %s", method, path_pattern.pattern)
self.path_regexs.setdefault(method, []).append( self.path_regexs.setdefault(method_bytes, []).append(
_PathEntry(path_pattern, callback, servlet_classname) _PathEntry(path_pattern, callback, servlet_classname)
) )
@ -427,7 +443,7 @@ class JsonResource(DirectServeJsonResource):
# Huh. No one wanted to handle that? Fiiiiiine. Send 400. # Huh. No one wanted to handle that? Fiiiiiine. Send 400.
return _unrecognised_request_handler, "unrecognised_request_handler", {} return _unrecognised_request_handler, "unrecognised_request_handler", {}
async def _async_render(self, request): async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request) callback, servlet_classname, group_dict = self._get_handler_for_request(request)
# Make sure we have an appropriate name for this handler in prometheus # Make sure we have an appropriate name for this handler in prometheus
@ -468,7 +484,7 @@ class DirectServeHtmlResource(_AsyncResource):
request: SynapseRequest, request: SynapseRequest,
code: int, code: int,
response_object: Any, response_object: Any,
): ) -> None:
"""Implements _AsyncResource._send_response""" """Implements _AsyncResource._send_response"""
# We expect to get bytes for us to write # We expect to get bytes for us to write
assert isinstance(response_object, bytes) assert isinstance(response_object, bytes)
@ -492,12 +508,12 @@ class StaticResource(File):
Differs from the File resource by adding clickjacking protection. Differs from the File resource by adding clickjacking protection.
""" """
def render_GET(self, request: Request): def render_GET(self, request: Request) -> bytes:
set_clickjacking_protection_headers(request) set_clickjacking_protection_headers(request)
return super().render_GET(request) return super().render_GET(request)
def _unrecognised_request_handler(request): def _unrecognised_request_handler(request: Request) -> NoReturn:
"""Request handler for unrecognised requests """Request handler for unrecognised requests
This is a request handler suitable for return from This is a request handler suitable for return from
@ -505,7 +521,7 @@ def _unrecognised_request_handler(request):
UnrecognizedRequestError. UnrecognizedRequestError.
Args: Args:
request (twisted.web.http.Request): request: Unused, but passed in to match the signature of ServletCallback.
""" """
raise UnrecognizedRequestError() raise UnrecognizedRequestError()
@ -513,14 +529,14 @@ def _unrecognised_request_handler(request):
class RootRedirect(resource.Resource): class RootRedirect(resource.Resource):
"""Redirects the root '/' path to another path.""" """Redirects the root '/' path to another path."""
def __init__(self, path): def __init__(self, path: str):
resource.Resource.__init__(self) resource.Resource.__init__(self)
self.url = path self.url = path
def render_GET(self, request): def render_GET(self, request: Request) -> bytes:
return redirectTo(self.url.encode("ascii"), request) return redirectTo(self.url.encode("ascii"), request)
def getChild(self, name, request): def getChild(self, name: str, request: Request) -> resource.Resource:
if len(name) == 0: if len(name) == 0:
return self # select ourselves as the child to render return self # select ourselves as the child to render
return resource.Resource.getChild(self, name, request) return resource.Resource.getChild(self, name, request)
@ -529,7 +545,7 @@ class RootRedirect(resource.Resource):
class OptionsResource(resource.Resource): class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children.""" """Responds to OPTION requests for itself and all children."""
def render_OPTIONS(self, request): def render_OPTIONS(self, request: Request) -> bytes:
request.setResponseCode(204) request.setResponseCode(204)
request.setHeader(b"Content-Length", b"0") request.setHeader(b"Content-Length", b"0")
@ -537,7 +553,7 @@ class OptionsResource(resource.Resource):
return b"" return b""
def getChildWithDefault(self, path, request): def getChildWithDefault(self, path: str, request: Request) -> resource.Resource:
if request.method == b"OPTIONS": if request.method == b"OPTIONS":
return self # select ourselves as the child to render return self # select ourselves as the child to render
return resource.Resource.getChildWithDefault(self, path, request) return resource.Resource.getChildWithDefault(self, path, request)
@ -649,7 +665,7 @@ def respond_with_json(
json_object: Any, json_object: Any,
send_cors: bool = False, send_cors: bool = False,
canonical_json: bool = True, canonical_json: bool = True,
): ) -> Optional[int]:
"""Sends encoded JSON in response to the given request. """Sends encoded JSON in response to the given request.
Args: Args:
@ -696,7 +712,7 @@ def respond_with_json_bytes(
code: int, code: int,
json_bytes: bytes, json_bytes: bytes,
send_cors: bool = False, send_cors: bool = False,
): ) -> Optional[int]:
"""Sends encoded JSON in response to the given request. """Sends encoded JSON in response to the given request.
Args: Args:
@ -713,7 +729,7 @@ def respond_with_json_bytes(
logger.warning( logger.warning(
"Not sending response to request %s, already disconnected.", request "Not sending response to request %s, already disconnected.", request
) )
return return None
request.setResponseCode(code) request.setResponseCode(code)
request.setHeader(b"Content-Type", b"application/json") request.setHeader(b"Content-Type", b"application/json")
@ -731,7 +747,7 @@ async def _async_write_json_to_request_in_thread(
request: SynapseRequest, request: SynapseRequest,
json_encoder: Callable[[Any], bytes], json_encoder: Callable[[Any], bytes],
json_object: Any, json_object: Any,
): ) -> None:
"""Encodes the given JSON object on a thread and then writes it to the """Encodes the given JSON object on a thread and then writes it to the
request. request.
@ -773,7 +789,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
_ByteProducer(request, bytes_generator) _ByteProducer(request, bytes_generator)
def set_cors_headers(request: Request): def set_cors_headers(request: Request) -> None:
"""Set the CORS headers so that javascript running in a web browsers can """Set the CORS headers so that javascript running in a web browsers can
use this API use this API
@ -790,14 +806,14 @@ def set_cors_headers(request: Request):
) )
def respond_with_html(request: Request, code: int, html: str): def respond_with_html(request: Request, code: int, html: str) -> None:
""" """
Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes. Wraps `respond_with_html_bytes` by first encoding HTML from a str to UTF-8 bytes.
""" """
respond_with_html_bytes(request, code, html.encode("utf-8")) respond_with_html_bytes(request, code, html.encode("utf-8"))
def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes): def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes) -> None:
""" """
Sends HTML (encoded as UTF-8 bytes) as the response to the given request. Sends HTML (encoded as UTF-8 bytes) as the response to the given request.
@ -815,7 +831,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
logger.warning( logger.warning(
"Not sending response to request %s, already disconnected.", request "Not sending response to request %s, already disconnected.", request
) )
return return None
request.setResponseCode(code) request.setResponseCode(code)
request.setHeader(b"Content-Type", b"text/html; charset=utf-8") request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
@ -828,7 +844,7 @@ def respond_with_html_bytes(request: Request, code: int, html_bytes: bytes):
finish_request(request) finish_request(request)
def set_clickjacking_protection_headers(request: Request): def set_clickjacking_protection_headers(request: Request) -> None:
""" """
Set headers to guard against clickjacking of embedded content. Set headers to guard against clickjacking of embedded content.
@ -850,7 +866,7 @@ def respond_with_redirect(request: Request, url: bytes) -> None:
finish_request(request) finish_request(request)
def finish_request(request: Request): def finish_request(request: Request) -> None:
"""Finish writing the response to the request. """Finish writing the response to the request.
Twisted throws a RuntimeException if the connection closed before the Twisted throws a RuntimeException if the connection closed before the

View file

@ -31,6 +31,7 @@ from typing_extensions import Literal
from twisted.web.server import Request from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import HttpServer
from synapse.types import JsonDict, RoomAlias, RoomID from synapse.types import JsonDict, RoomAlias, RoomID
from synapse.util import json_decoder from synapse.util import json_decoder
@ -726,7 +727,7 @@ class RestServlet:
into the appropriate HTTP response. into the appropriate HTTP response.
""" """
def register(self, http_server): def register(self, http_server: HttpServer) -> None:
"""Register this servlet with the given HTTP server.""" """Register this servlet with the given HTTP server."""
patterns = getattr(self, "PATTERNS", None) patterns = getattr(self, "PATTERNS", None)
if patterns: if patterns:

View file

@ -14,7 +14,7 @@
import contextlib import contextlib
import logging import logging
import time import time
from typing import Generator, Optional, Tuple, Union from typing import Any, Generator, Optional, Tuple, Union
import attr import attr
from zope.interface import implementer from zope.interface import implementer
@ -66,9 +66,9 @@ class SynapseRequest(Request):
self, self,
channel: HTTPChannel, channel: HTTPChannel,
site: "SynapseSite", site: "SynapseSite",
*args, *args: Any,
max_request_body_size: int = 1024, max_request_body_size: int = 1024,
**kw, **kw: Any,
): ):
super().__init__(channel, *args, **kw) super().__init__(channel, *args, **kw)
self._max_request_body_size = max_request_body_size self._max_request_body_size = max_request_body_size
@ -557,7 +557,7 @@ class SynapseSite(Site):
proxied = config.http_options.x_forwarded proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest request_class = XForwardedForRequest if proxied else SynapseRequest
def request_factory(channel, queued: bool) -> Request: def request_factory(channel: HTTPChannel, queued: bool) -> Request:
return request_class( return request_class(
channel, channel,
self, self,

View file

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Optional
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from signedjson.sign import sign_json from signedjson.sign import sign_json
@ -99,7 +99,7 @@ class LocalKey(Resource):
json_object = sign_json(json_object, self.config.server.server_name, key) json_object = sign_json(json_object, self.config.server.server_name, key)
return json_object return json_object
def render_GET(self, request: Request) -> int: def render_GET(self, request: Request) -> Optional[int]:
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# Update the expiry time if less than half the interval remains. # Update the expiry time if less than half the interval remains.
if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: if time_now + self.config.key.key_refresh_interval / 2 > self.valid_until_ts: