diff --git a/changelog.d/10437.misc b/changelog.d/10437.misc new file mode 100644 index 0000000000..a557578499 --- /dev/null +++ b/changelog.d/10437.misc @@ -0,0 +1 @@ +Improve servlet type hints. diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 2974d4d0cc..5e059d6e09 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -984,7 +984,7 @@ class PublicRoomList(BaseFederationServlet): limit = parse_integer_from_args(query, "limit", 0) since_token = parse_string_from_args(query, "since", None) include_all_networks = parse_boolean_from_args( - query, "include_all_networks", False + query, "include_all_networks", default=False ) third_party_instance_id = parse_string_from_args( query, "third_party_instance_id", None @@ -1908,16 +1908,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): suggested_only = parse_boolean_from_args(query, "suggested_only", default=False) max_rooms_per_space = parse_integer_from_args(query, "max_rooms_per_space") - exclude_rooms = [] - if b"exclude_rooms" in query: - try: - exclude_rooms = [ - room_id.decode("ascii") for room_id in query[b"exclude_rooms"] - ] - except Exception: - raise SynapseError( - 400, "Bad query parameter for exclude_rooms", Codes.INVALID_PARAM - ) + exclude_rooms = parse_strings_from_args(query, "exclude_rooms", default=[]) return 200, await self.handler.federation_space_summary( origin, room_id, suggested_only, max_rooms_per_space, exclude_rooms diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index cf45b6623b..732a1e6aeb 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -14,47 +14,86 @@ """ This module contains base REST classes for constructing REST servlets. """ import logging -from typing import Dict, Iterable, List, Optional, overload +from typing import Iterable, List, Mapping, Optional, Sequence, overload from typing_extensions import Literal from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError +from synapse.types import JsonDict from synapse.util import json_decoder logger = logging.getLogger(__name__) -def parse_integer(request, name, default=None, required=False): +@overload +def parse_integer(request: Request, name: str, default: int) -> int: + ... + + +@overload +def parse_integer(request: Request, name: str, *, required: Literal[True]) -> int: + ... + + +@overload +def parse_integer( + request: Request, name: str, default: Optional[int] = None, required: bool = False +) -> Optional[int]: + ... + + +def parse_integer( + request: Request, name: str, default: Optional[int] = None, required: bool = False +) -> Optional[int]: """Parse an integer parameter from the request string Args: request: the twisted HTTP request. - name (bytes/unicode): the name of the query parameter. - default (int|None): value to use if the parameter is absent, defaults - to None. - required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. + name: the name of the query parameter. + default: value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the parameter is absent, + defaults to False. Returns: - int|None: An int value or the default. + An int value or the default. Raises: SynapseError: if the parameter is absent and required, or if the parameter is present and not an integer. """ - return parse_integer_from_args(request.args, name, default, required) + args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore + return parse_integer_from_args(args, name, default, required) -def parse_integer_from_args(args, name, default=None, required=False): +def parse_integer_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[int] = None, + required: bool = False, +) -> Optional[int]: + """Parse an integer parameter from the request string - if not isinstance(name, bytes): - name = name.encode("ascii") + Args: + args: A mapping of request args as bytes to a list of bytes (e.g. request.args). + name: the name of the query parameter. + default: value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the parameter is absent, + defaults to False. - if name in args: + Returns: + An int value or the default. + + Raises: + SynapseError: if the parameter is absent and required, or if the + parameter is present and not an integer. + """ + name_bytes = name.encode("ascii") + + if name_bytes in args: try: - return int(args[name][0]) + return int(args[name_bytes][0]) except Exception: message = "Query parameter %r must be an integer" % (name,) raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) @@ -66,36 +105,102 @@ def parse_integer_from_args(args, name, default=None, required=False): return default -def parse_boolean(request, name, default=None, required=False): +@overload +def parse_boolean(request: Request, name: str, default: bool) -> bool: + ... + + +@overload +def parse_boolean(request: Request, name: str, *, required: Literal[True]) -> bool: + ... + + +@overload +def parse_boolean( + request: Request, name: str, default: Optional[bool] = None, required: bool = False +) -> Optional[bool]: + ... + + +def parse_boolean( + request: Request, name: str, default: Optional[bool] = None, required: bool = False +) -> Optional[bool]: """Parse a boolean parameter from the request query string Args: request: the twisted HTTP request. - name (bytes/unicode): the name of the query parameter. - default (bool|None): value to use if the parameter is absent, defaults - to None. - required (bool): whether to raise a 400 SynapseError if the - parameter is absent, defaults to False. + name: the name of the query parameter. + default: value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the parameter is absent, + defaults to False. Returns: - bool|None: A bool value or the default. + A bool value or the default. Raises: SynapseError: if the parameter is absent and required, or if the parameter is present and not one of "true" or "false". """ - - return parse_boolean_from_args(request.args, name, default, required) + args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore + return parse_boolean_from_args(args, name, default, required) -def parse_boolean_from_args(args, name, default=None, required=False): +@overload +def parse_boolean_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: bool, +) -> bool: + ... - if not isinstance(name, bytes): - name = name.encode("ascii") - if name in args: +@overload +def parse_boolean_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + *, + required: Literal[True], +) -> bool: + ... + + +@overload +def parse_boolean_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[bool] = None, + required: bool = False, +) -> Optional[bool]: + ... + + +def parse_boolean_from_args( + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: Optional[bool] = None, + required: bool = False, +) -> Optional[bool]: + """Parse a boolean parameter from the request query string + + Args: + args: A mapping of request args as bytes to a list of bytes (e.g. request.args). + name: the name of the query parameter. + default: value to use if the parameter is absent, defaults to None. + required: whether to raise a 400 SynapseError if the parameter is absent, + defaults to False. + + Returns: + A bool value or the default. + + Raises: + SynapseError: if the parameter is absent and required, or if the + parameter is present and not one of "true" or "false". + """ + name_bytes = name.encode("ascii") + + if name_bytes in args: try: - return {b"true": True, b"false": False}[args[name][0]] + return {b"true": True, b"false": False}[args[name_bytes][0]] except Exception: message = ( "Boolean query parameter %r must be one of ['true', 'false']" @@ -111,7 +216,7 @@ def parse_boolean_from_args(args, name, default=None, required=False): @overload def parse_bytes_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[bytes] = None, ) -> Optional[bytes]: @@ -120,7 +225,7 @@ def parse_bytes_from_args( @overload def parse_bytes_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Literal[None] = None, *, @@ -131,7 +236,7 @@ def parse_bytes_from_args( @overload def parse_bytes_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[bytes] = None, required: bool = False, @@ -140,7 +245,7 @@ def parse_bytes_from_args( def parse_bytes_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[bytes] = None, required: bool = False, @@ -241,7 +346,7 @@ def parse_string( parameter is present, must be one of a list of allowed values and is not one of those allowed values. """ - args: Dict[bytes, List[bytes]] = request.args # type: ignore + args: Mapping[bytes, Sequence[bytes]] = request.args # type: ignore return parse_string_from_args( args, name, @@ -275,9 +380,8 @@ def _parse_string_value( @overload def parse_strings_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[List[str]] = None, *, allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", @@ -287,9 +391,20 @@ def parse_strings_from_args( @overload def parse_strings_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], + name: str, + default: List[str], + *, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> List[str]: + ... + + +@overload +def parse_strings_from_args( + args: Mapping[bytes, Sequence[bytes]], name: str, - default: Optional[List[str]] = None, *, required: Literal[True], allowed_values: Optional[Iterable[str]] = None, @@ -300,7 +415,7 @@ def parse_strings_from_args( @overload def parse_strings_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[List[str]] = None, *, @@ -312,7 +427,7 @@ def parse_strings_from_args( def parse_strings_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[List[str]] = None, required: bool = False, @@ -361,7 +476,7 @@ def parse_strings_from_args( @overload def parse_string_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[str] = None, *, @@ -373,7 +488,7 @@ def parse_string_from_args( @overload def parse_string_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[str] = None, *, @@ -386,7 +501,7 @@ def parse_string_from_args( @overload def parse_string_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[str] = None, required: bool = False, @@ -397,7 +512,7 @@ def parse_string_from_args( def parse_string_from_args( - args: Dict[bytes, List[bytes]], + args: Mapping[bytes, Sequence[bytes]], name: str, default: Optional[str] = None, required: bool = False, @@ -445,13 +560,14 @@ def parse_string_from_args( return strings[0] -def parse_json_value_from_request(request, allow_empty_body=False): +def parse_json_value_from_request( + request: Request, allow_empty_body: bool = False +) -> Optional[JsonDict]: """Parse a JSON value from the body of a twisted HTTP request. Args: request: the twisted HTTP request. - allow_empty_body (bool): if True, an empty body will be accepted and - turned into None + allow_empty_body: if True, an empty body will be accepted and turned into None Returns: The JSON value. @@ -460,7 +576,7 @@ def parse_json_value_from_request(request, allow_empty_body=False): SynapseError if the request body couldn't be decoded as JSON. """ try: - content_bytes = request.content.read() + content_bytes = request.content.read() # type: ignore except Exception: raise SynapseError(400, "Error reading JSON content.") @@ -476,13 +592,15 @@ def parse_json_value_from_request(request, allow_empty_body=False): return content -def parse_json_object_from_request(request, allow_empty_body=False): +def parse_json_object_from_request( + request: Request, allow_empty_body: bool = False +) -> JsonDict: """Parse a JSON object from the body of a twisted HTTP request. Args: request: the twisted HTTP request. - allow_empty_body (bool): if True, an empty body will be accepted and - turned into an empty dict. + allow_empty_body: if True, an empty body will be accepted and turned into + an empty dict. Raises: SynapseError if the request body couldn't be decoded as JSON or @@ -493,14 +611,14 @@ def parse_json_object_from_request(request, allow_empty_body=False): if allow_empty_body and content is None: return {} - if type(content) != dict: + if not isinstance(content, dict): message = "Content must be a JSON object." raise SynapseError(400, message, errcode=Codes.BAD_JSON) return content -def assert_params_in_dict(body, required): +def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: absent = [] for k in required: if k not in body: diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 5d309a534c..25ba52c624 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -754,7 +754,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): if server: raise e - limit = parse_integer(request, "limit", 0) + limit: Optional[int] = parse_integer(request, "limit", 0) since_token = parse_string(request, "since") if limit == 0: diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py index 0f9aa54ca9..889e0d3625 100644 --- a/synapse/storage/databases/main/stats.py +++ b/synapse/storage/databases/main/stats.py @@ -650,7 +650,7 @@ class StatsStore(StateDeltasStore): order_by: Optional[str] = UserSortOrder.USER_ID.value, direction: Optional[str] = "f", search_term: Optional[str] = None, - ) -> Tuple[List[JsonDict], Dict[str, int]]: + ) -> Tuple[List[JsonDict], int]: """Function to retrieve a paginated list of users and their uploaded local media (size and number). This will return a json list of users and the total number of users matching the filter criteria. diff --git a/tests/rest/admin/test_media.py b/tests/rest/admin/test_media.py index 6fee0f95b6..7198fd293f 100644 --- a/tests/rest/admin/test_media.py +++ b/tests/rest/admin/test_media.py @@ -261,7 +261,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.MISSING_PARAM, channel.json_body["errcode"]) self.assertEqual( - "Missing integer query parameter b'before_ts'", channel.json_body["error"] + "Missing integer query parameter 'before_ts'", channel.json_body["error"] ) def test_invalid_parameter(self): @@ -303,7 +303,7 @@ class DeleteMediaByDateSizeTestCase(unittest.HomeserverTestCase): self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"]) self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"]) self.assertEqual( - "Boolean query parameter b'keep_profiles' must be one of ['true', 'false']", + "Boolean query parameter 'keep_profiles' must be one of ['true', 'false']", channel.json_body["error"], )