Add optimisation to StreamChangeCache (#17130)

When there have been lots of changes compared with the number of
entities, we can do a fast(er) path.

Locally I ran some benchmarking, and the comparison seems to give the
best determination of which method we use.
This commit is contained in:
Erik Johnston 2024-05-06 12:56:52 +01:00 committed by GitHub
parent 7c9ac01eb5
commit 3e6ee8ff88
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 34 additions and 4 deletions

1
changelog.d/17130.misc Normal file
View file

@ -0,0 +1 @@
Add optimisation to `StreamChangeCache.get_entities_changed(..)`.

View file

@ -165,7 +165,7 @@ class StreamChangeCache:
return False return False
def get_entities_changed( def get_entities_changed(
self, entities: Collection[EntityType], stream_pos: int self, entities: Collection[EntityType], stream_pos: int, _perf_factor: int = 1
) -> Union[Set[EntityType], FrozenSet[EntityType]]: ) -> Union[Set[EntityType], FrozenSet[EntityType]]:
""" """
Returns the subset of the given entities that have had changes after the given position. Returns the subset of the given entities that have had changes after the given position.
@ -177,6 +177,8 @@ class StreamChangeCache:
Args: Args:
entities: Entities to check for changes. entities: Entities to check for changes.
stream_pos: The stream position to check for changes after. stream_pos: The stream position to check for changes after.
_perf_factor: Used by unit tests to choose when to use each
optimisation.
Return: Return:
A subset of entities which have changed after the given stream position. A subset of entities which have changed after the given stream position.
@ -184,6 +186,22 @@ class StreamChangeCache:
This will be all entities if the given stream position is at or earlier This will be all entities if the given stream position is at or earlier
than the earliest known stream position. than the earliest known stream position.
""" """
if not self._cache or stream_pos <= self._earliest_known_stream_pos:
self.metrics.inc_misses()
return set(entities)
# If there have been tonnes of changes compared with the number of
# entities, it is faster to check each entities stream ordering
# one-by-one.
max_stream_pos, _ = self._cache.peekitem()
if max_stream_pos - stream_pos > _perf_factor * len(entities):
self.metrics.inc_hits()
return {
entity
for entity in entities
if self._entity_to_key.get(entity, -1) > stream_pos
}
cache_result = self.get_all_entities_changed(stream_pos) cache_result = self.get_all_entities_changed(stream_pos)
if cache_result.hit: if cache_result.hit:
# We now do an intersection, trying to do so in the most efficient # We now do an intersection, trying to do so in the most efficient

View file

@ -1,3 +1,5 @@
from parameterized import parameterized
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
from tests import unittest from tests import unittest
@ -161,7 +163,8 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertFalse(cache.has_any_entity_changed(2)) self.assertFalse(cache.has_any_entity_changed(2))
self.assertFalse(cache.has_any_entity_changed(3)) self.assertFalse(cache.has_any_entity_changed(3))
def test_get_entities_changed(self) -> None: @parameterized.expand([(0,), (1000000000,)])
def test_get_entities_changed(self, perf_factor: int) -> None:
""" """
StreamChangeCache.get_entities_changed will return the entities in the StreamChangeCache.get_entities_changed will return the entities in the
given list that have changed since the provided stream ID. If the given list that have changed since the provided stream ID. If the
@ -178,7 +181,9 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# get the ones after that point. # get the ones after that point.
self.assertEqual( self.assertEqual(
cache.get_entities_changed( cache.get_entities_changed(
["user@foo.com", "bar@baz.net", "user@elsewhere.org"], stream_pos=2 ["user@foo.com", "bar@baz.net", "user@elsewhere.org"],
stream_pos=2,
_perf_factor=perf_factor,
), ),
{"bar@baz.net", "user@elsewhere.org"}, {"bar@baz.net", "user@elsewhere.org"},
) )
@ -195,6 +200,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
"not@here.website", "not@here.website",
], ],
stream_pos=2, stream_pos=2,
_perf_factor=perf_factor,
), ),
{"bar@baz.net", "user@elsewhere.org"}, {"bar@baz.net", "user@elsewhere.org"},
) )
@ -210,6 +216,7 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
"not@here.website", "not@here.website",
], ],
stream_pos=0, stream_pos=0,
_perf_factor=perf_factor,
), ),
{"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"}, {"user@foo.com", "bar@baz.net", "user@elsewhere.org", "not@here.website"},
) )
@ -217,7 +224,11 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
# Query a subset of the entries mid-way through the stream. We should # Query a subset of the entries mid-way through the stream. We should
# only get back the subset. # only get back the subset.
self.assertEqual( self.assertEqual(
cache.get_entities_changed(["bar@baz.net"], stream_pos=2), cache.get_entities_changed(
["bar@baz.net"],
stream_pos=2,
_perf_factor=perf_factor,
),
{"bar@baz.net"}, {"bar@baz.net"},
) )