Raise an exception when getting state at an outlier (#12191)

It seems like calling `_get_state_group_for_events` for an event where the
state is unknown is an error. Accordingly, let's raise an exception rather than
silently returning an empty result.
This commit is contained in:
Richard van der Hoff 2022-04-01 13:01:49 +01:00 committed by GitHub
parent 9b43df1f7b
commit 319a805cd3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 87 additions and 22 deletions

1
changelog.d/12191.misc Normal file
View file

@ -0,0 +1 @@
Avoid trying to calculate the state at outlier events.

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, Iterable, Optional, Set, Tuple from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
from frozendict import frozendict from frozendict import frozendict
@ -309,9 +309,13 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
num_args=1, num_args=1,
) )
async def _get_state_group_for_events( async def _get_state_group_for_events(
self, event_ids: Iterable[str] self, event_ids: Collection[str]
) -> Dict[str, int]: ) -> Dict[str, int]:
"""Returns mapping event_id -> state_group""" """Returns mapping event_id -> state_group.
Raises:
RuntimeError if the state is unknown at any of the given events
"""
rows = await self.db_pool.simple_select_many_batch( rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups", table="event_to_state_groups",
column="event_id", column="event_id",
@ -321,7 +325,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="_get_state_group_for_events", desc="_get_state_group_for_events",
) )
return {row["event_id"]: row["state_group"] for row in rows} res = {row["event_id"]: row["state_group"] for row in rows}
for e in event_ids:
if e not in res:
raise RuntimeError("No state group for unknown or outlier event %s" % e)
return res
async def get_referenced_state_groups( async def get_referenced_state_groups(
self, state_groups: Iterable[int] self, state_groups: Iterable[int]

View file

@ -571,6 +571,10 @@ class StateGroupStorage:
Returns: Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id) dict of state_group_id -> (dict of (type, state_key) -> event id)
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
""" """
if not event_ids: if not event_ids:
return {} return {}
@ -659,6 +663,10 @@ class StateGroupStorage:
Returns: Returns:
A dict of (event_id) -> (type, state_key) -> [state_events] A dict of (event_id) -> (type, state_key) -> [state_events]
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
""" """
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
@ -696,6 +704,10 @@ class StateGroupStorage:
Returns: Returns:
A dict from event_id -> (type, state_key) -> event_id A dict from event_id -> (type, state_key) -> event_id
Raises:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
""" """
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids) event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
@ -723,6 +735,10 @@ class StateGroupStorage:
Returns: Returns:
A dict from (type, state_key) -> state_event A dict from (type, state_key) -> state_event
Raises:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
""" """
state_map = await self.get_state_for_events( state_map = await self.get_state_for_events(
[event_id], state_filter or StateFilter.all() [event_id], state_filter or StateFilter.all()
@ -741,6 +757,10 @@ class StateGroupStorage:
Returns: Returns:
A dict from (type, state_key) -> state_event_id A dict from (type, state_key) -> state_event_id
Raises:
RuntimeError if we don't have a state group for the event (ie it is an
outlier or is unknown)
""" """
state_map = await self.get_state_ids_for_events( state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all() [event_id], state_filter or StateFilter.all()

View file

@ -20,17 +20,17 @@ from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase from synapse.events import EventBase, make_event_from_dict
from synapse.federation.federation_base import event_from_pdu_json from synapse.federation.federation_base import event_from_pdu_json
from synapse.logging.context import LoggingContext, run_in_background from synapse.logging.context import LoggingContext, run_in_background
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import create_requester
from synapse.util import Clock from synapse.util import Clock
from synapse.util.stringutils import random_string from synapse.util.stringutils import random_string
from tests import unittest from tests import unittest
from tests.test_utils import event_injection
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,7 +39,7 @@ def generate_fake_event_id() -> str:
return "$fake_" + random_string(43) return "$fake_" + random_string(43)
class FederationTestCase(unittest.HomeserverTestCase): class FederationTestCase(unittest.FederatingHomeserverTestCase):
servlets = [ servlets = [
admin.register_servlets, admin.register_servlets,
login.register_servlets, login.register_servlets,
@ -219,40 +219,76 @@ class FederationTestCase(unittest.HomeserverTestCase):
# create the room # create the room
user_id = self.register_user("kermit", "test") user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test") tok = self.login("kermit", "test")
requester = create_requester(user_id)
room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
room_version = self.get_success(self.store.get_room_version(room_id))
ev1 = self.helper.send(room_id, "first message", tok=tok) # we need a user on the remote server to be a member, so that we can send
# extremity-causing events.
self.get_success(
event_injection.inject_member_event(
self.hs, room_id, f"@user:{self.OTHER_SERVER_NAME}", "join"
)
)
send_result = self.helper.send(room_id, "first message", tok=tok)
ev1 = self.get_success(
self.store.get_event(send_result["event_id"], allow_none=False)
)
current_state = self.get_success(
self.store.get_events_as_list(
(self.get_success(self.store.get_current_state_ids(room_id))).values()
)
)
# Create "many" backward extremities. The magic number we're trying to # Create "many" backward extremities. The magic number we're trying to
# create more than is 5 which corresponds to the number of backward # create more than is 5 which corresponds to the number of backward
# extremities we slice off in `_maybe_backfill_inner` # extremities we slice off in `_maybe_backfill_inner`
federation_event_handler = self.hs.get_federation_event_handler()
for _ in range(0, 8): for _ in range(0, 8):
event_handler = self.hs.get_event_creation_handler() event = make_event_from_dict(
event, context = self.get_success( self.add_hashes_and_signatures(
event_handler.create_event(
requester,
{ {
"origin_server_ts": 1,
"type": "m.room.message", "type": "m.room.message",
"content": { "content": {
"msgtype": "m.text", "msgtype": "m.text",
"body": "message connected to fake event", "body": "message connected to fake event",
}, },
"room_id": room_id, "room_id": room_id,
"sender": user_id, "sender": f"@user:{self.OTHER_SERVER_NAME}",
"prev_events": [
ev1.event_id,
# We're creating an backward extremity each time thanks
# to this fake event
generate_fake_event_id(),
],
# lazy: *everything* is an auth event
"auth_events": [ev.event_id for ev in current_state],
"depth": ev1.depth + 1,
}, },
prev_event_ids=[ room_version,
ev1["event_id"], ),
# We're creating an backward extremity each time thanks room_version,
# to this fake event )
generate_fake_event_id(),
], # we poke this directly into _process_received_pdu, to avoid the
# federation handler wanting to backfill the fake event.
self.get_success(
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME, event, state=current_state
) )
) )
self.get_success(
event_handler.handle_new_client_event(requester, event, context) # we should now have 8 backwards extremities.
backwards_extremities = self.get_success(
self.store.db_pool.simple_select_list(
"event_backward_extremities",
keyvalues={"room_id": room_id},
retcols=["event_id"],
) )
)
self.assertEqual(len(backwards_extremities), 8)
current_depth = 1 current_depth = 1
limit = 100 limit = 100