Persist CreateRoom events to DB in a batch (#13800)

This commit is contained in:
Shay 2022-09-28 03:11:48 -07:00 committed by GitHub
parent a2cf66a94d
commit 8ab16a92ed
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 580 additions and 356 deletions

1
changelog.d/13800.misc Normal file
View file

@ -0,0 +1 @@
Speed up creation of DM rooms.

View file

@ -56,11 +56,13 @@ from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import (
MutableStateMap,
PersistedEventPosition,
Requester,
RoomAlias,
StateMap,
@ -493,6 +495,7 @@ class EventCreationHandler:
self.membership_types_to_include_profile_data_in.add(Membership.INVITE)
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
self.send_events = ReplicationSendEventsRestServlet.make_client(hs)
self.request_ratelimiter = hs.get_request_ratelimiter()
@ -1016,8 +1019,7 @@ class EventCreationHandler:
ev = await self.handle_new_client_event(
requester=requester,
event=event,
context=context,
events_and_context=[(event, context)],
ratelimit=ratelimit,
ignore_shadow_ban=ignore_shadow_ban,
)
@ -1293,13 +1295,13 @@ class EventCreationHandler:
async def handle_new_client_event(
self,
requester: Requester,
event: EventBase,
context: EventContext,
events_and_context: List[Tuple[EventBase, EventContext]],
ratelimit: bool = True,
extra_users: Optional[List[UserID]] = None,
ignore_shadow_ban: bool = False,
) -> EventBase:
"""Processes a new event.
"""Processes new events. Please note that if batch persisting events, an error in
handling any one of these events will result in all of the events being dropped.
This includes deduplicating, checking auth, persisting,
notifying users, sending to remote servers, etc.
@ -1309,8 +1311,7 @@ class EventCreationHandler:
Args:
requester
event
context
events_and_context: A list of one or more tuples of event, context to be persisted
ratelimit
extra_users: Any extra users to notify about event
@ -1328,6 +1329,7 @@ class EventCreationHandler:
"""
extra_users = extra_users or []
for event, context in events_and_context:
# we don't apply shadow-banning to membership events here. Invites are blocked
# higher up the stack, and we allow shadow-banned users to send join and leave
# events as normal.
@ -1375,15 +1377,15 @@ class EventCreationHandler:
# We now persist the event (and update the cache in parallel, since we
# don't want to block on it).
event, context = events_and_context[0]
try:
result, _ = await make_deferred_yieldable(
gather_results(
(
run_in_background(
self._persist_event,
self._persist_events,
requester=requester,
event=event,
context=context,
events_and_context=events_and_context,
ratelimit=ratelimit,
extra_users=extra_users,
),
@ -1407,22 +1409,23 @@ class EventCreationHandler:
return result
async def _persist_event(
async def _persist_events(
self,
requester: Requester,
event: EventBase,
context: EventContext,
events_and_context: List[Tuple[EventBase, EventContext]],
ratelimit: bool = True,
extra_users: Optional[List[UserID]] = None,
) -> EventBase:
"""Actually persists the event. Should only be called by
"""Actually persists new events. Should only be called by
`handle_new_client_event`, and see its docstring for documentation of
the arguments.
the arguments. Please note that if batch persisting events, an error in
handling any one of these events will result in all of the events being dropped.
PartialStateConflictError: if attempting to persist a partial state event in
a room that has been un-partial stated.
"""
for event, context in events_and_context:
# Skip push notification actions for historical messages
# because we don't want to notify people about old history back in time.
# The historical messages also do not have the proper `context.current_state_ids`
@ -1436,16 +1439,17 @@ class EventCreationHandler:
try:
# If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id)
first_event, _ = events_and_context[0]
writer_instance = self._events_shard_config.get_instance(
first_event.room_id
)
if writer_instance != self._instance_name:
try:
result = await self.send_event(
result = await self.send_events(
instance_name=writer_instance,
event_id=event.event_id,
events_and_context=events_and_context,
store=self.store,
requester=requester,
event=event,
context=context,
ratelimit=ratelimit,
extra_users=extra_users,
)
@ -1455,6 +1459,11 @@ class EventCreationHandler:
raise
stream_id = result["stream_id"]
event_id = result["event_id"]
# If we batch persisted events we return the last persisted event, otherwise
# we return the one event that was persisted
event, _ = events_and_context[-1]
if event_id != event.event_id:
# If we get a different event back then it means that its
# been de-duplicated, so we replace the given event with the
@ -1467,12 +1476,16 @@ class EventCreationHandler:
event.internal_metadata.stream_ordering = stream_id
return event
event = await self.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
event = await self.persist_and_notify_client_events(
requester,
events_and_context,
ratelimit=ratelimit,
extra_users=extra_users,
)
return event
except Exception:
for event, _ in events_and_context:
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
await self.store.remove_push_actions_from_staging(event.event_id)
@ -1569,23 +1582,26 @@ class EventCreationHandler:
Codes.BAD_ALIAS,
)
async def persist_and_notify_client_event(
async def persist_and_notify_client_events(
self,
requester: Requester,
event: EventBase,
context: EventContext,
events_and_context: List[Tuple[EventBase, EventContext]],
ratelimit: bool = True,
extra_users: Optional[List[UserID]] = None,
) -> EventBase:
"""Called when we have fully built the event, have already
calculated the push actions for the event, and checked auth.
"""Called when we have fully built the events, have already
calculated the push actions for the events, and checked auth.
This should only be run on the instance in charge of persisting events.
Please note that if batch persisting events, an error in
handling any one of these events will result in all of the events being dropped.
Returns:
The persisted event. This may be different than the given event if
it was de-duplicated (e.g. because we had already persisted an
event with the same transaction ID.)
The persisted event, if one event is passed in, or the last event in the
list in the case of batch persisting. If only one event was persisted, the
returned event may be different than the given event if it was de-duplicated
(e.g. because we had already persisted an event with the same transaction ID.)
Raises:
PartialStateConflictError: if attempting to persist a partial state event in
@ -1593,7 +1609,7 @@ class EventCreationHandler:
"""
extra_users = extra_users or []
assert self._storage_controllers.persistence is not None
for event, context in events_and_context:
assert self._events_shard_config.should_handle(
self._instance_name, event.room_id
)
@ -1623,6 +1639,7 @@ class EventCreationHandler:
requester, is_admin_redaction=is_admin_redaction
)
# run checks/actions on event based on type
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
(
current_membership,
@ -1631,7 +1648,9 @@ class EventCreationHandler:
event.state_key, event.room_id
)
if current_membership != Membership.JOIN:
self._notifier.notify_user_joined_room(event.event_id, event.room_id)
self._notifier.notify_user_joined_room(
event.event_id, event.room_id
)
await self._maybe_kick_guest_users(event, context)
@ -1643,11 +1662,13 @@ class EventCreationHandler:
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:
original_event = await self.store.get_event(original_event_id)
original_alias_event = await self.store.get_event(original_event_id)
if original_event:
original_alias = original_event.content.get("alias", None)
original_alt_aliases = original_event.content.get("alt_aliases", [])
if original_alias_event:
original_alias = original_alias_event.content.get("alias", None)
original_alt_aliases = original_alias_event.content.get(
"alt_aliases", []
)
# Check the alias is currently valid (if it has changed).
room_alias_str = event.content.get("alias", None)
@ -1661,7 +1682,9 @@ class EventCreationHandler:
alt_aliases = event.content.get("alt_aliases", [])
if not isinstance(alt_aliases, (list, tuple)):
raise SynapseError(
400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM
400,
"The alt_aliases property must be a list.",
Codes.INVALID_PARAM,
)
# If the old version of alt_aliases is of an unknown form,
@ -1732,10 +1755,14 @@ class EventCreationHandler:
raise AuthError(403, "Redacting create events is not permitted")
if original_event.room_id != event.room_id:
raise SynapseError(400, "Cannot redact event from a different room")
raise SynapseError(
400, "Cannot redact event from a different room"
)
if original_event.type == EventTypes.ServerACL:
raise AuthError(403, "Redacting server ACL events is not permitted")
raise AuthError(
403, "Redacting server ACL events is not permitted"
)
# Add a little safety stop-gap to prevent people from trying to
# redact MSC2716 related events when they're in a room version
@ -1768,7 +1795,9 @@ class EventCreationHandler:
event, prev_state_ids, for_verification=True
)
auth_events_map = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
auth_events = {
(e.type, e.state_key): e for e in auth_events_map.values()
}
if event_auth.check_redaction(
room_version_obj, event, auth_events=auth_events
@ -1777,10 +1806,14 @@ class EventCreationHandler:
# checks on the original event. Let's start by checking the original
# event exists.
if not original_event:
raise NotFoundError("Could not find event %s" % (event.redacts,))
raise NotFoundError(
"Could not find event %s" % (event.redacts,)
)
if event.user_id != original_event.user_id:
raise AuthError(403, "You don't have permission to redact events")
raise AuthError(
403, "You don't have permission to redact events"
)
# all the checks are done.
event.internal_metadata.recheck_redaction = False
@ -1831,24 +1864,27 @@ class EventCreationHandler:
if event.internal_metadata.is_historical():
backfilled = True
# Note that this returns the event that was persisted, which may not be
# the same as we passed in if it was deduplicated due transaction IDs.
assert self._storage_controllers.persistence is not None
(
event,
event_pos,
persisted_events,
max_stream_token,
) = await self._storage_controllers.persistence.persist_event(
event, context=context, backfilled=backfilled
) = await self._storage_controllers.persistence.persist_events(
events_and_context, backfilled=backfilled
)
for event in persisted_events:
if self._ephemeral_events_enabled:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
stream_ordering = event.internal_metadata.stream_ordering
assert stream_ordering is not None
pos = PersistedEventPosition(self._instance_name, stream_ordering)
async def _notify() -> None:
try:
await self.notifier.on_new_room_event(
event, event_pos, max_stream_token, extra_users=extra_users
event, pos, max_stream_token, extra_users=extra_users
)
except Exception:
logger.exception(
@ -1863,7 +1899,7 @@ class EventCreationHandler:
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
return event
return persisted_events[-1]
async def _maybe_kick_guest_users(
self, event: EventBase, context: EventContext
@ -1952,8 +1988,7 @@ class EventCreationHandler:
# shadow-banned user.
await self.handle_new_client_event(
requester,
event,
context,
events_and_context=[(event, context)],
ratelimit=False,
ignore_shadow_ban=True,
)

View file

@ -301,8 +301,7 @@ class RoomCreationHandler:
# now send the tombstone
await self.event_creation_handler.handle_new_client_event(
requester=requester,
event=tombstone_event,
context=tombstone_context,
events_and_context=[(tombstone_event, tombstone_context)],
)
state_filter = StateFilter.from_types(
@ -1057,8 +1056,10 @@ class RoomCreationHandler:
creator_id = creator.user.to_string()
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
depth = 1
# the last event sent/persisted to the db
last_sent_event_id: Optional[str] = None
# the most recently created event
prev_event: List[str] = []
# a map of event types, state keys -> event_ids. We collect these mappings this as events are
@ -1112,8 +1113,7 @@ class RoomCreationHandler:
ev = await self.event_creation_handler.handle_new_client_event(
requester=creator,
event=event,
context=context,
events_and_context=[(event, context)],
ratelimit=False,
ignore_shadow_ban=True,
)
@ -1152,7 +1152,6 @@ class RoomCreationHandler:
prev_event_ids=[last_sent_event_id],
depth=depth,
)
last_sent_event_id = member_event_id
prev_event = [member_event_id]
# update the depth and state map here as the membership event has been created
@ -1168,7 +1167,7 @@ class RoomCreationHandler:
EventTypes.PowerLevels, pl_content, False
)
current_state_group = power_context._state_group
last_sent_stream_id = await send(power_event, power_context, creator)
await send(power_event, power_context, creator)
else:
power_level_content: JsonDict = {
"users": {creator_id: 100},
@ -1217,7 +1216,7 @@ class RoomCreationHandler:
False,
)
current_state_group = pl_context._state_group
last_sent_stream_id = await send(pl_event, pl_context, creator)
await send(pl_event, pl_context, creator)
events_to_send = []
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
@ -1271,9 +1270,11 @@ class RoomCreationHandler:
)
events_to_send.append((encryption_event, encryption_context))
for event, context in events_to_send:
last_sent_stream_id = await send(event, context, creator)
return last_sent_stream_id, last_sent_event_id, depth
last_event = await self.event_creation_handler.handle_new_client_event(
creator, events_to_send, ignore_shadow_ban=True
)
assert last_event.internal_metadata.stream_ordering is not None
return last_event.internal_metadata.stream_ordering, last_event.event_id, depth
def _generate_room_id(self) -> str:
"""Generates a random room ID.

View file

@ -379,8 +379,7 @@ class RoomBatchHandler:
await self.create_requester_for_user_id_from_app_service(
event.sender, app_service_requester.app_service
),
event=event,
context=context,
events_and_context=[(event, context)],
)
return event_ids

View file

@ -432,8 +432,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
with opentracing.start_active_span("handle_new_client_event"):
result_event = await self.event_creation_handler.handle_new_client_event(
requester,
event,
context,
events_and_context=[(event, context)],
extra_users=[target],
ratelimit=ratelimit,
)
@ -1252,7 +1251,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
raise SynapseError(403, "This room has been blocked on this server")
event = await self.event_creation_handler.handle_new_client_event(
requester, event, context, extra_users=[target_user], ratelimit=ratelimit
requester,
events_and_context=[(event, context)],
extra_users=[target_user],
ratelimit=ratelimit,
)
prev_member_event_id = prev_state_ids.get(
@ -1860,8 +1862,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
result_event = await self.event_creation_handler.handle_new_client_event(
requester,
event,
context,
events_and_context=[(event, context)],
extra_users=[UserID.from_string(target_user)],
)
# we know it was persisted, so must have a stream ordering

View file

@ -25,6 +25,7 @@ from synapse.replication.http import (
push,
register,
send_event,
send_events,
state,
streams,
)
@ -43,6 +44,7 @@ class ReplicationRestResource(JsonResource):
def register_servlets(self, hs: "HomeServer") -> None:
send_event.register_servlets(hs, self)
send_events.register_servlets(hs, self)
federation.register_servlets(hs, self)
presence.register_servlets(hs, self)
membership.register_servlets(hs, self)

View file

@ -141,8 +141,8 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
"Got event to send with ID: %s into room: %s", event.event_id, event.room_id
)
event = await self.event_creation_handler.persist_and_notify_client_event(
requester, event, context, ratelimit=ratelimit, extra_users=extra_users
event = await self.event_creation_handler.persist_and_notify_client_events(
requester, [(event, context)], ratelimit=ratelimit, extra_users=extra_users
)
return (

View file

@ -0,0 +1,171 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List, Tuple
from twisted.web.server import Request
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
class ReplicationSendEventsRestServlet(ReplicationEndpoint):
"""Handles batches of newly created events on workers, including persisting and
notifying.
The API looks like:
POST /_synapse/replication/send_events/:txn_id
{
"events": [{
"event": { .. serialized event .. },
"room_version": .., // "1", "2", "3", etc: the version of the room
// containing the event
"event_format_version": .., // 1,2,3 etc: the event format version
"internal_metadata": { .. serialized internal_metadata .. },
"outlier": true|false,
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
"requester": { .. serialized requester .. },
"ratelimit": true,
}]
}
200 OK
{ "stream_id": 12345, "event_id": "$abcdef..." }
Responds with a 409 when a `PartialStateConflictError` is raised due to an event
context that needs to be recomputed due to the un-partial stating of a room.
"""
NAME = "send_events"
PATH_ARGS = ()
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload( # type: ignore[override]
events_and_context: List[Tuple[EventBase, EventContext]],
store: "DataStore",
requester: Requester,
ratelimit: bool,
extra_users: List[UserID],
) -> JsonDict:
"""
Args:
store
requester
events_and_ctx
ratelimit
"""
serialized_events = []
for event, context in events_and_context:
serialized_context = await context.serialize(event, store)
serialized_event = {
"event": event.get_pdu_json(),
"room_version": event.room_version.identifier,
"event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(),
"outlier": event.internal_metadata.is_outlier(),
"rejected_reason": event.rejected_reason,
"context": serialized_context,
"requester": requester.serialize(),
"ratelimit": ratelimit,
"extra_users": [u.to_string() for u in extra_users],
}
serialized_events.append(serialized_event)
payload = {"events": serialized_events}
return payload
async def _handle_request( # type: ignore[override]
self, request: Request
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_send_events_parse"):
payload = parse_json_object_from_request(request)
events_and_context = []
events = payload["events"]
for event_payload in events:
event_dict = event_payload["event"]
room_ver = KNOWN_ROOM_VERSIONS[event_payload["room_version"]]
internal_metadata = event_payload["internal_metadata"]
rejected_reason = event_payload["rejected_reason"]
event = make_event_from_dict(
event_dict, room_ver, internal_metadata, rejected_reason
)
event.internal_metadata.outlier = event_payload["outlier"]
requester = Requester.deserialize(
self.store, event_payload["requester"]
)
context = EventContext.deserialize(
self._storage_controllers, event_payload["context"]
)
ratelimit = event_payload["ratelimit"]
events_and_context.append((event, context))
extra_users = [
UserID.from_string(u) for u in event_payload["extra_users"]
]
logger.info(
"Got batch of events to send, last ID of batch is: %s, sending into room: %s",
event.event_id,
event.room_id,
)
last_event = (
await self.event_creation_handler.persist_and_notify_client_events(
requester, events_and_context, ratelimit, extra_users
)
)
return (
200,
{
"stream_id": last_event.internal_metadata.stream_ordering,
"event_id": last_event.event_id,
},
)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationSendEventsRestServlet(hs).register(http_server)

View file

@ -105,7 +105,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
event1, context = self._create_duplicate_event(txn_id)
ret_event1 = self.get_success(
self.handler.handle_new_client_event(self.requester, event1, context)
self.handler.handle_new_client_event(
self.requester,
events_and_context=[(event1, context)],
)
)
stream_id1 = ret_event1.internal_metadata.stream_ordering
@ -118,7 +121,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertNotEqual(event1.event_id, event2.event_id)
ret_event2 = self.get_success(
self.handler.handle_new_client_event(self.requester, event2, context)
self.handler.handle_new_client_event(
self.requester,
events_and_context=[(event2, context)],
)
)
stream_id2 = ret_event2.internal_metadata.stream_ordering

View file

@ -497,7 +497,9 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
)
)
self.get_success(
event_creation_handler.handle_new_client_event(requester, event, context)
event_creation_handler.handle_new_client_event(
requester, events_and_context=[(event, context)]
)
)
# Register a second user, which won't be be in the room (or even have an invite)

View file

@ -531,7 +531,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)
self.get_success(
event_handler.handle_new_client_event(self.requester, event, context)
event_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)]
)
)
state1 = set(self.get_success(context.get_current_state_ids()).values())
@ -549,7 +551,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
)
)
self.get_success(
event_handler.handle_new_client_event(self.requester, event, context)
event_handler.handle_new_client_event(
self.requester, events_and_context=[(event, context)]
)
)
state2 = set(self.get_success(context.get_current_state_ids()).values())

View file

@ -734,7 +734,9 @@ class HomeserverTestCase(TestCase):
event.internal_metadata.soft_failed = True
self.get_success(
event_creator.handle_new_client_event(requester, event, context)
event_creator.handle_new_client_event(
requester, events_and_context=[(event, context)]
)
)
return event.event_id