Speed up fetching device lists changes in sync.

Currently we copy `users_who_share_room` needlessly about three times,
which is expensive when the set is large (which it can easily be).
This commit is contained in:
Erik Johnston 2020-05-05 17:07:59 +01:00
parent b2dba06079
commit f9073893af
3 changed files with 24 additions and 9 deletions

View file

@ -1143,10 +1143,14 @@ class SyncHandler(object):
user_id user_id
) )
tracked_users = set(users_who_share_room) # Always tell the user about their own devices. We check as the user
# ID is almost certainly already included (unless they're not in any
# rooms) and taking a copy of the set is relatively expensive.
if user_id not in users_who_share_room:
users_who_share_room = set(users_who_share_room)
users_who_share_room.add(user_id)
# Always tell the user about their own devices tracked_users = users_who_share_room
tracked_users.add(user_id)
# Step 1a, check for changes in devices of users we share a room with # Step 1a, check for changes in devices of users we share a room with
users_that_have_changed = await self.store.get_users_whose_devices_changed( users_that_have_changed = await self.store.get_users_whose_devices_changed(

View file

@ -541,8 +541,8 @@ class DeviceWorkerStore(SQLBaseStore):
# Get set of users who *may* have changed. Users not in the returned # Get set of users who *may* have changed. Users not in the returned
# list have definitely not changed. # list have definitely not changed.
to_check = list( to_check = self._device_list_stream_cache.get_entities_changed(
self._device_list_stream_cache.get_entities_changed(user_ids, from_key) user_ids, from_key
) )
if not to_check: if not to_check:

View file

@ -14,12 +14,13 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict, Iterable, List, Mapping, Optional, Set from typing import Dict, FrozenSet, List, Mapping, Optional, Set, Union
from six import integer_types from six import integer_types
from sortedcontainers import SortedDict from sortedcontainers import SortedDict
from synapse.types import Collection
from synapse.util import caches from synapse.util import caches
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -85,8 +86,8 @@ class StreamChangeCache:
return False return False
def get_entities_changed( def get_entities_changed(
self, entities: Iterable[EntityType], stream_pos: int self, entities: Collection[EntityType], stream_pos: int
) -> Set[EntityType]: ) -> Union[Set[EntityType], FrozenSet[EntityType]]:
""" """
Returns subset of entities that have had new things since the given Returns subset of entities that have had new things since the given
position. Entities unknown to the cache will be returned. If the position. Entities unknown to the cache will be returned. If the
@ -94,7 +95,17 @@ class StreamChangeCache:
""" """
changed_entities = self.get_all_entities_changed(stream_pos) changed_entities = self.get_all_entities_changed(stream_pos)
if changed_entities is not None: if changed_entities is not None:
# We now do an intersection, trying to do so in the most efficient
# way possible (some of these sets are *large*). First check in the
# given iterable is already set that we can reuse, otherwise we
# create a set of the *smallest* of the two iterables and call
# `intersection(..)` on it (this can be twice as fast as the reverse).
if isinstance(entities, (set, frozenset)):
result = entities.intersection(changed_entities)
elif len(changed_entities) < len(entities):
result = set(changed_entities).intersection(entities) result = set(changed_entities).intersection(entities)
else:
result = set(entities).intersection(changed_entities)
self.metrics.inc_hits() self.metrics.inc_hits()
else: else:
result = set(entities) result = set(entities)