add fucntions to persist events as a batch, encapsulate some logic in a helper function

This commit is contained in:
H. Shay 2022-09-13 10:11:11 -07:00
parent d950d99ab5
commit 99aa2136c2

View file

@ -61,6 +61,7 @@ from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import (
MutableStateMap,
PersistedEventPosition,
Requester,
RoomAlias,
StateMap,
@ -1316,6 +1317,124 @@ class EventCreationHandler:
400, "Cannot start threads from an event with a relation"
)
async def handle_create_room_events(
self,
requester: Requester,
events_and_ctx: List[Tuple[EventBase, EventContext]],
ratelimit: bool = True,
) -> EventBase:
"""
Process a batch of room creation events. For each event in the list it checks
the authorization and that the event can be serialized. Returns the last event in the
list once it has been persisted.
Args:
requester: the room creator
events_and_ctx: a set of events and their associated contexts to persist
ratelimit: whether to ratelimit this request
"""
for event, context in events_and_ctx:
try:
validate_event_for_room_version(event)
await self._event_auth_handler.check_auth_rules_from_context(
event, context
)
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
raise err
# Ensure that we can round trip before trying to persist in db
try:
dump = json_encoder.encode(event.content)
json_decoder.decode(dump)
except Exception:
logger.exception("Failed to encode content: %r", event.content)
raise
# We now persist the events
try:
result = await self._persist_events_batch(
requester, events_and_ctx, ratelimit
)
except Exception as e:
logger.info(f"Encountered an error persisting events: {e}")
return result
async def _persist_events_batch(
self,
requestor: Requester,
events_and_ctx: List[Tuple[EventBase, EventContext]],
ratelimit: bool = True,
) -> EventBase:
"""
Processes the push actions and adds them to the push staging area before attempting to
persist the batch of events.
See handle_create_room_events for arguments
Returns the last event in the list if persisted successfully
"""
for event, context in events_and_ctx:
with opentracing.start_active_span("calculate_push_actions"):
await self._bulk_push_rule_evaluator.action_for_event_by_user(
event, context
)
try:
last_event = await self.persist_and_notify_batched_events(
requestor, events_and_ctx, ratelimit
)
except Exception:
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
for event, _ in events_and_ctx:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
return last_event
async def persist_and_notify_batched_events(
self,
requester: Requester,
events_and_ctx: List[Tuple[EventBase, EventContext]],
ratelimit: bool = True,
) -> EventBase:
"""
Handles the actual persisting of a batch of events to the DB, and sends the appropriate
notifications when this is done.
Args:
requester: the room creator
events_and_ctx: list of events and their associated contexts to persist
ratelimit: whether to apply ratelimiting to this request
"""
if ratelimit:
await self.request_ratelimiter.ratelimit(requester)
for event, context in events_and_ctx:
await self._actions_by_event_type(event, context)
assert self._storage_controllers.persistence is not None
(
persisted_events,
max_stream_token,
) = await self._storage_controllers.persistence.persist_events(events_and_ctx)
stream_ordering = persisted_events[-1].internal_metadata.stream_ordering
assert stream_ordering is not None
pos = PersistedEventPosition(self._instance_name, stream_ordering)
async def _notify() -> None:
try:
await self.notifier.on_new_room_event(
persisted_events[-1], pos, max_stream_token
)
except Exception:
logger.exception(
"Error notifying about new room event %s",
event.event_id,
)
run_in_background(_notify)
return persisted_events[-1]
@measure_func("handle_new_client_event")
async def handle_new_client_event(
self,
@ -1650,6 +1769,55 @@ class EventCreationHandler:
requester, is_admin_redaction=is_admin_redaction
)
# run checks/actions on event based on type
await self._actions_by_event_type(event, context)
# Mark any `m.historical` messages as backfilled so they don't appear
# in `/sync` and have the proper decrementing `stream_ordering` as we import
backfilled = False
if event.internal_metadata.is_historical():
backfilled = True
# Note that this returns the event that was persisted, which may not be
# the same as we passed in if it was deduplicated due transaction IDs.
(
event,
event_pos,
max_stream_token,
) = await self._storage_controllers.persistence.persist_event(
event, context=context, backfilled=backfilled
)
if self._ephemeral_events_enabled:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
async def _notify() -> None:
try:
await self.notifier.on_new_room_event(
event, event_pos, max_stream_token, extra_users=extra_users
)
except Exception:
logger.exception(
"Error notifying about new room event %s",
event.event_id,
)
run_in_background(_notify)
if event.type == EventTypes.Message:
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
return event
async def _actions_by_event_type(
self, event: EventBase, context: EventContext
) -> None:
"""
Helper function to execute actions/checks based on the event type
"""
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
(
current_membership,
@ -1670,11 +1838,13 @@ class EventCreationHandler:
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:
original_event = await self.store.get_event(original_event_id)
original_alias_event = await self.store.get_event(original_event_id)
if original_event:
original_alias = original_event.content.get("alias", None)
original_alt_aliases = original_event.content.get("alt_aliases", [])
if original_alias_event:
original_alias = original_alias_event.content.get("alias", None)
original_alt_aliases = original_alias_event.content.get(
"alt_aliases", []
)
# Check the alias is currently valid (if it has changed).
room_alias_str = event.content.get("alias", None)
@ -1852,46 +2022,6 @@ class EventCreationHandler:
errcode=Codes.INVALID_PARAM,
)
# Mark any `m.historical` messages as backfilled so they don't appear
# in `/sync` and have the proper decrementing `stream_ordering` as we import
backfilled = False
if event.internal_metadata.is_historical():
backfilled = True
# Note that this returns the event that was persisted, which may not be
# the same as we passed in if it was deduplicated due transaction IDs.
(
event,
event_pos,
max_stream_token,
) = await self._storage_controllers.persistence.persist_event(
event, context=context, backfilled=backfilled
)
if self._ephemeral_events_enabled:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
async def _notify() -> None:
try:
await self.notifier.on_new_room_event(
event, event_pos, max_stream_token, extra_users=extra_users
)
except Exception:
logger.exception(
"Error notifying about new room event %s",
event.event_id,
)
run_in_background(_notify)
if event.type == EventTypes.Message:
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
return event
async def _maybe_kick_guest_users(
self, event: EventBase, context: EventContext
) -> None: