Make token serializing/deserializing async (#8427)

The idea is that in future tokens will encode a mapping of instance to position. However, we don't want to include the full instance name in the string representation, so instead we'll have a mapping between instance name and an immutable integer ID in the DB that we can use instead. We'll then do the lookup when we serialize/deserialize the token (we could alternatively pass around an `Instance` type that includes both the name and ID, but that turns out to be a lot more invasive).
This commit is contained in:
Erik Johnston 2020-09-30 20:29:19 +01:00 committed by GitHub
parent a0a1ba6973
commit 7941372ec8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 115 additions and 59 deletions

1
changelog.d/8427.misc Normal file
View file

@ -0,0 +1 @@
Make stream token serializing/deserializing async.

View file

@ -133,8 +133,8 @@ class EventStreamHandler(BaseHandler):
chunk = {
"chunk": chunks,
"start": tokens[0].to_string(),
"end": tokens[1].to_string(),
"start": await tokens[0].to_string(self.store),
"end": await tokens[1].to_string(self.store),
}
return chunk

View file

@ -203,8 +203,8 @@ class InitialSyncHandler(BaseHandler):
messages, time_now=time_now, as_client_event=as_client_event
)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
}
d["state"] = await self._event_serializer.serialize_events(
@ -249,7 +249,7 @@ class InitialSyncHandler(BaseHandler):
],
"account_data": account_data_events,
"receipts": receipt,
"end": now_token.to_string(),
"end": await now_token.to_string(self.store),
}
return ret
@ -348,8 +348,8 @@ class InitialSyncHandler(BaseHandler):
"chunk": (
await self._event_serializer.serialize_events(messages, time_now)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
},
"state": (
await self._event_serializer.serialize_events(
@ -447,8 +447,8 @@ class InitialSyncHandler(BaseHandler):
"chunk": (
await self._event_serializer.serialize_events(messages, time_now)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
},
"state": state,
"presence": presence,

View file

@ -413,8 +413,8 @@ class PaginationHandler:
if not events:
return {
"chunk": [],
"start": from_token.to_string(),
"end": next_token.to_string(),
"start": await from_token.to_string(self.store),
"end": await next_token.to_string(self.store),
}
state = None
@ -442,8 +442,8 @@ class PaginationHandler:
events, time_now, as_client_event=as_client_event
)
),
"start": from_token.to_string(),
"end": next_token.to_string(),
"start": await from_token.to_string(self.store),
"end": await next_token.to_string(self.store),
}
if state:

View file

@ -1077,11 +1077,13 @@ class RoomContextHandler:
# the token, which we replace.
token = StreamToken.START
results["start"] = token.copy_and_replace(
results["start"] = await token.copy_and_replace(
"room_key", results["start"]
).to_string()
).to_string(self.store)
results["end"] = token.copy_and_replace("room_key", results["end"]).to_string()
results["end"] = await token.copy_and_replace(
"room_key", results["end"]
).to_string(self.store)
return results

View file

@ -362,13 +362,13 @@ class SearchHandler(BaseHandler):
self.storage, user.to_string(), res["events_after"]
)
res["start"] = now_token.copy_and_replace(
res["start"] = await now_token.copy_and_replace(
"room_key", res["start"]
).to_string()
).to_string(self.store)
res["end"] = now_token.copy_and_replace(
res["end"] = await now_token.copy_and_replace(
"room_key", res["end"]
).to_string()
).to_string(self.store)
if include_profile:
senders = {

View file

@ -110,7 +110,7 @@ class PurgeHistoryRestServlet(RestServlet):
raise SynapseError(400, "Event is for wrong room.")
room_token = await self.store.get_topological_token_for_event(event_id)
token = str(room_token)
token = await room_token.to_string(self.store)
logger.info("[purge] purging up to token %s (event_id %s)", token, event_id)
elif "purge_up_to_ts" in body:

View file

@ -33,6 +33,7 @@ class EventStreamRestServlet(RestServlet):
super().__init__()
self.event_stream_handler = hs.get_event_stream_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
@ -44,7 +45,7 @@ class EventStreamRestServlet(RestServlet):
if b"room_id" in request.args:
room_id = request.args[b"room_id"][0].decode("ascii")
pagin_config = PaginationConfig.from_request(request)
pagin_config = await PaginationConfig.from_request(self.store, request)
timeout = EventStreamRestServlet.DEFAULT_LONGPOLL_TIME_MS
if b"timeout" in request.args:
try:

View file

@ -27,11 +27,12 @@ class InitialSyncRestServlet(RestServlet):
super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request)
as_client_event = b"raw" not in request.args
pagination_config = PaginationConfig.from_request(request)
pagination_config = await PaginationConfig.from_request(self.store, request)
include_archived = parse_boolean(request, "archived", default=False)
content = await self.initial_sync_handler.snapshot_all_rooms(
user_id=requester.user.to_string(),

View file

@ -451,6 +451,7 @@ class RoomMemberListRestServlet(RestServlet):
super().__init__()
self.message_handler = hs.get_message_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
# TODO support Pagination stream API (limit/tokens)
@ -465,7 +466,7 @@ class RoomMemberListRestServlet(RestServlet):
if at_token_string is None:
at_token = None
else:
at_token = StreamToken.from_string(at_token_string)
at_token = await StreamToken.from_string(self.store, at_token_string)
# let you filter down on particular memberships.
# XXX: this may not be the best shape for this API - we could pass in a filter
@ -521,10 +522,13 @@ class RoomMessageListRestServlet(RestServlet):
super().__init__()
self.pagination_handler = hs.get_pagination_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request, default_limit=10)
pagination_config = await PaginationConfig.from_request(
self.store, request, default_limit=10
)
as_client_event = b"raw" not in request.args
filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str:
@ -580,10 +584,11 @@ class RoomInitialSyncRestServlet(RestServlet):
super().__init__()
self.initial_sync_handler = hs.get_initial_sync_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
pagination_config = PaginationConfig.from_request(request)
pagination_config = await PaginationConfig.from_request(self.store, request)
content = await self.initial_sync_handler.room_initial_sync(
room_id=room_id, requester=requester, pagin_config=pagination_config
)

View file

@ -180,6 +180,7 @@ class KeyChangesServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
async def on_GET(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
@ -191,7 +192,7 @@ class KeyChangesServlet(RestServlet):
# changes after the "to" as well as before.
set_tag("to", parse_string(request, "to"))
from_token = StreamToken.from_string(from_token_string)
from_token = await StreamToken.from_string(self.store, from_token_string)
user_id = requester.user.to_string()

View file

@ -77,6 +77,7 @@ class SyncRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
self.sync_handler = hs.get_sync_handler()
self.clock = hs.get_clock()
self.filtering = hs.get_filtering()
@ -151,10 +152,9 @@ class SyncRestServlet(RestServlet):
device_id=device_id,
)
since_token = None
if since is not None:
since_token = StreamToken.from_string(since)
else:
since_token = None
since_token = await StreamToken.from_string(self.store, since)
# send any outstanding server notices to the user.
await self._server_notices_sender.on_user_syncing(user.to_string())
@ -236,7 +236,7 @@ class SyncRestServlet(RestServlet):
"leave": sync_result.groups.leave,
},
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
"next_batch": sync_result.next_batch.to_string(),
"next_batch": await sync_result.next_batch.to_string(self.store),
}
@staticmethod
@ -413,7 +413,7 @@ class SyncRestServlet(RestServlet):
result = {
"timeline": {
"events": serialized_timeline,
"prev_batch": room.timeline.prev_batch.to_string(),
"prev_batch": await room.timeline.prev_batch.to_string(self.store),
"limited": room.timeline.limited,
},
"state": {"events": serialized_state},

View file

@ -42,17 +42,17 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
The set of state groups that are referenced by deleted events.
"""
parsed_token = await RoomStreamToken.parse(self, token)
return await self.db_pool.runInteraction(
"purge_history",
self._purge_history_txn,
room_id,
token,
parsed_token,
delete_local_events,
)
def _purge_history_txn(self, txn, room_id, token_str, delete_local_events):
token = RoomStreamToken.parse(token_str)
def _purge_history_txn(self, txn, room_id, token, delete_local_events):
# Tables that should be pruned:
# event_auth
# event_backward_extremities

View file

@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
@ -21,6 +20,7 @@ import attr
from synapse.api.errors import SynapseError
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.storage.databases.main import DataStore
from synapse.types import StreamToken
logger = logging.getLogger(__name__)
@ -39,8 +39,9 @@ class PaginationConfig:
limit = attr.ib(type=Optional[int])
@classmethod
def from_request(
async def from_request(
cls,
store: "DataStore",
request: SynapseRequest,
raise_invalid_params: bool = True,
default_limit: Optional[int] = None,
@ -54,13 +55,13 @@ class PaginationConfig:
if from_tok == "END":
from_tok = None # For backwards compat.
elif from_tok:
from_tok = StreamToken.from_string(from_tok)
from_tok = await StreamToken.from_string(store, from_tok)
except Exception:
raise SynapseError(400, "'from' parameter is invalid")
try:
if to_tok:
to_tok = StreamToken.from_string(to_tok)
to_tok = await StreamToken.from_string(store, to_tok)
except Exception:
raise SynapseError(400, "'to' parameter is invalid")

View file

@ -18,7 +18,17 @@ import re
import string
import sys
from collections import namedtuple
from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
from typing import (
TYPE_CHECKING,
Any,
Dict,
Mapping,
MutableMapping,
Optional,
Tuple,
Type,
TypeVar,
)
import attr
from signedjson.key import decode_verify_key_bytes
@ -26,6 +36,9 @@ from unpaddedbase64 import decode_base64
from synapse.api.errors import Codes, SynapseError
if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore
# define a version of typing.Collection that works on python 3.5
if sys.version_info[:3] >= (3, 6, 0):
from typing import Collection
@ -393,7 +406,7 @@ class RoomStreamToken:
stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
@classmethod
def parse(cls, string: str) -> "RoomStreamToken":
async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
try:
if string[0] == "s":
return cls(topological=None, stream=int(string[1:]))
@ -428,7 +441,7 @@ class RoomStreamToken:
def as_tuple(self) -> Tuple[Optional[int], int]:
return (self.topological, self.stream)
def __str__(self) -> str:
async def to_string(self, store: "DataStore") -> str:
if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream)
else:
@ -453,18 +466,32 @@ class StreamToken:
START = None # type: StreamToken
@classmethod
def from_string(cls, string):
async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
try:
keys = string.split(cls._SEPARATOR)
while len(keys) < len(attr.fields(cls)):
# i.e. old token from before receipt_key
keys.append("0")
return cls(RoomStreamToken.parse(keys[0]), *(int(k) for k in keys[1:]))
return cls(
await RoomStreamToken.parse(store, keys[0]), *(int(k) for k in keys[1:])
)
except Exception:
raise SynapseError(400, "Invalid Token")
def to_string(self):
return self._SEPARATOR.join([str(k) for k in attr.astuple(self, recurse=False)])
async def to_string(self, store: "DataStore") -> str:
return self._SEPARATOR.join(
[
await self.room_key.to_string(store),
str(self.presence_key),
str(self.typing_key),
str(self.receipt_key),
str(self.account_data_key),
str(self.push_rules_key),
str(self.to_device_key),
str(self.device_list_key),
str(self.groups_key),
]
)
@property
def room_stream_id(self):
@ -493,7 +520,7 @@ class StreamToken:
return attr.evolve(self, **{key: new_value})
StreamToken.START = StreamToken.from_string("s0_0")
StreamToken.START = StreamToken(RoomStreamToken(None, 0), 0, 0, 0, 0, 0, 0, 0, 0)
@attr.s(slots=True, frozen=True)

View file

@ -902,16 +902,18 @@ class RoomMessageListTestCase(RoomBase):
# Send a first message in the room, which will be removed by the purge.
first_event_id = self.helper.send(self.room_id, "message 1")["event_id"]
first_token = str(
self.get_success(store.get_topological_token_for_event(first_event_id))
first_token = self.get_success(
store.get_topological_token_for_event(first_event_id)
)
first_token_str = self.get_success(first_token.to_string(store))
# Send a second message in the room, which won't be removed, and which we'll
# use as the marker to purge events before.
second_event_id = self.helper.send(self.room_id, "message 2")["event_id"]
second_token = str(
self.get_success(store.get_topological_token_for_event(second_event_id))
second_token = self.get_success(
store.get_topological_token_for_event(second_event_id)
)
second_token_str = self.get_success(second_token.to_string(store))
# Send a third event in the room to ensure we don't fall under any edge case
# due to our marker being the latest forward extremity in the room.
@ -921,7 +923,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
% (
self.room_id,
second_token_str,
json.dumps({"types": [EventTypes.Message]}),
),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
@ -936,7 +942,7 @@ class RoomMessageListTestCase(RoomBase):
pagination_handler._purge_history(
purge_id=purge_id,
room_id=self.room_id,
token=second_token,
token=second_token_str,
delete_local_events=True,
)
)
@ -946,7 +952,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (self.room_id, second_token, json.dumps({"types": [EventTypes.Message]})),
% (
self.room_id,
second_token_str,
json.dumps({"types": [EventTypes.Message]}),
),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
@ -960,7 +970,11 @@ class RoomMessageListTestCase(RoomBase):
request, channel = self.make_request(
"GET",
"/rooms/%s/messages?access_token=x&from=%s&dir=b&filter=%s"
% (self.room_id, first_token, json.dumps({"types": [EventTypes.Message]})),
% (
self.room_id,
first_token_str,
json.dumps({"types": [EventTypes.Message]}),
),
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)

View file

@ -47,12 +47,15 @@ class PurgeTests(HomeserverTestCase):
storage = self.hs.get_storage()
# Get the topological token
event = str(
self.get_success(store.get_topological_token_for_event(last["event_id"]))
token = self.get_success(
store.get_topological_token_for_event(last["event_id"])
)
token_str = self.get_success(token.to_string(self.hs.get_datastore()))
# Purge everything before this topological token
self.get_success(storage.purge_events.purge_history(self.room_id, event, True))
self.get_success(
storage.purge_events.purge_history(self.room_id, token_str, True)
)
# 1-3 should fail and last will succeed, meaning that 1-3 are deleted
# and last is not.