Fix bug where sync could get stuck when using workers (#17438)

This is because we serialized the token wrong if the instance map
contained entries from before the minimum token.
This commit is contained in:
Erik Johnston 2024-07-15 16:13:04 +01:00 committed by GitHub
parent d88ba45db9
commit df11af14db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 138 additions and 10 deletions

1
changelog.d/17438.bugfix Normal file
View file

@ -0,0 +1 @@
Fix rare bug where `/sync` would break for a user when using workers with multiple stream writers.

View file

@ -699,10 +699,17 @@ class SlidingSyncHandler:
instance_to_max_stream_ordering_map[instance_name] = stream_ordering instance_to_max_stream_ordering_map[instance_name] = stream_ordering
# Then assemble the `RoomStreamToken` # Then assemble the `RoomStreamToken`
min_stream_pos = min(instance_to_max_stream_ordering_map.values())
membership_snapshot_token = RoomStreamToken( membership_snapshot_token = RoomStreamToken(
# Minimum position in the `instance_map` # Minimum position in the `instance_map`
stream=min(instance_to_max_stream_ordering_map.values()), stream=min_stream_pos,
instance_map=immutabledict(instance_to_max_stream_ordering_map), instance_map=immutabledict(
{
instance_name: stream_pos
for instance_name, stream_pos in instance_to_max_stream_ordering_map.items()
if stream_pos > min_stream_pos
}
),
) )
# Since we fetched the users room list at some point in time after the from/to # Since we fetched the users room list at some point in time after the from/to

View file

@ -20,6 +20,7 @@
# #
# #
import abc import abc
import logging
import re import re
import string import string
from enum import Enum from enum import Enum
@ -74,6 +75,9 @@ if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore, PurgeEventsStore from synapse.storage.databases.main import DataStore, PurgeEventsStore
from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore
logger = logging.getLogger(__name__)
# Define a state map type from type/state_key to T (usually an event ID or # Define a state map type from type/state_key to T (usually an event ID or
# event) # event)
T = TypeVar("T") T = TypeVar("T")
@ -454,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
represented by a default `stream` attribute and a map of instance name to represented by a default `stream` attribute and a map of instance name to
stream position of any writers that are ahead of the default stream stream position of any writers that are ahead of the default stream
position. position.
The values in `instance_map` must be greater than the `stream` attribute.
""" """
stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True) stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True)
@ -468,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
kw_only=True, kw_only=True,
) )
def __attrs_post_init__(self) -> None:
# Enforce that all instances have a value greater than the min stream
# position.
for i, v in self.instance_map.items():
if v <= self.stream:
raise ValueError(
f"'instance_map' includes a stream position before the main 'stream' attribute. Instance: {i}"
)
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
async def parse(cls, store: "DataStore", string: str) -> "Self": async def parse(cls, store: "DataStore", string: str) -> "Self":
@ -494,6 +509,9 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
for instance in set(self.instance_map).union(other.instance_map) for instance in set(self.instance_map).union(other.instance_map)
} }
# Filter out any redundant entries.
instance_map = {i: s for i, s in instance_map.items() if s > max_stream}
return attr.evolve( return attr.evolve(
self, stream=max_stream, instance_map=immutabledict(instance_map) self, stream=max_stream, instance_map=immutabledict(instance_map)
) )
@ -539,10 +557,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta):
def bound_stream_token(self, max_stream: int) -> "Self": def bound_stream_token(self, max_stream: int) -> "Self":
"""Bound the stream positions to a maximum value""" """Bound the stream positions to a maximum value"""
min_pos = min(self.stream, max_stream)
return type(self)( return type(self)(
stream=min(self.stream, max_stream), stream=min_pos,
instance_map=immutabledict( instance_map=immutabledict(
{k: min(s, max_stream) for k, s in self.instance_map.items()} {
k: min(s, max_stream)
for k, s in self.instance_map.items()
if min(s, max_stream) > min_pos
}
), ),
) )
@ -637,6 +660,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
"Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'." "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'."
) )
super().__attrs_post_init__()
@classmethod @classmethod
async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
try: try:
@ -651,6 +676,11 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
instance_map = {} instance_map = {}
for part in parts[1:]: for part in parts[1:]:
if not part:
# Handle tokens of the form `m5~`, which were created by
# a bug
continue
key, value = part.split(".") key, value = part.split(".")
instance_id = int(key) instance_id = int(key)
pos = int(value) pos = int(value)
@ -666,7 +696,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
except CancelledError: except CancelledError:
raise raise
except Exception: except Exception:
pass # We log an exception here as even though this *might* be a client
# handing a bad token, its more likely that Synapse returned a bad
# token (and we really want to catch those!).
logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid room stream token %r" % (string,)) raise SynapseError(400, "Invalid room stream token %r" % (string,))
@classmethod @classmethod
@ -713,6 +746,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
return self.instance_map.get(instance_name, self.stream) return self.instance_map.get(instance_name, self.stream)
async def to_string(self, store: "DataStore") -> str: async def to_string(self, store: "DataStore") -> str:
"""See class level docstring for information about the format."""
if self.topological is not None: if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream) return "t%d-%d" % (self.topological, self.stream)
elif self.instance_map: elif self.instance_map:
@ -727,8 +762,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken):
instance_id = await store.get_id_for_instance(name) instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}") entries.append(f"{instance_id}.{pos}")
encoded_map = "~".join(entries) if entries:
return f"m{self.stream}~{encoded_map}" encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
return f"s{self.stream}"
else: else:
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
@ -756,6 +793,11 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
instance_map = {} instance_map = {}
for part in parts[1:]: for part in parts[1:]:
if not part:
# Handle tokens of the form `m5~`, which were created by
# a bug
continue
key, value = part.split(".") key, value = part.split(".")
instance_id = int(key) instance_id = int(key)
pos = int(value) pos = int(value)
@ -770,10 +812,15 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
except CancelledError: except CancelledError:
raise raise
except Exception: except Exception:
pass # We log an exception here as even though this *might* be a client
# handing a bad token, its more likely that Synapse returned a bad
# token (and we really want to catch those!).
logger.exception("Failed to parse stream token: %r", string)
raise SynapseError(400, "Invalid stream token %r" % (string,)) raise SynapseError(400, "Invalid stream token %r" % (string,))
async def to_string(self, store: "DataStore") -> str: async def to_string(self, store: "DataStore") -> str:
"""See class level docstring for information about the format."""
if self.instance_map: if self.instance_map:
entries = [] entries = []
for name, pos in self.instance_map.items(): for name, pos in self.instance_map.items():
@ -786,8 +833,10 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken):
instance_id = await store.get_id_for_instance(name) instance_id = await store.get_id_for_instance(name)
entries.append(f"{instance_id}.{pos}") entries.append(f"{instance_id}.{pos}")
encoded_map = "~".join(entries) if entries:
return f"m{self.stream}~{encoded_map}" encoded_map = "~".join(entries)
return f"m{self.stream}~{encoded_map}"
return str(self.stream)
else: else:
return str(self.stream) return str(self.stream)

View file

@ -19,9 +19,18 @@
# #
# #
from typing import Type
from unittest import skipUnless
from immutabledict import immutabledict
from parameterized import parameterized_class
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.types import ( from synapse.types import (
AbstractMultiWriterStreamToken,
MultiWriterStreamToken,
RoomAlias, RoomAlias,
RoomStreamToken,
UserID, UserID,
get_domain_from_id, get_domain_from_id,
get_localpart_from_id, get_localpart_from_id,
@ -29,6 +38,7 @@ from synapse.types import (
) )
from tests import unittest from tests import unittest
from tests.utils import USE_POSTGRES_FOR_TESTS
class IsMineIDTests(unittest.HomeserverTestCase): class IsMineIDTests(unittest.HomeserverTestCase):
@ -127,3 +137,64 @@ class MapUsernameTestCase(unittest.TestCase):
# this should work with either a unicode or a bytes # this should work with either a unicode or a bytes
self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast") self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast")
self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast") self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast")
@parameterized_class(
("token_type",),
[
(MultiWriterStreamToken,),
(RoomStreamToken,),
],
class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}",
)
class MultiWriterTokenTestCase(unittest.HomeserverTestCase):
"""Tests for the different types of multi writer tokens."""
token_type: Type[AbstractMultiWriterStreamToken]
def test_basic_token(self) -> None:
"""Test that a simple stream token can be serialized and unserialized"""
store = self.hs.get_datastores().main
token = self.token_type(stream=5)
string_token = self.get_success(token.to_string(store))
if isinstance(token, RoomStreamToken):
self.assertEqual(string_token, "s5")
else:
self.assertEqual(string_token, "5")
parsed_token = self.get_success(self.token_type.parse(store, string_token))
self.assertEqual(parsed_token, token)
@skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres")
def test_instance_map(self) -> None:
"""Test for stream token with instance map"""
store = self.hs.get_datastores().main
token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6}))
string_token = self.get_success(token.to_string(store))
self.assertEqual(string_token, "m5~1.6")
parsed_token = self.get_success(self.token_type.parse(store, string_token))
self.assertEqual(parsed_token, token)
def test_instance_map_assertion(self) -> None:
"""Test that we assert values in the instance map are greater than the
min stream position"""
with self.assertRaises(ValueError):
self.token_type(stream=5, instance_map=immutabledict({"foo": 4}))
with self.assertRaises(ValueError):
self.token_type(stream=5, instance_map=immutabledict({"foo": 5}))
def test_parse_bad_token(self) -> None:
"""Test that we can parse tokens produced by a bug in Synapse of the
form `m5~`"""
store = self.hs.get_datastores().main
parsed_token = self.get_success(self.token_type.parse(store, "m5~"))
self.assertEqual(parsed_token, self.token_type(stream=5))