Add helper to parse an enum from query args & use it. (#14956)

The `parse_enum` helper pulls an enum value from the query string
(by delegating down to the parse_string helper with values generated
from the enum).

This is used to pull out "f" and "b" in most places and then we thread
the resulting Direction enum throughout more code.
This commit is contained in:
Patrick Cloke 2023-02-01 16:35:24 -05:00 committed by GitHub
parent 230a831c73
commit 1182ae5063
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 176 additions and 96 deletions

1
changelog.d/14956.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints.

View file

@ -37,7 +37,7 @@ from typing import (
import attr import attr
from prometheus_client import Counter from prometheus_client import Counter
from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
CodeMessageException, CodeMessageException,
Codes, Codes,
@ -1680,7 +1680,12 @@ class FederationClient(FederationBase):
return result return result
async def timestamp_to_event( async def timestamp_to_event(
self, *, destinations: List[str], room_id: str, timestamp: int, direction: str self,
*,
destinations: List[str],
room_id: str,
timestamp: int,
direction: Direction,
) -> Optional["TimestampToEventResponse"]: ) -> Optional["TimestampToEventResponse"]:
""" """
Calls each remote federating server from `destinations` asking for their closest Calls each remote federating server from `destinations` asking for their closest
@ -1693,7 +1698,7 @@ class FederationClient(FederationBase):
room_id: Room to fetch the event from room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event. the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event. or backward from the given timestamp to find the closest event.
Returns: Returns:
@ -1738,7 +1743,7 @@ class FederationClient(FederationBase):
return None return None
async def _timestamp_to_event_from_destination( async def _timestamp_to_event_from_destination(
self, destination: str, room_id: str, timestamp: int, direction: str self, destination: str, room_id: str, timestamp: int, direction: Direction
) -> "TimestampToEventResponse": ) -> "TimestampToEventResponse":
""" """
Calls a remote federating server at `destination` asking for their Calls a remote federating server at `destination` asking for their
@ -1751,7 +1756,7 @@ class FederationClient(FederationBase):
room_id: Room to fetch the event from room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event. the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event. or backward from the given timestamp to find the closest event.
Returns: Returns:

View file

@ -34,7 +34,13 @@ from prometheus_client import Counter, Gauge, Histogram
from twisted.internet.abstract import isIPAddress from twisted.internet.abstract import isIPAddress
from twisted.python import failure from twisted.python import failure
from synapse.api.constants import EduTypes, EventContentFields, EventTypes, Membership from synapse.api.constants import (
Direction,
EduTypes,
EventContentFields,
EventTypes,
Membership,
)
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -218,7 +224,7 @@ class FederationServer(FederationBase):
return 200, res return 200, res
async def on_timestamp_to_event_request( async def on_timestamp_to_event_request(
self, origin: str, room_id: str, timestamp: int, direction: str self, origin: str, room_id: str, timestamp: int, direction: Direction
) -> Tuple[int, Dict[str, Any]]: ) -> Tuple[int, Dict[str, Any]]:
"""When we receive a federated `/timestamp_to_event` request, """When we receive a federated `/timestamp_to_event` request,
handle all of the logic for validating and fetching the event. handle all of the logic for validating and fetching the event.
@ -228,7 +234,7 @@ class FederationServer(FederationBase):
room_id: Room to fetch the event from room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event. the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event. or backward from the given timestamp to find the closest event.
Returns: Returns:

View file

@ -32,7 +32,7 @@ from typing import (
import attr import attr
import ijson import ijson
from synapse.api.constants import Membership from synapse.api.constants import Direction, Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.api.urls import ( from synapse.api.urls import (
@ -169,7 +169,7 @@ class TransportLayerClient:
) )
async def timestamp_to_event( async def timestamp_to_event(
self, destination: str, room_id: str, timestamp: int, direction: str self, destination: str, room_id: str, timestamp: int, direction: Direction
) -> Union[JsonDict, List]: ) -> Union[JsonDict, List]:
""" """
Calls a remote federating server at `destination` asking for their Calls a remote federating server at `destination` asking for their
@ -180,7 +180,7 @@ class TransportLayerClient:
room_id: Room to fetch the event from room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event. the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event. or backward from the given timestamp to find the closest event.
Returns: Returns:
@ -194,7 +194,7 @@ class TransportLayerClient:
room_id, room_id,
) )
args = {"ts": [str(timestamp)], "dir": [direction]} args = {"ts": [str(timestamp)], "dir": [direction.value]}
remote_response = await self.client.get_json( remote_response = await self.client.get_json(
destination, path=path, args=args, try_trailing_slash_on_400=True destination, path=path, args=args, try_trailing_slash_on_400=True

View file

@ -26,7 +26,7 @@ from typing import (
from typing_extensions import Literal from typing_extensions import Literal
from synapse.api.constants import EduTypes from synapse.api.constants import Direction, EduTypes
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX from synapse.api.urls import FEDERATION_UNSTABLE_PREFIX, FEDERATION_V2_PREFIX
@ -234,9 +234,10 @@ class FederationTimestampLookupServlet(BaseFederationServerServlet):
room_id: str, room_id: str,
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
timestamp = parse_integer_from_args(query, "ts", required=True) timestamp = parse_integer_from_args(query, "ts", required=True)
direction = parse_string_from_args( direction_str = parse_string_from_args(
query, "dir", default="f", allowed_values=["f", "b"], required=True query, "dir", allowed_values=["f", "b"], required=True
) )
direction = Direction(direction_str)
return await self.handler.on_timestamp_to_event_request( return await self.handler.on_timestamp_to_event_request(
origin, room_id, timestamp, direction origin, room_id, timestamp, direction

View file

@ -314,7 +314,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
def get_current_key(self, direction: str = "f") -> int: def get_current_key(self) -> int:
return self.store.get_max_account_data_stream_id() return self.store.get_max_account_data_stream_id()
async def get_new_events( async def get_new_events(

View file

@ -315,5 +315,5 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
return events, to_key return events, to_key
def get_current_key(self, direction: str = "f") -> int: def get_current_key(self) -> int:
return self.store.get_max_receipt_stream_id() return self.store.get_max_receipt_stream_id()

View file

@ -27,6 +27,7 @@ from typing_extensions import TypedDict
import synapse.events.snapshot import synapse.events.snapshot
from synapse.api.constants import ( from synapse.api.constants import (
Direction,
EventContentFields, EventContentFields,
EventTypes, EventTypes,
GuestAccess, GuestAccess,
@ -1487,7 +1488,7 @@ class TimestampLookupHandler:
requester: Requester, requester: Requester,
room_id: str, room_id: str,
timestamp: int, timestamp: int,
direction: str, direction: Direction,
) -> Tuple[str, int]: ) -> Tuple[str, int]:
"""Find the closest event to the given timestamp in the given direction. """Find the closest event to the given timestamp in the given direction.
If we can't find an event locally or the event we have locally is next to a gap, If we can't find an event locally or the event we have locally is next to a gap,
@ -1498,7 +1499,7 @@ class TimestampLookupHandler:
room_id: Room to fetch the event from room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event. the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event. or backward from the given timestamp to find the closest event.
Returns: Returns:
@ -1533,13 +1534,13 @@ class TimestampLookupHandler:
local_event_id, allow_none=False, allow_rejected=False local_event_id, allow_none=False, allow_rejected=False
) )
if direction == "f": if direction == Direction.FORWARDS:
# We only need to check for a backward gap if we're looking forwards # We only need to check for a backward gap if we're looking forwards
# to ensure there is nothing in between. # to ensure there is nothing in between.
is_event_next_to_backward_gap = ( is_event_next_to_backward_gap = (
await self.store.is_event_next_to_backward_gap(local_event) await self.store.is_event_next_to_backward_gap(local_event)
) )
elif direction == "b": elif direction == Direction.BACKWARDS:
# We only need to check for a forward gap if we're looking backwards # We only need to check for a forward gap if we're looking backwards
# to ensure there is nothing in between # to ensure there is nothing in between
is_event_next_to_forward_gap = ( is_event_next_to_forward_gap = (

View file

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
""" This module contains base REST classes for constructing REST servlets. """ """ This module contains base REST classes for constructing REST servlets. """
import enum
import logging import logging
from http import HTTPStatus from http import HTTPStatus
from typing import ( from typing import (
@ -362,6 +363,7 @@ def parse_string(
request: Request, request: Request,
name: str, name: str,
*, *,
default: Optional[str] = None,
required: bool = False, required: bool = False,
allowed_values: Optional[Iterable[str]] = None, allowed_values: Optional[Iterable[str]] = None,
encoding: str = "ascii", encoding: str = "ascii",
@ -413,6 +415,74 @@ def parse_string(
) )
EnumT = TypeVar("EnumT", bound=enum.Enum)
@overload
def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
default: EnumT,
) -> EnumT:
...
@overload
def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
*,
required: Literal[True],
) -> EnumT:
...
def parse_enum(
request: Request,
name: str,
E: Type[EnumT],
default: Optional[EnumT] = None,
required: bool = False,
) -> Optional[EnumT]:
"""
Parse an enum parameter from the request query string.
Note that the enum *must only have string values*.
Args:
request: the twisted HTTP request.
name: the name of the query parameter.
E: the enum which represents valid values
default: enum value to use if the parameter is absent, defaults to None.
required: whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
Returns:
An enum value.
Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
"""
# Assert the enum values are strings.
assert all(
isinstance(e.value, str) for e in E
), "parse_enum only works with string values"
str_value = parse_string(
request,
name,
default=default.value if default is not None else None,
required=required,
allowed_values=[e.value for e in E],
)
if str_value is None:
return None
return E(str_value)
def _parse_string_value( def _parse_string_value(
value: bytes, value: bytes,
allowed_values: Optional[Iterable[str]], allowed_values: Optional[Iterable[str]],

View file

@ -16,8 +16,9 @@ import logging
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict from synapse.types import JsonDict
@ -60,7 +61,7 @@ class EventReportsRestServlet(RestServlet):
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100) limit = parse_integer(request, "limit", default=100)
direction = parse_string(request, "dir", default="b") direction = parse_enum(request, "dir", Direction, Direction.BACKWARDS)
user_id = parse_string(request, "user_id") user_id = parse_string(request, "user_id")
room_id = parse_string(request, "room_id") room_id = parse_string(request, "room_id")
@ -78,13 +79,6 @@ class EventReportsRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM, errcode=Codes.INVALID_PARAM,
) )
if direction not in ("f", "b"):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Unknown direction: %s" % (direction,),
errcode=Codes.INVALID_PARAM,
)
event_reports, total = await self.store.get_event_reports_paginate( event_reports, total = await self.store.get_event_reports_paginate(
start, limit, direction, user_id, room_id start, limit, direction, user_id, room_id
) )

View file

@ -15,9 +15,10 @@ import logging
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.federation.transport.server import Authenticator from synapse.federation.transport.server import Authenticator
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.storage.databases.main.transactions import DestinationSortOrder from synapse.storage.databases.main.transactions import DestinationSortOrder
@ -79,7 +80,7 @@ class ListDestinationsRestServlet(RestServlet):
allowed_values=[dest.value for dest in DestinationSortOrder], allowed_values=[dest.value for dest in DestinationSortOrder],
) )
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
destinations, total = await self._store.get_destinations_paginate( destinations, total = await self._store.get_destinations_paginate(
start, limit, destination, order_by, direction start, limit, destination, order_by, direction
@ -192,7 +193,7 @@ class DestinationMembershipRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM, errcode=Codes.INVALID_PARAM,
) )
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
rooms, total = await self._store.get_destination_rooms_paginate( rooms, total = await self._store.get_destination_rooms_paginate(
destination, start, limit, direction destination, start, limit, direction

View file

@ -17,9 +17,16 @@ import logging
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import Direction
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string from synapse.http.servlet import (
RestServlet,
parse_boolean,
parse_enum,
parse_integer,
parse_string,
)
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import ( from synapse.rest.admin._base import (
admin_patterns, admin_patterns,
@ -389,7 +396,7 @@ class UserMediaRestServlet(RestServlet):
# to newest media is on top for backward compatibility. # to newest media is on top for backward compatibility.
if b"order_by" not in request.args and b"dir" not in request.args: if b"order_by" not in request.args and b"dir" not in request.args:
order_by = MediaSortOrder.CREATED_TS.value order_by = MediaSortOrder.CREATED_TS.value
direction = "b" direction = Direction.BACKWARDS
else: else:
order_by = parse_string( order_by = parse_string(
request, request,
@ -397,8 +404,8 @@ class UserMediaRestServlet(RestServlet):
default=MediaSortOrder.CREATED_TS.value, default=MediaSortOrder.CREATED_TS.value,
allowed_values=[sort_order.value for sort_order in MediaSortOrder], allowed_values=[sort_order.value for sort_order in MediaSortOrder],
) )
direction = parse_string( direction = parse_enum(
request, "dir", default="f", allowed_values=("f", "b") request, "dir", Direction, default=Direction.FORWARDS
) )
media, total = await self.store.get_local_media_by_user_paginate( media, total = await self.store.get_local_media_by_user_paginate(
@ -447,7 +454,7 @@ class UserMediaRestServlet(RestServlet):
# to newest media is on top for backward compatibility. # to newest media is on top for backward compatibility.
if b"order_by" not in request.args and b"dir" not in request.args: if b"order_by" not in request.args and b"dir" not in request.args:
order_by = MediaSortOrder.CREATED_TS.value order_by = MediaSortOrder.CREATED_TS.value
direction = "b" direction = Direction.BACKWARDS
else: else:
order_by = parse_string( order_by = parse_string(
request, request,
@ -455,8 +462,8 @@ class UserMediaRestServlet(RestServlet):
default=MediaSortOrder.CREATED_TS.value, default=MediaSortOrder.CREATED_TS.value,
allowed_values=[sort_order.value for sort_order in MediaSortOrder], allowed_values=[sort_order.value for sort_order in MediaSortOrder],
) )
direction = parse_string( direction = parse_enum(
request, "dir", default="f", allowed_values=("f", "b") request, "dir", Direction, default=Direction.FORWARDS
) )
media, _ = await self.store.get_local_media_by_user_paginate( media, _ = await self.store.get_local_media_by_user_paginate(

View file

@ -16,13 +16,14 @@ from http import HTTPStatus
from typing import TYPE_CHECKING, List, Optional, Tuple, cast from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from urllib import parse as urlparse from urllib import parse as urlparse
from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.constants import Direction, EventTypes, JoinRules, Membership
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.api.filtering import Filter from synapse.api.filtering import Filter
from synapse.http.servlet import ( from synapse.http.servlet import (
ResolveRoomIdMixin, ResolveRoomIdMixin,
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
parse_enum,
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
@ -224,15 +225,8 @@ class ListRoomRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM, errcode=Codes.INVALID_PARAM,
) )
direction = parse_string(request, "dir", default="f") direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
if direction not in ("f", "b"): reverse_order = True if direction == Direction.BACKWARDS else False
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Unknown direction: %s" % (direction,),
errcode=Codes.INVALID_PARAM,
)
reverse_order = True if direction == "b" else False
# Return list of rooms according to parameters # Return list of rooms according to parameters
rooms, total_rooms = await self.store.get_rooms_paginate( rooms, total_rooms = await self.store.get_rooms_paginate(
@ -949,7 +943,7 @@ class RoomTimestampToEventRestServlet(RestServlet):
await assert_user_is_admin(self._auth, requester) await assert_user_is_admin(self._auth, requester)
timestamp = parse_integer(request, "ts", required=True) timestamp = parse_integer(request, "ts", required=True)
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
( (
event_id, event_id,

View file

@ -16,8 +16,9 @@ import logging
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.constants import Direction
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.storage.databases.main.stats import UserSortOrder from synapse.storage.databases.main.stats import UserSortOrder
@ -102,13 +103,7 @@ class UserMediaStatisticsRestServlet(RestServlet):
errcode=Codes.INVALID_PARAM, errcode=Codes.INVALID_PARAM,
) )
direction = parse_string(request, "dir", default="f") direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
if direction not in ("f", "b"):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Unknown direction: %s" % (direction,),
errcode=Codes.INVALID_PARAM,
)
users_media, total = await self.store.get_users_media_usage_paginate( users_media, total = await self.store.get_users_media_usage_paginate(
start, limit, from_ts, until_ts, order_by, direction, search_term start, limit, from_ts, until_ts, order_by, direction, search_term

View file

@ -18,12 +18,13 @@ import secrets
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes from synapse.api.constants import Direction, UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
parse_boolean, parse_boolean,
parse_enum,
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
@ -120,7 +121,7 @@ class UsersRestServletV2(RestServlet):
), ),
) )
direction = parse_string(request, "dir", default="f", allowed_values=("f", "b")) direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
users, total = await self.store.get_users_paginate( users, total = await self.store.get_users_paginate(
start, start,

View file

@ -16,6 +16,7 @@ import logging
import re import re
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import Direction
from synapse.handlers.relations import ThreadsListInclude from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_integer, parse_string
@ -59,7 +60,7 @@ class RelationPaginationServlet(RestServlet):
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = await PaginationConfig.from_request( pagination_config = await PaginationConfig.from_request(
self._store, request, default_limit=5, default_dir="b" self._store, request, default_limit=5, default_dir=Direction.BACKWARDS
) )
# The unstable version of this API returns an extra field for client # The unstable version of this API returns an extra field for client

View file

@ -26,7 +26,7 @@ from prometheus_client.core import Histogram
from twisted.web.server import Request from twisted.web.server import Request
from synapse import event_auth from synapse import event_auth
from synapse.api.constants import EventTypes, Membership from synapse.api.constants import Direction, EventTypes, Membership
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
Codes, Codes,
@ -44,6 +44,7 @@ from synapse.http.servlet import (
RestServlet, RestServlet,
assert_params_in_dict, assert_params_in_dict,
parse_boolean, parse_boolean,
parse_enum,
parse_integer, parse_integer,
parse_json_object_from_request, parse_json_object_from_request,
parse_string, parse_string,
@ -1297,7 +1298,7 @@ class TimestampLookupRestServlet(RestServlet):
await self._auth.check_user_in_room_or_world_readable(room_id, requester) await self._auth.check_user_in_room_or_world_readable(room_id, requester)
timestamp = parse_integer(request, "ts", required=True) timestamp = parse_integer(request, "ts", required=True)
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) direction = parse_enum(request, "dir", Direction, default=Direction.FORWARDS)
( (
event_id, event_id,

View file

@ -17,6 +17,7 @@
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Tuple, cast from typing import TYPE_CHECKING, List, Optional, Tuple, cast
from synapse.api.constants import Direction
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
@ -167,7 +168,7 @@ class DataStore(
guests: bool = True, guests: bool = True,
deactivated: bool = False, deactivated: bool = False,
order_by: str = UserSortOrder.NAME.value, order_by: str = UserSortOrder.NAME.value,
direction: str = "f", direction: Direction = Direction.FORWARDS,
approved: bool = True, approved: bool = True,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from """Function to retrieve a paginated list of users from
@ -197,7 +198,7 @@ class DataStore(
# Set ordering # Set ordering
order_by_column = UserSortOrder(order_by).value order_by_column = UserSortOrder(order_by).value
if direction == "b": if direction == Direction.BACKWARDS:
order = "DESC" order = "DESC"
else: else:
order = "ASC" order = "ASC"

View file

@ -38,7 +38,7 @@ from typing_extensions import Literal
from twisted.internet import defer from twisted.internet import defer
from synapse.api.constants import EventTypes from synapse.api.constants import Direction, EventTypes
from synapse.api.errors import NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.room_versions import ( from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS, KNOWN_ROOM_VERSIONS,
@ -2240,7 +2240,7 @@ class EventsWorkerStore(SQLBaseStore):
) )
async def get_event_id_for_timestamp( async def get_event_id_for_timestamp(
self, room_id: str, timestamp: int, direction: str self, room_id: str, timestamp: int, direction: Direction
) -> Optional[str]: ) -> Optional[str]:
"""Find the closest event to the given timestamp in the given direction. """Find the closest event to the given timestamp in the given direction.
@ -2248,14 +2248,14 @@ class EventsWorkerStore(SQLBaseStore):
room_id: Room to fetch the event from room_id: Room to fetch the event from
timestamp: The point in time (inclusive) we should navigate from in timestamp: The point in time (inclusive) we should navigate from in
the given direction to find the closest event. the given direction to find the closest event.
direction: ["f"|"b"] to indicate whether we should navigate forward direction: indicates whether we should navigate forward
or backward from the given timestamp to find the closest event. or backward from the given timestamp to find the closest event.
Returns: Returns:
The closest event_id otherwise None if we can't find any event in The closest event_id otherwise None if we can't find any event in
the given direction. the given direction.
""" """
if direction == "b": if direction == Direction.BACKWARDS:
# Find closest event *before* a given timestamp. We use descending # Find closest event *before* a given timestamp. We use descending
# (which gives values largest to smallest) because we want the # (which gives values largest to smallest) because we want the
# largest possible timestamp *before* the given timestamp. # largest possible timestamp *before* the given timestamp.
@ -2307,9 +2307,6 @@ class EventsWorkerStore(SQLBaseStore):
return None return None
if direction not in ("f", "b"):
raise ValueError("Unknown direction: %s" % (direction,))
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"get_event_id_for_timestamp_txn", "get_event_id_for_timestamp_txn",
get_event_id_for_timestamp_txn, get_event_id_for_timestamp_txn,

View file

@ -26,6 +26,7 @@ from typing import (
cast, cast,
) )
from synapse.api.constants import Direction
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
@ -176,7 +177,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
limit: int, limit: int,
user_id: str, user_id: str,
order_by: str = MediaSortOrder.CREATED_TS.value, order_by: str = MediaSortOrder.CREATED_TS.value,
direction: str = "f", direction: Direction = Direction.FORWARDS,
) -> Tuple[List[Dict[str, Any]], int]: ) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media """Get a paginated list of metadata for a local piece of media
which an user_id has uploaded which an user_id has uploaded
@ -199,7 +200,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
# Set ordering # Set ordering
order_by_column = MediaSortOrder(order_by).value order_by_column = MediaSortOrder(order_by).value
if direction == "b": if direction == Direction.BACKWARDS:
order = "DESC" order = "DESC"
else: else:
order = "ASC" order = "ASC"

View file

@ -35,6 +35,7 @@ from typing import (
import attr import attr
from synapse.api.constants import ( from synapse.api.constants import (
Direction,
EventContentFields, EventContentFields,
EventTypes, EventTypes,
JoinRules, JoinRules,
@ -2204,7 +2205,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
self, self,
start: int, start: int,
limit: int, limit: int,
direction: str = "b", direction: Direction = Direction.BACKWARDS,
user_id: Optional[str] = None, user_id: Optional[str] = None,
room_id: Optional[str] = None, room_id: Optional[str] = None,
) -> Tuple[List[Dict[str, Any]], int]: ) -> Tuple[List[Dict[str, Any]], int]:
@ -2213,8 +2214,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
Args: Args:
start: event offset to begin the query from start: event offset to begin the query from
limit: number of rows to retrieve limit: number of rows to retrieve
direction: Whether to fetch the most recent first (`"b"`) or the direction: Whether to fetch the most recent first (backwards) or the
oldest first (`"f"`) oldest first (forwards)
user_id: search for user_id. Ignored if user_id is None user_id: search for user_id. Ignored if user_id is None
room_id: search for room_id. Ignored if room_id is None room_id: search for room_id. Ignored if room_id is None
Returns: Returns:
@ -2236,7 +2237,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
filters.append("er.room_id LIKE ?") filters.append("er.room_id LIKE ?")
args.extend(["%" + room_id + "%"]) args.extend(["%" + room_id + "%"])
if direction == "b": if direction == Direction.BACKWARDS:
order = "DESC" order = "DESC"
else: else:
order = "ASC" order = "ASC"

View file

@ -22,7 +22,7 @@ from typing_extensions import Counter
from twisted.internet.defer import DeferredLock from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.constants import Direction, EventContentFields, EventTypes, Membership
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.storage.database import ( from synapse.storage.database import (
DatabasePool, DatabasePool,
@ -663,7 +663,7 @@ class StatsStore(StateDeltasStore):
from_ts: Optional[int] = None, from_ts: Optional[int] = None,
until_ts: Optional[int] = None, until_ts: Optional[int] = None,
order_by: Optional[str] = UserSortOrder.USER_ID.value, order_by: Optional[str] = UserSortOrder.USER_ID.value,
direction: Optional[str] = "f", direction: Direction = Direction.FORWARDS,
search_term: Optional[str] = None, search_term: Optional[str] = None,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users and their uploaded local media """Function to retrieve a paginated list of users and their uploaded local media
@ -714,7 +714,7 @@ class StatsStore(StateDeltasStore):
500, "Incorrect value for order_by provided: %s" % order_by 500, "Incorrect value for order_by provided: %s" % order_by
) )
if direction == "b": if direction == Direction.BACKWARDS:
order = "DESC" order = "DESC"
else: else:
order = "ASC" order = "ASC"

View file

@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, cast
import attr import attr
from canonicaljson import encode_canonical_json from canonicaljson import encode_canonical_json
from synapse.api.constants import Direction
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import db_to_json from synapse.storage._base import db_to_json
from synapse.storage.database import ( from synapse.storage.database import (
@ -496,7 +497,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
limit: int, limit: int,
destination: Optional[str] = None, destination: Optional[str] = None,
order_by: str = DestinationSortOrder.DESTINATION.value, order_by: str = DestinationSortOrder.DESTINATION.value,
direction: str = "f", direction: Direction = Direction.FORWARDS,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of destinations. """Function to retrieve a paginated list of destinations.
This will return a json list of destinations and the This will return a json list of destinations and the
@ -518,7 +519,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
order_by_column = DestinationSortOrder(order_by).value order_by_column = DestinationSortOrder(order_by).value
if direction == "b": if direction == Direction.BACKWARDS:
order = "DESC" order = "DESC"
else: else:
order = "ASC" order = "ASC"
@ -550,7 +551,11 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
) )
async def get_destination_rooms_paginate( async def get_destination_rooms_paginate(
self, destination: str, start: int, limit: int, direction: str = "f" self,
destination: str,
start: int,
limit: int,
direction: Direction = Direction.FORWARDS,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of destination's rooms. """Function to retrieve a paginated list of destination's rooms.
This will return a json list of rooms and the This will return a json list of rooms and the
@ -569,7 +574,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Tuple[List[JsonDict], int]: ) -> Tuple[List[JsonDict], int]:
if direction == "b": if direction == Direction.BACKWARDS:
order = "DESC" order = "DESC"
else: else:
order = "ASC" order = "ASC"

View file

@ -18,7 +18,7 @@ import attr
from synapse.api.constants import Direction from synapse.api.constants import Direction
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string from synapse.http.servlet import parse_enum, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.storage.databases.main import DataStore from synapse.storage.databases.main import DataStore
from synapse.types import StreamToken from synapse.types import StreamToken
@ -44,15 +44,9 @@ class PaginationConfig:
store: "DataStore", store: "DataStore",
request: SynapseRequest, request: SynapseRequest,
default_limit: int, default_limit: int,
default_dir: str = "f", default_dir: Direction = Direction.FORWARDS,
) -> "PaginationConfig": ) -> "PaginationConfig":
direction_str = parse_string( direction = parse_enum(request, "dir", Direction, default=default_dir)
request,
"dir",
default=default_dir,
allowed_values=[Direction.FORWARDS.value, Direction.BACKWARDS.value],
)
direction = Direction(direction_str)
from_tok_str = parse_string(request, "from") from_tok_str = parse_string(request, "from")
to_tok_str = parse_string(request, "to") to_tok_str = parse_string(request, "to")

View file

@ -280,7 +280,10 @@ class EventReportsTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, channel.code, msg=channel.json_body) self.assertEqual(400, channel.code, msg=channel.json_body)
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"]) self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
self.assertEqual("Unknown direction: bar", channel.json_body["error"]) self.assertEqual(
"Query parameter 'dir' must be one of ['b', 'f']",
channel.json_body["error"],
)
def test_limit_is_negative(self) -> None: def test_limit_is_negative(self) -> None:
""" """