From c836cb988ec61b12b7cccc92776e1c473c589151 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Sat, 13 Jan 2024 11:18:48 +0000 Subject: [PATCH] Better event stream typing --- synapse/events/utils.py | 14 ++----- synapse/handlers/events.py | 77 ++++++++++++++++++++------------------ synapse/notifier.py | 14 ++++--- 3 files changed, 52 insertions(+), 53 deletions(-) diff --git a/synapse/events/utils.py b/synapse/events/utils.py index cb7ebc31e7..1ccb63c7be 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -404,7 +404,7 @@ _DEFAULT_SERIALIZE_EVENT_CONFIG = SerializeEventConfig() def serialize_event( - e: Union[JsonDict, EventBase], + e: EventBase, time_now_ms: int, *, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, @@ -420,10 +420,6 @@ def serialize_event( The serialized event dictionary. """ - # FIXME(erikj): To handle the case of presence events and the like - if not isinstance(e, EventBase): - return e - time_now_ms = int(time_now_ms) # Should this strip out None's? @@ -531,7 +527,7 @@ class EventClientSerializer: async def serialize_event( self, - event: Union[JsonDict, EventBase], + event: EventBase, time_now: int, *, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, @@ -549,10 +545,6 @@ class EventClientSerializer: Returns: The serialized event """ - # To handle the case of presence events and the like - if not isinstance(event, EventBase): - return event - serialized_event = serialize_event(event, time_now, config=config) new_unsigned = {} @@ -656,7 +648,7 @@ class EventClientSerializer: async def serialize_events( self, - events: Iterable[Union[JsonDict, EventBase]], + events: Iterable[EventBase], time_now: int, *, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index 36404d9c78..aa4d3f2e9e 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -20,7 +20,7 @@ import logging import random -from typing import TYPE_CHECKING, Iterable, List, Optional +from typing import TYPE_CHECKING, Iterable, List, Optional, cast from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState from synapse.api.errors import AuthError, SynapseError @@ -29,7 +29,7 @@ from synapse.events.utils import SerializeEventConfig from synapse.handlers.presence import format_user_presence_state from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, UserID +from synapse.types import JsonDict, Requester, StreamKeyType, UserID from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -93,49 +93,54 @@ class EventStreamHandler: is_guest=requester.is_guest, explicit_room_id=room_id, ) - events = stream_result.events + events_by_source = stream_result.events_by_source time_now = self.clock.time_msec() # When the user joins a new room, or another user joins a currently # joined room, we need to send down presence for those users. - to_add: List[JsonDict] = [] - for event in events: - if not isinstance(event, EventBase): + to_return: List[JsonDict] = [] + for keyname, source_events in events_by_source.items(): + if keyname != StreamKeyType.ROOM: + e = cast(List[JsonDict], source_events) + to_return.extend(e) continue - if event.type == EventTypes.Member: - if event.membership != Membership.JOIN: - continue - # Send down presence. - if event.state_key == requester.user.to_string(): - # Send down presence for everyone in the room. - users: Iterable[str] = await self.store.get_users_in_room( - event.room_id + + events = cast(List[EventBase], source_events) + + serialized_events = await self._event_serializer.serialize_events( + events, + time_now, + config=SerializeEventConfig( + as_client_event=as_client_event, requester=requester + ), + ) + to_return.extend(serialized_events) + + for event in events: + if event.type == EventTypes.Member: + if event.membership != Membership.JOIN: + continue + # Send down presence. + if event.state_key == requester.user.to_string(): + # Send down presence for everyone in the room. + users: Iterable[str] = await self.store.get_users_in_room( + event.room_id + ) + else: + users = [event.state_key] + + states = await presence_handler.get_states(users) + to_return.extend( + { + "type": EduTypes.PRESENCE, + "content": format_user_presence_state(state, time_now), + } + for state in states ) - else: - users = [event.state_key] - - states = await presence_handler.get_states(users) - to_add.extend( - { - "type": EduTypes.PRESENCE, - "content": format_user_presence_state(state, time_now), - } - for state in states - ) - - events.extend(to_add) - - chunks = await self._event_serializer.serialize_events( - events, - time_now, - config=SerializeEventConfig( - as_client_event=as_client_event, requester=requester - ), - ) chunk = { - "chunk": chunks, + "chunk": to_return, "start": await stream_result.start_token.to_string(self.store), "end": await stream_result.end_token.to_string(self.store), } diff --git a/synapse/notifier.py b/synapse/notifier.py index dec47add7e..4213fd7c6f 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -198,12 +198,12 @@ class _NotifierUserStream: @attr.s(slots=True, frozen=True, auto_attribs=True) class EventStreamResult: - events: List[Union[JsonDict, EventBase]] + events_by_source: Dict[StreamKeyType, List[Union[JsonDict, EventBase]]] start_token: StreamToken end_token: StreamToken def __bool__(self) -> bool: - return bool(self.events) + return any(bool(e) for e in self.events_by_source.values()) @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -694,12 +694,12 @@ class Notifier: before_token: StreamToken, after_token: StreamToken ) -> EventStreamResult: if after_token == before_token: - return EventStreamResult([], from_token, from_token) + return EventStreamResult({}, from_token, from_token) # The events fetched from each source are a JsonDict, EventBase, or # UserPresenceState, but see below for UserPresenceState being # converted to JsonDict. - events: List[Union[JsonDict, EventBase]] = [] + events_by_source: Dict[StreamKeyType, List[Union[JsonDict, EventBase]]] = {} end_token = from_token for keyname, source in self.event_sources.sources.get_sources(): @@ -734,10 +734,12 @@ class Notifier: for event in new_events ] - events.extend(new_events) + if new_events: + events_by_source.setdefault(keyname, []).extend(new_events) + end_token = end_token.copy_and_replace(keyname, new_key) - return EventStreamResult(events, from_token, end_token) + return EventStreamResult(events_by_source, from_token, end_token) user_id_for_stream = user.to_string() if is_peeking: