This commit is contained in:
Erik Johnston 2024-06-14 16:07:32 +01:00
parent 2813522704
commit d84b438fc6
3 changed files with 37 additions and 30 deletions

View file

@ -114,7 +114,7 @@ class ReplicationDataHandler:
""" """
all_room_ids: Set[str] = set() all_room_ids: Set[str] = set()
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
if any(row.entity.startswith("@") and not row.is_signature for row in rows): if any(not row.is_signature and not row.hosts_calculated for row in rows):
prev_token = self.store.get_device_stream_token() prev_token = self.store.get_device_stream_token()
all_room_ids = await self.store.get_all_device_list_changes( all_room_ids = await self.store.get_all_device_list_changes(
prev_token, token prev_token, token
@ -149,7 +149,7 @@ class ReplicationDataHandler:
) )
await self._pusher_pool.on_new_receipts({row.user_id for row in rows}) await self._pusher_pool.on_new_receipts({row.user_id for row in rows})
elif stream_name == ToDeviceStream.NAME: elif stream_name == ToDeviceStream.NAME:
entities = [row.entity for row in rows if row.entity.startswith("@")] entities = [row.user_id for row in rows if not row.hosts_calculated]
if entities: if entities:
self.notifier.on_new_event( self.notifier.on_new_event(
StreamKeyType.TO_DEVICE, token, users=entities StreamKeyType.TO_DEVICE, token, users=entities
@ -433,11 +433,7 @@ class FederationSenderHandler:
# The entities are either user IDs (starting with '@') whose devices # The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about # have changed, or remote servers that we need to tell about
# changes. # changes.
hosts = { hosts = await self.store.get_destinations_for_device(token)
row.entity
for row in rows
if not row.entity.startswith("@") and not row.is_signature
}
await self.federation_sender.send_device_messages(hosts, immediate=False) await self.federation_sender.send_device_messages(hosts, immediate=False)
elif stream_name == ToDeviceStream.NAME: elif stream_name == ToDeviceStream.NAME:

View file

@ -207,22 +207,25 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
) -> None: ) -> None:
for row in rows: for row in rows:
if row.is_signature: if row.is_signature:
self._user_signature_stream_cache.entity_has_changed(row.entity, token) self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
continue continue
# The entities are either user IDs (starting with '@') whose devices # The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about # have changed, or remote servers that we need to tell about
# changes. # changes.
if row.entity.startswith("@"): if not row.hosts_calculated:
self._device_list_stream_cache.entity_has_changed(row.entity, token) self._device_list_stream_cache.entity_has_changed(row.user_id, token)
self.get_cached_devices_for_user.invalidate((row.entity,)) self.get_cached_devices_for_user.invalidate((row.user_id,))
self._get_cached_user_device.invalidate((row.entity,)) self._get_cached_user_device.invalidate((row.user_id,))
self.get_device_list_last_stream_id_for_remote.invalidate((row.entity,)) self.get_device_list_last_stream_id_for_remote.invalidate(
(row.user_id,)
)
else: else:
self._device_list_federation_stream_cache.entity_has_changed( # self._device_list_federation_stream_cache.entity_has_changed(
row.entity, token # row.entity, token
) # )
pass
def device_lists_in_rooms_have_changed( def device_lists_in_rooms_have_changed(
self, room_ids: StrCollection, token: int self, room_ids: StrCollection, token: int
@ -364,18 +367,18 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
""" """
now_stream_id = self.get_device_stream_token() now_stream_id = self.get_device_stream_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed( # has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id) # destination, int(from_stream_id)
) # )
if not has_changed: # if not has_changed:
# debugging for https://github.com/matrix-org/synapse/issues/14251 # # debugging for https://github.com/matrix-org/synapse/issues/14251
issue_8631_logger.debug( # issue_8631_logger.debug(
"%s: no change between %i and %i", # "%s: no change between %i and %i",
destination, # destination,
from_stream_id, # from_stream_id,
now_stream_id, # now_stream_id,
) # )
return now_stream_id, [] # return now_stream_id, []
updates = await self.db_pool.runInteraction( updates = await self.db_pool.runInteraction(
"get_device_updates_by_remote", "get_device_updates_by_remote",
@ -1577,6 +1580,14 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
get_device_list_changes_in_room_txn, get_device_list_changes_in_room_txn,
) )
async def get_destinations_for_device(self, stream_id: int) -> StrCollection:
return await self.db_pool.simple_select_onecol(
table="device_lists_outbound_pokes",
keyvalues={"stream_id": stream_id},
retcol="destination",
desc="get_destinations_for_device",
)
class DeviceBackgroundUpdateStore(SQLBaseStore): class DeviceBackgroundUpdateStore(SQLBaseStore):
def __init__( def __init__(

View file

@ -123,9 +123,9 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
for row in rows: for row in rows:
assert isinstance(row, DeviceListsStream.DeviceListsStreamRow) assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
if row.entity.startswith("@"): if not row.hosts_calculated:
self._get_e2e_device_keys_for_federation_query_inner.invalidate( self._get_e2e_device_keys_for_federation_query_inner.invalidate(
(row.entity,) (row.user_id,)
) )
super().process_replication_rows(stream_name, instance_name, token, rows) super().process_replication_rows(stream_name, instance_name, token, rows)