Move get_state_at() to area we can share from

This commit is contained in:
Eric Eastwood 2024-06-10 14:24:47 -05:00
parent 61f86e0d39
commit 578b44af4c
2 changed files with 92 additions and 90 deletions

View file

@ -981,89 +981,6 @@ class SyncHandler:
bundled_aggregations=bundled_aggregations,
)
async def get_state_after_event(
self,
event_id: str,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]:
"""
Get the room state after the given event
Args:
event_id: event of interest
state_filter: The state filter used to fetch state from the database.
await_full_state: if `True`, will block if we do not yet have complete state
at the event and `state_filter` is not satisfied by partial state.
Defaults to `True`.
"""
state_ids = await self._state_storage_controller.get_state_ids_for_event(
event_id,
state_filter=state_filter or StateFilter.all(),
await_full_state=await_full_state,
)
# using get_metadata_for_events here (instead of get_event) sidesteps an issue
# with redactions: if `event_id` is a redaction event, and we don't have the
# original (possibly because it got purged), get_event will refuse to return
# the redaction event, which isn't terribly helpful here.
#
# (To be fair, in that case we could assume it's *not* a state event, and
# therefore we don't need to worry about it. But still, it seems cleaner just
# to pull the metadata.)
m = (await self.store.get_metadata_for_events([event_id]))[event_id]
if m.state_key is not None and m.rejection_reason is None:
state_ids = dict(state_ids)
state_ids[(m.event_type, m.state_key)] = event_id
return state_ids
async def get_state_at(
self,
room_id: str,
stream_position: StreamToken,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]:
"""Get the room state at a particular stream position
Args:
room_id: room for which to get state
stream_position: point at which to get state
state_filter: The state filter used to fetch state from the database.
await_full_state: if `True`, will block if we do not yet have complete state
at the last event in the room before `stream_position` and
`state_filter` is not satisfied by partial state. Defaults to `True`.
"""
# FIXME: This gets the state at the latest event before the stream ordering,
# which might not be the same as the "current state" of the room at the time
# of the stream token if there were multiple forward extremities at the time.
last_event_id = await self.store.get_last_event_in_room_before_stream_ordering(
room_id,
end_token=stream_position.room_key,
)
if last_event_id:
state = await self.get_state_after_event(
last_event_id,
state_filter=state_filter or StateFilter.all(),
await_full_state=await_full_state,
)
else:
# no events in this room - so presumably no state
state = {}
# (erikj) This should be rarely hit, but we've had some reports that
# we get more state down gappy syncs than we should, so let's add
# some logging.
logger.info(
"Failed to find any events in room %s at %s",
room_id,
stream_position.room_key,
)
return state
async def compute_summary(
self,
room_id: str,
@ -1437,7 +1354,7 @@ class SyncHandler:
await_full_state = True
lazy_load_members = False
state_at_timeline_end = await self.get_state_at(
state_at_timeline_end = await self._state_storage_controller.get_state_at(
room_id,
stream_position=end_token,
state_filter=state_filter,
@ -1565,7 +1482,7 @@ class SyncHandler:
else:
# We can get here if the user has ignored the senders of all
# the recent events.
state_at_timeline_start = await self.get_state_at(
state_at_timeline_start = await self._state_storage_controller.get_state_at(
room_id,
stream_position=end_token,
state_filter=state_filter,
@ -1587,14 +1504,14 @@ class SyncHandler:
# about them).
state_filter = StateFilter.all()
state_at_previous_sync = await self.get_state_at(
state_at_previous_sync = await self._state_storage_controller.get_state_at(
room_id,
stream_position=since_token,
state_filter=state_filter,
await_full_state=await_full_state,
)
state_at_timeline_end = await self.get_state_at(
state_at_timeline_end = await self._state_storage_controller.get_state_at(
room_id,
stream_position=end_token,
state_filter=state_filter,
@ -2593,7 +2510,7 @@ class SyncHandler:
continue
if room_id in sync_result_builder.joined_room_ids or has_join:
old_state_ids = await self.get_state_at(
old_state_ids = await self._state_storage_controller.get_state_at(
room_id,
since_token,
state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]),
@ -2623,7 +2540,7 @@ class SyncHandler:
newly_left_rooms.append(room_id)
else:
if not old_state_ids:
old_state_ids = await self.get_state_at(
old_state_ids = await self._state_storage_controller.get_state_at(
room_id,
since_token,
state_filter=StateFilter.from_types(

View file

@ -45,7 +45,7 @@ from synapse.storage.util.partial_state_events_tracker import (
PartialStateEventsTracker,
)
from synapse.synapse_rust.acl import ServerAclEvaluator
from synapse.types import MutableStateMap, StateMap, get_domain_from_id
from synapse.types import MutableStateMap, StreamToken, StateMap, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
@ -372,6 +372,91 @@ class StateStorageController:
)
return state_map[event_id]
async def get_state_after_event(
self,
event_id: str,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]:
"""
Get the room state after the given event
Args:
event_id: event of interest
state_filter: The state filter used to fetch state from the database.
await_full_state: if `True`, will block if we do not yet have complete state
at the event and `state_filter` is not satisfied by partial state.
Defaults to `True`.
"""
state_ids = await self.get_state_ids_for_event(
event_id,
state_filter=state_filter or StateFilter.all(),
await_full_state=await_full_state,
)
# using get_metadata_for_events here (instead of get_event) sidesteps an issue
# with redactions: if `event_id` is a redaction event, and we don't have the
# original (possibly because it got purged), get_event will refuse to return
# the redaction event, which isn't terribly helpful here.
#
# (To be fair, in that case we could assume it's *not* a state event, and
# therefore we don't need to worry about it. But still, it seems cleaner just
# to pull the metadata.)
m = (await self.stores.main.get_metadata_for_events([event_id]))[event_id]
if m.state_key is not None and m.rejection_reason is None:
state_ids = dict(state_ids)
state_ids[(m.event_type, m.state_key)] = event_id
return state_ids
async def get_state_at(
self,
room_id: str,
stream_position: StreamToken,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]:
"""Get the room state at a particular stream position
Args:
room_id: room for which to get state
stream_position: point at which to get state
state_filter: The state filter used to fetch state from the database.
await_full_state: if `True`, will block if we do not yet have complete state
at the last event in the room before `stream_position` and
`state_filter` is not satisfied by partial state. Defaults to `True`.
"""
# FIXME: This gets the state at the latest event before the stream ordering,
# which might not be the same as the "current state" of the room at the time
# of the stream token if there were multiple forward extremities at the time.
last_event_id = (
await self.stores.main.get_last_event_in_room_before_stream_ordering(
room_id,
end_token=stream_position.room_key,
)
)
if last_event_id:
state = await self.get_state_after_event(
last_event_id,
state_filter=state_filter or StateFilter.all(),
await_full_state=await_full_state,
)
else:
# no events in this room - so presumably no state
state = {}
# (erikj) This should be rarely hit, but we've had some reports that
# we get more state down gappy syncs than we should, so let's add
# some logging.
logger.info(
"Failed to find any events in room %s at %s",
room_id,
stream_position.room_key,
)
return state
@trace
@tag_args
async def get_state_for_groups(