Better event stream typing

This commit is contained in:
Erik Johnston 2024-01-13 11:18:48 +00:00
parent 69637f8bac
commit c836cb988e
3 changed files with 52 additions and 53 deletions

View file

@ -404,7 +404,7 @@ _DEFAULT_SERIALIZE_EVENT_CONFIG = SerializeEventConfig()
def serialize_event(
e: Union[JsonDict, EventBase],
e: EventBase,
time_now_ms: int,
*,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
@ -420,10 +420,6 @@ def serialize_event(
The serialized event dictionary.
"""
# FIXME(erikj): To handle the case of presence events and the like
if not isinstance(e, EventBase):
return e
time_now_ms = int(time_now_ms)
# Should this strip out None's?
@ -531,7 +527,7 @@ class EventClientSerializer:
async def serialize_event(
self,
event: Union[JsonDict, EventBase],
event: EventBase,
time_now: int,
*,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
@ -549,10 +545,6 @@ class EventClientSerializer:
Returns:
The serialized event
"""
# To handle the case of presence events and the like
if not isinstance(event, EventBase):
return event
serialized_event = serialize_event(event, time_now, config=config)
new_unsigned = {}
@ -656,7 +648,7 @@ class EventClientSerializer:
async def serialize_events(
self,
events: Iterable[Union[JsonDict, EventBase]],
events: Iterable[EventBase],
time_now: int,
*,
config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,

View file

@ -20,7 +20,7 @@
import logging
import random
from typing import TYPE_CHECKING, Iterable, List, Optional
from typing import TYPE_CHECKING, Iterable, List, Optional, cast
from synapse.api.constants import EduTypes, EventTypes, Membership, PresenceState
from synapse.api.errors import AuthError, SynapseError
@ -29,7 +29,7 @@ from synapse.events.utils import SerializeEventConfig
from synapse.handlers.presence import format_user_presence_state
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, UserID
from synapse.types import JsonDict, Requester, StreamKeyType, UserID
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
@ -93,49 +93,54 @@ class EventStreamHandler:
is_guest=requester.is_guest,
explicit_room_id=room_id,
)
events = stream_result.events
events_by_source = stream_result.events_by_source
time_now = self.clock.time_msec()
# When the user joins a new room, or another user joins a currently
# joined room, we need to send down presence for those users.
to_add: List[JsonDict] = []
for event in events:
if not isinstance(event, EventBase):
to_return: List[JsonDict] = []
for keyname, source_events in events_by_source.items():
if keyname != StreamKeyType.ROOM:
e = cast(List[JsonDict], source_events)
to_return.extend(e)
continue
if event.type == EventTypes.Member:
if event.membership != Membership.JOIN:
continue
# Send down presence.
if event.state_key == requester.user.to_string():
# Send down presence for everyone in the room.
users: Iterable[str] = await self.store.get_users_in_room(
event.room_id
events = cast(List[EventBase], source_events)
serialized_events = await self._event_serializer.serialize_events(
events,
time_now,
config=SerializeEventConfig(
as_client_event=as_client_event, requester=requester
),
)
to_return.extend(serialized_events)
for event in events:
if event.type == EventTypes.Member:
if event.membership != Membership.JOIN:
continue
# Send down presence.
if event.state_key == requester.user.to_string():
# Send down presence for everyone in the room.
users: Iterable[str] = await self.store.get_users_in_room(
event.room_id
)
else:
users = [event.state_key]
states = await presence_handler.get_states(users)
to_return.extend(
{
"type": EduTypes.PRESENCE,
"content": format_user_presence_state(state, time_now),
}
for state in states
)
else:
users = [event.state_key]
states = await presence_handler.get_states(users)
to_add.extend(
{
"type": EduTypes.PRESENCE,
"content": format_user_presence_state(state, time_now),
}
for state in states
)
events.extend(to_add)
chunks = await self._event_serializer.serialize_events(
events,
time_now,
config=SerializeEventConfig(
as_client_event=as_client_event, requester=requester
),
)
chunk = {
"chunk": chunks,
"chunk": to_return,
"start": await stream_result.start_token.to_string(self.store),
"end": await stream_result.end_token.to_string(self.store),
}

View file

@ -198,12 +198,12 @@ class _NotifierUserStream:
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventStreamResult:
events: List[Union[JsonDict, EventBase]]
events_by_source: Dict[StreamKeyType, List[Union[JsonDict, EventBase]]]
start_token: StreamToken
end_token: StreamToken
def __bool__(self) -> bool:
return bool(self.events)
return any(bool(e) for e in self.events_by_source.values())
@attr.s(slots=True, frozen=True, auto_attribs=True)
@ -694,12 +694,12 @@ class Notifier:
before_token: StreamToken, after_token: StreamToken
) -> EventStreamResult:
if after_token == before_token:
return EventStreamResult([], from_token, from_token)
return EventStreamResult({}, from_token, from_token)
# The events fetched from each source are a JsonDict, EventBase, or
# UserPresenceState, but see below for UserPresenceState being
# converted to JsonDict.
events: List[Union[JsonDict, EventBase]] = []
events_by_source: Dict[StreamKeyType, List[Union[JsonDict, EventBase]]] = {}
end_token = from_token
for keyname, source in self.event_sources.sources.get_sources():
@ -734,10 +734,12 @@ class Notifier:
for event in new_events
]
events.extend(new_events)
if new_events:
events_by_source.setdefault(keyname, []).extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
return EventStreamResult(events, from_token, end_token)
return EventStreamResult(events_by_source, from_token, end_token)
user_id_for_stream = user.to_string()
if is_peeking: