Better event stream typing

This commit is contained in:
Erik Johnston 2024-01-13 11:18:48 +00:00
parent 69637f8bac
commit c836cb988e
3 changed files with 52 additions and 53 deletions

View file

@ -404,7 +404,7 @@ _DEFAULT_SERIALIZE_EVENT_CONFIG = SerializeEventConfig()
def serialize_event( def serialize_event(
e: Union[JsonDict, EventBase], e: EventBase,
time_now_ms: int, time_now_ms: int,
*, *,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
@ -420,10 +420,6 @@ def serialize_event(
The serialized event dictionary. The serialized event dictionary.
""" """
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase):
return e
time_now_ms = int(time_now_ms) time_now_ms = int(time_now_ms)
# Should this strip out None's? # Should this strip out None's?
@ -531,7 +527,7 @@ class EventClientSerializer:
async def serialize_event( async def serialize_event(
self, self,
event: Union[JsonDict, EventBase], event: EventBase,
time_now: int, time_now: int,
*, *,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
@ -549,10 +545,6 @@ class EventClientSerializer:
Returns: Returns:
The serialized event The serialized event
""" """
# To handle the case of presence events and the like
if not isinstance(event, EventBase):
return event
serialized_event = serialize_event(event, time_now, config=config) serialized_event = serialize_event(event, time_now, config=config)
new_unsigned = {} new_unsigned = {}
@ -656,7 +648,7 @@ class EventClientSerializer:
async def serialize_events( async def serialize_events(
self, self,
events: Iterable[Union[JsonDict, EventBase]], events: Iterable[EventBase],
time_now: int, time_now: int,
*, *,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG, config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,

View file

@ -20,7 +20,7 @@
import logging import logging
import random import random
from typing import TYPE_CHECKING, Iterable, List, Optional from typing import TYPE_CHECKING, Iterable, List, Optional, cast
from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState
from synapse.api.errors import AuthError, SynapseError from synapse.api.errors import AuthError, SynapseError
@ -29,7 +29,7 @@ from synapse.events.utils import SerializeEventConfig
from synapse.handlers.presence import format_user_presence_state from synapse.handlers.presence import format_user_presence_state
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.streams.config import PaginationConfig from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, UserID from synapse.types import JsonDict, Requester, StreamKeyType, UserID
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
if TYPE_CHECKING: if TYPE_CHECKING:
@ -93,49 +93,54 @@ class EventStreamHandler:
is_guest=requester.is_guest, is_guest=requester.is_guest,
explicit_room_id=room_id, explicit_room_id=room_id,
) )
events = stream_result.events events_by_source = stream_result.events_by_source
time_now = self.clock.time_msec() time_now = self.clock.time_msec()
# When the user joins a new room, or another user joins a currently # When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users. # joined room, we need to send down presence for those users.
to_add: List[JsonDict] = [] to_return: List[JsonDict] = []
for event in events: for keyname, source_events in events_by_source.items():
if not isinstance(event, EventBase): if keyname != StreamKeyType.ROOM:
e = cast(List[JsonDict], source_events)
to_return.extend(e)
continue continue
if event.type == EventTypes.Member:
if event.membership != Membership.JOIN: events = cast(List[EventBase], source_events)
continue
# Send down presence. serialized_events = await self._event_serializer.serialize_events(
if event.state_key == requester.user.to_string(): events,
# Send down presence for everyone in the room. time_now,
users: Iterable[str] = await self.store.get_users_in_room( config=SerializeEventConfig(
event.room_id as_client_event=as_client_event, requester=requester
),
)
to_return.extend(serialized_events)
for event in events:
if event.type == EventTypes.Member:
if event.membership != Membership.JOIN:
continue
# Send down presence.
if event.state_key == requester.user.to_string():
# Send down presence for everyone in the room.
users: Iterable[str] = await self.store.get_users_in_room(
event.room_id
)
else:
users = [event.state_key]
states = await presence_handler.get_states(users)
to_return.extend(
{
"type": EduTypes.PRESENCE,
"content": format_user_presence_state(state, time_now),
}
for state in states
) )
else:
users = [event.state_key]
states = await presence_handler.get_states(users)
to_add.extend(
{
"type": EduTypes.PRESENCE,
"content": format_user_presence_state(state, time_now),
}
for state in states
)
events.extend(to_add)
chunks = await self._event_serializer.serialize_events(
events,
time_now,
config=SerializeEventConfig(
as_client_event=as_client_event, requester=requester
),
)
chunk = { chunk = {
"chunk": chunks, "chunk": to_return,
"start": await stream_result.start_token.to_string(self.store), "start": await stream_result.start_token.to_string(self.store),
"end": await stream_result.end_token.to_string(self.store), "end": await stream_result.end_token.to_string(self.store),
} }

View file

@ -198,12 +198,12 @@ class _NotifierUserStream:
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class EventStreamResult: class EventStreamResult:
events: List[Union[JsonDict, EventBase]] events_by_source: Dict[StreamKeyType, List[Union[JsonDict, EventBase]]]
start_token: StreamToken start_token: StreamToken
end_token: StreamToken end_token: StreamToken
def __bool__(self) -> bool: def __bool__(self) -> bool:
return bool(self.events) return any(bool(e) for e in self.events_by_source.values())
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
@ -694,12 +694,12 @@ class Notifier:
before_token: StreamToken, after_token: StreamToken before_token: StreamToken, after_token: StreamToken
) -> EventStreamResult: ) -> EventStreamResult:
if after_token == before_token: if after_token == before_token:
return EventStreamResult([], from_token, from_token) return EventStreamResult({}, from_token, from_token)
# The events fetched from each source are a JsonDict, EventBase, or # The events fetched from each source are a JsonDict, EventBase, or
# UserPresenceState, but see below for UserPresenceState being # UserPresenceState, but see below for UserPresenceState being
# converted to JsonDict. # converted to JsonDict.
events: List[Union[JsonDict, EventBase]] = [] events_by_source: Dict[StreamKeyType, List[Union[JsonDict, EventBase]]] = {}
end_token = from_token end_token = from_token
for keyname, source in self.event_sources.sources.get_sources(): for keyname, source in self.event_sources.sources.get_sources():
@ -734,10 +734,12 @@ class Notifier:
for event in new_events for event in new_events
] ]
events.extend(new_events) if new_events:
events_by_source.setdefault(keyname, []).extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key) end_token = end_token.copy_and_replace(keyname, new_key)
return EventStreamResult(events, from_token, end_token) return EventStreamResult(events_by_source, from_token, end_token)
user_id_for_stream = user.to_string() user_id_for_stream = user.to_string()
if is_peeking: if is_peeking: