Stream change cache

This commit is contained in:
Erik Johnston 2024-06-19 16:22:54 +01:00
parent a45f1be28c
commit ce7b1d3a21
2 changed files with 55 additions and 39 deletions

View file

@ -121,6 +121,12 @@ class ReplicationDataHandler:
)
self.store.device_lists_in_rooms_have_changed(all_room_ids, token)
# If we're sending federation we need to update the device lists
# outbound pokes stream change cache with updated hosts.
if self.send_handler and any(row.hosts_calculated for row in rows):
hosts = await self.store.get_destinations_for_device(token)
self.store.device_lists_outbound_pokes_have_changed(hosts, token)
self.store.process_replication_rows(stream_name, instance_name, token, rows)
# NOTE: this must be called after process_replication_rows to ensure any
# cache invalidations are first handled before any stream ID advances.

View file

@ -164,22 +164,24 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
prefilled_cache=user_signature_stream_prefill,
)
(
device_list_federation_prefill,
device_list_federation_list_id,
) = self.db_pool.get_cache_dict(
db_conn,
"device_lists_outbound_pokes",
entity_column="destination",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache",
device_list_federation_list_id,
prefilled_cache=device_list_federation_prefill,
)
self._device_list_federation_stream_cache = None
if hs.get_federation_sender() is not None:
(
device_list_federation_prefill,
device_list_federation_list_id,
) = self.db_pool.get_cache_dict(
db_conn,
"device_lists_outbound_pokes",
entity_column="destination",
stream_column="stream_id",
max_value=device_list_max,
limit=10000,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache",
device_list_federation_list_id,
prefilled_cache=device_list_federation_prefill,
)
if hs.config.worker.run_background_tasks:
self._clock.looping_call(
@ -221,11 +223,15 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
(row.user_id,)
)
else:
# self._device_list_federation_stream_cache.entity_has_changed(
# row.entity, token
# )
pass
def device_lists_outbound_pokes_have_changed(
self, destinations: StrCollection, token: int
) -> None:
assert self._device_list_federation_stream_cache is not None
for destination in destinations:
self._device_list_federation_stream_cache.entity_has_changed(
destination, token
)
def device_lists_in_rooms_have_changed(
self, room_ids: StrCollection, token: int
@ -369,18 +375,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
if from_stream_id == now_stream_id:
return now_stream_id, []
# has_changed = self._device_list_federation_stream_cache.has_entity_changed(
# destination, int(from_stream_id)
# )
# if not has_changed:
# # debugging for https://github.com/matrix-org/synapse/issues/14251
# issue_8631_logger.debug(
# "%s: no change between %i and %i",
# destination,
# from_stream_id,
# now_stream_id,
# )
# return now_stream_id, []
if self._device_list_federation_stream_cache is None:
raise Exception("Func can only be used on federation senders")
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
)
if not has_changed:
# debugging for https://github.com/matrix-org/synapse/issues/14251
issue_8631_logger.debug(
"%s: no change between %i and %i",
destination,
from_stream_id,
now_stream_id,
)
return now_stream_id, []
updates = await self.db_pool.runInteraction(
"get_device_updates_by_remote",
@ -2125,12 +2134,13 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
stream_ids: List[int],
context: Optional[Dict[str, str]],
) -> None:
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_ids[-1],
)
if self._device_list_federation_stream_cache:
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host,
stream_ids[-1],
)
now = self._clock.time_msec()
stream_id_iterator = iter(stream_ids)