Compare commits

...

5 commits

Author SHA1 Message Date
Erik Johnston a7c7c190b7
Merge f237303c32 into 3aae60f17b 2024-06-15 09:34:14 -07:00
Richard van der Hoff 3aae60f17b
Enable cross-signing key upload without UIA (#17284)
Per MSC3967, which is now stable, we should not require UIA when
uploading cross-signing keys for the first time.

Fixes: #17227
2024-06-14 11:14:56 +01:00
Erik Johnston f237303c32 Handle previously persisted events properly 2024-06-10 16:58:58 +01:00
Erik Johnston db67cd6893 Newsfile 2024-06-10 14:24:16 +01:00
Erik Johnston 5a287d9d14 Handle large chain calc better
We calculate the auth chain links outside of the main persist event
transaction to ensure that we do not block other event sending during
the calculation.
2024-06-10 14:22:58 +01:00
12 changed files with 253 additions and 209 deletions

View file

@ -0,0 +1 @@
Do not require user-interactive authentication for uploading cross-signing keys for the first time, per MSC3967.

1
changelog.d/17291.misc Normal file
View file

@ -0,0 +1 @@
Do not block event sending/receiving while calulating large event auth chains.

View file

@ -393,9 +393,6 @@ class ExperimentalConfig(Config):
# MSC3391: Removing account data. # MSC3391: Removing account data.
self.msc3391_enabled = experimental.get("msc3391_enabled", False) self.msc3391_enabled = experimental.get("msc3391_enabled", False)
# MSC3967: Do not require UIA when first uploading cross signing keys
self.msc3967_enabled = experimental.get("msc3967_enabled", False)
# MSC3861: Matrix architecture change to delegate authentication via OIDC # MSC3861: Matrix architecture change to delegate authentication via OIDC
try: try:
self.msc3861 = MSC3861(**experimental.get("msc3861", {})) self.msc3861 = MSC3861(**experimental.get("msc3861", {}))

View file

@ -41,7 +41,6 @@ class ExperimentalFeature(str, Enum):
MSC3026 = "msc3026" MSC3026 = "msc3026"
MSC3881 = "msc3881" MSC3881 = "msc3881"
MSC3967 = "msc3967"
class ExperimentalFeaturesRestServlet(RestServlet): class ExperimentalFeaturesRestServlet(RestServlet):

View file

@ -382,44 +382,35 @@ class SigningKeyUploadServlet(RestServlet):
master_key_updatable_without_uia, master_key_updatable_without_uia,
) = await self.e2e_keys_handler.check_cross_signing_setup(user_id) ) = await self.e2e_keys_handler.check_cross_signing_setup(user_id)
# Before MSC3967 we required UIA both when setting up cross signing for the # Resending exactly the same keys should just 200 OK without doing a UIA prompt.
# first time and when resetting the device signing key. With MSC3967 we only keys_are_different = await self.e2e_keys_handler.has_different_keys(
# require UIA when resetting cross-signing, and not when setting up the first user_id, body
# time. Because there is no UIA in MSC3861, for now we throw an error if the )
# user tries to reset the device signing key when MSC3861 is enabled, but allow if not keys_are_different:
# first-time setup. return 200, {}
if self.hs.config.experimental.msc3861.enabled:
# The auth service has to explicitly mark the master key as replaceable
# without UIA to reset the device signing key with MSC3861.
if is_cross_signing_setup and not master_key_updatable_without_uia:
config = self.hs.config.experimental.msc3861
if config.account_management_url is not None:
url = f"{config.account_management_url}?action=org.matrix.cross_signing_reset"
else:
url = config.issuer
raise SynapseError( # The keys are different; is x-signing set up? If no, then this is first-time
HTTPStatus.NOT_IMPLEMENTED, # setup, and that is allowed without UIA, per MSC3967.
"To reset your end-to-end encryption cross-signing identity, " # If yes, then we need to authenticate the change.
f"you first need to approve it at {url} and then try again.", if is_cross_signing_setup:
Codes.UNRECOGNIZED, # With MSC3861, UIA is not possible. Instead, the auth service has to
) # explicitly mark the master key as replaceable.
# But first-time setup is fine if self.hs.config.experimental.msc3861.enabled:
if not master_key_updatable_without_uia:
config = self.hs.config.experimental.msc3861
if config.account_management_url is not None:
url = f"{config.account_management_url}?action=org.matrix.cross_signing_reset"
else:
url = config.issuer
elif self.hs.config.experimental.msc3967_enabled: raise SynapseError(
# MSC3967 allows this endpoint to 200 OK for idempotency. Resending exactly the same HTTPStatus.NOT_IMPLEMENTED,
# keys should just 200 OK without doing a UIA prompt. "To reset your end-to-end encryption cross-signing identity, "
keys_are_different = await self.e2e_keys_handler.has_different_keys( f"you first need to approve it at {url} and then try again.",
user_id, body Codes.UNRECOGNIZED,
) )
if not keys_are_different: else:
# FIXME: we do not fallthrough to upload_signing_keys_for_user because confusingly # Without MSC3861, we require UIA.
# if we do, we 500 as it looks like it tries to INSERT the same key twice, causing a
# unique key constraint violation. This sounds like a bug?
return 200, {}
# the keys are different, is x-signing set up? If no, then the keys don't exist which is
# why they are different. If yes, then we need to UIA to change them.
if is_cross_signing_setup:
await self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, requester,
request, request,
@ -428,18 +419,6 @@ class SigningKeyUploadServlet(RestServlet):
# Do not allow skipping of UIA auth. # Do not allow skipping of UIA auth.
can_skip_ui_auth=False, can_skip_ui_auth=False,
) )
# Otherwise we don't require UIA since we are setting up cross signing for first time
else:
# Previous behaviour is to always require UIA but allow it to be skipped
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
"add a device signing key to your account",
# Allow skipping of UI auth since this is frequently called directly
# after login and it is silly to ask users to re-auth immediately.
can_skip_ui_auth=True,
)
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
return 200, result return 200, result

View file

@ -617,6 +617,17 @@ class EventsPersistenceStorageController:
room_id, chunk room_id, chunk
) )
with Measure(self._clock, "calculate_chain_cover_index_for_events"):
# We now calculate chain ID/sequence numbers for any state events we're
# persisting. We ignore out of band memberships as we're not in the room
# and won't have their auth chain (we'll fix it up later if we join the
# room).
#
# See: docs/auth_chain_difference_algorithm.md
new_event_links = await self.persist_events_store.calculate_chain_cover_index_for_events(
room_id, [e for e, _ in chunk]
)
await self.persist_events_store._persist_events_and_state_updates( await self.persist_events_store._persist_events_and_state_updates(
room_id, room_id,
chunk, chunk,
@ -624,6 +635,7 @@ class EventsPersistenceStorageController:
new_forward_extremities=new_forward_extremities, new_forward_extremities=new_forward_extremities,
use_negative_stream_ordering=backfilled, use_negative_stream_ordering=backfilled,
inhibit_local_membership_updates=backfilled, inhibit_local_membership_updates=backfilled,
new_event_links=new_event_links,
) )
return replaced_events return replaced_events

View file

@ -34,7 +34,6 @@ from typing import (
Optional, Optional,
Set, Set,
Tuple, Tuple,
Union,
cast, cast,
) )
@ -100,6 +99,14 @@ class DeltaState:
return not self.to_delete and not self.to_insert and not self.no_longer_in_room return not self.to_delete and not self.to_insert and not self.no_longer_in_room
@attr.s(slots=True, auto_attribs=True)
class NewEventChainLinks:
chain_id: int
sequence_number: int
links: List[Tuple[int, int]] = attr.Factory(list)
class PersistEventsStore: class PersistEventsStore:
"""Contains all the functions for writing events to the database. """Contains all the functions for writing events to the database.
@ -148,6 +155,7 @@ class PersistEventsStore:
*, *,
state_delta_for_room: Optional[DeltaState], state_delta_for_room: Optional[DeltaState],
new_forward_extremities: Optional[Set[str]], new_forward_extremities: Optional[Set[str]],
new_event_links: Dict[str, NewEventChainLinks],
use_negative_stream_ordering: bool = False, use_negative_stream_ordering: bool = False,
inhibit_local_membership_updates: bool = False, inhibit_local_membership_updates: bool = False,
) -> None: ) -> None:
@ -217,6 +225,7 @@ class PersistEventsStore:
inhibit_local_membership_updates=inhibit_local_membership_updates, inhibit_local_membership_updates=inhibit_local_membership_updates,
state_delta_for_room=state_delta_for_room, state_delta_for_room=state_delta_for_room,
new_forward_extremities=new_forward_extremities, new_forward_extremities=new_forward_extremities,
new_event_links=new_event_links,
) )
persist_event_counter.inc(len(events_and_contexts)) persist_event_counter.inc(len(events_and_contexts))
@ -243,6 +252,87 @@ class PersistEventsStore:
(room_id,), frozenset(new_forward_extremities) (room_id,), frozenset(new_forward_extremities)
) )
async def calculate_chain_cover_index_for_events(
self, room_id: str, events: Collection[EventBase]
) -> Dict[str, NewEventChainLinks]:
# Filter to state events, and ensure there are no duplicates.
state_events = []
seen_events = set()
for event in events:
if not event.is_state() or event.event_id in seen_events:
continue
state_events.append(event)
seen_events.add(event.event_id)
if not state_events:
return {}
return await self.db_pool.runInteraction(
"_calculate_chain_cover_index_for_events",
self.calculate_chain_cover_index_for_events_txn,
room_id,
state_events,
)
def calculate_chain_cover_index_for_events_txn(
self, txn: LoggingTransaction, room_id: str, state_events: Collection[EventBase]
) -> Dict[str, NewEventChainLinks]:
# We now calculate chain ID/sequence numbers for any state events we're
# persisting. We ignore out of band memberships as we're not in the room
# and won't have their auth chain (we'll fix it up later if we join the
# room).
#
# See: docs/auth_chain_difference_algorithm.md
# We ignore legacy rooms that we aren't filling the chain cover index
# for.
row = self.db_pool.simple_select_one_txn(
txn,
table="rooms",
keyvalues={"room_id": room_id},
retcols=("room_id", "has_auth_chain_index"),
allow_none=True,
)
if row is None:
return {}
# Filter out already persisted events.
rows = self.db_pool.simple_select_many_txn(
txn,
table="events",
column="event_id",
iterable=[e.event_id for e in state_events],
keyvalues={},
retcols=("event_id",),
)
already_persisted_events = {event_id for event_id, in rows}
state_events = [
event
for event in state_events
if event.event_id in already_persisted_events
]
if not state_events:
return {}
# We need to know the type/state_key and auth events of the events we're
# calculating chain IDs for. We don't rely on having the full Event
# instances as we'll potentially be pulling more events from the DB and
# we don't need the overhead of fetching/parsing the full event JSON.
event_to_types = {e.event_id: (e.type, e.state_key) for e in state_events}
event_to_auth_chain = {e.event_id: e.auth_event_ids() for e in state_events}
event_to_room_id = {e.event_id: e.room_id for e in state_events}
return self._calculate_chain_cover_index(
txn,
self.db_pool,
self.store.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of """Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events. existing (non-outlier/rejected) events.
@ -358,6 +448,7 @@ class PersistEventsStore:
inhibit_local_membership_updates: bool, inhibit_local_membership_updates: bool,
state_delta_for_room: Optional[DeltaState], state_delta_for_room: Optional[DeltaState],
new_forward_extremities: Optional[Set[str]], new_forward_extremities: Optional[Set[str]],
new_event_links: Dict[str, NewEventChainLinks],
) -> None: ) -> None:
"""Insert some number of room events into the necessary database tables. """Insert some number of room events into the necessary database tables.
@ -466,7 +557,9 @@ class PersistEventsStore:
# Insert into event_to_state_groups. # Insert into event_to_state_groups.
self._store_event_state_mappings_txn(txn, events_and_contexts) self._store_event_state_mappings_txn(txn, events_and_contexts)
self._persist_event_auth_chain_txn(txn, [e for e, _ in events_and_contexts]) self._persist_event_auth_chain_txn(
txn, [e for e, _ in events_and_contexts], new_event_links
)
# _store_rejected_events_txn filters out any events which were # _store_rejected_events_txn filters out any events which were
# rejected, and returns the filtered list. # rejected, and returns the filtered list.
@ -496,6 +589,7 @@ class PersistEventsStore:
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,
events: List[EventBase], events: List[EventBase],
new_event_links: Dict[str, NewEventChainLinks],
) -> None: ) -> None:
# We only care about state events, so this if there are no state events. # We only care about state events, so this if there are no state events.
if not any(e.is_state() for e in events): if not any(e.is_state() for e in events):
@ -519,59 +613,8 @@ class PersistEventsStore:
], ],
) )
# We now calculate chain ID/sequence numbers for any state events we're if new_event_links:
# persisting. We ignore out of band memberships as we're not in the room self._persist_chain_cover_index(txn, self.db_pool, new_event_links)
# and won't have their auth chain (we'll fix it up later if we join the
# room).
#
# See: docs/auth_chain_difference_algorithm.md
# We ignore legacy rooms that we aren't filling the chain cover index
# for.
rows = cast(
List[Tuple[str, Optional[Union[int, bool]]]],
self.db_pool.simple_select_many_txn(
txn,
table="rooms",
column="room_id",
iterable={event.room_id for event in events if event.is_state()},
keyvalues={},
retcols=("room_id", "has_auth_chain_index"),
),
)
rooms_using_chain_index = {
room_id for room_id, has_auth_chain_index in rows if has_auth_chain_index
}
state_events = {
event.event_id: event
for event in events
if event.is_state() and event.room_id in rooms_using_chain_index
}
if not state_events:
return
# We need to know the type/state_key and auth events of the events we're
# calculating chain IDs for. We don't rely on having the full Event
# instances as we'll potentially be pulling more events from the DB and
# we don't need the overhead of fetching/parsing the full event JSON.
event_to_types = {
e.event_id: (e.type, e.state_key) for e in state_events.values()
}
event_to_auth_chain = {
e.event_id: e.auth_event_ids() for e in state_events.values()
}
event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
self._add_chain_cover_index(
txn,
self.db_pool,
self.store.event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
@classmethod @classmethod
def _add_chain_cover_index( def _add_chain_cover_index(
@ -583,6 +626,35 @@ class PersistEventsStore:
event_to_types: Dict[str, Tuple[str, str]], event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, StrCollection], event_to_auth_chain: Dict[str, StrCollection],
) -> None: ) -> None:
"""Calculate and persist the chain cover index for the given events.
Args:
event_to_room_id: Event ID to the room ID of the event
event_to_types: Event ID to type and state_key of the event
event_to_auth_chain: Event ID to list of auth event IDs of the
event (events with no auth events can be excluded).
"""
new_event_links = cls._calculate_chain_cover_index(
txn,
db_pool,
event_chain_id_gen,
event_to_room_id,
event_to_types,
event_to_auth_chain,
)
cls._persist_chain_cover_index(txn, db_pool, new_event_links)
@classmethod
def _calculate_chain_cover_index(
cls,
txn: LoggingTransaction,
db_pool: DatabasePool,
event_chain_id_gen: SequenceGenerator,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, StrCollection],
) -> Dict[str, NewEventChainLinks]:
"""Calculate the chain cover index for the given events. """Calculate the chain cover index for the given events.
Args: Args:
@ -708,11 +780,11 @@ class PersistEventsStore:
room_id = event_to_room_id.get(event_id) room_id = event_to_room_id.get(event_id)
if room_id: if room_id:
e_type, state_key = event_to_types[event_id] e_type, state_key = event_to_types[event_id]
db_pool.simple_insert_txn( db_pool.simple_upsert_txn(
txn, txn,
table="event_auth_chain_to_calculate", table="event_auth_chain_to_calculate",
keyvalues={"event_id": event_id},
values={ values={
"event_id": event_id,
"room_id": room_id, "room_id": room_id,
"type": e_type, "type": e_type,
"state_key": state_key, "state_key": state_key,
@ -724,7 +796,7 @@ class PersistEventsStore:
break break
if not events_to_calc_chain_id_for: if not events_to_calc_chain_id_for:
return return {}
# Allocate chain ID/sequence numbers to each new event. # Allocate chain ID/sequence numbers to each new event.
new_chain_tuples = cls._allocate_chain_ids( new_chain_tuples = cls._allocate_chain_ids(
@ -739,23 +811,10 @@ class PersistEventsStore:
) )
chain_map.update(new_chain_tuples) chain_map.update(new_chain_tuples)
db_pool.simple_insert_many_txn( to_return = {
txn, event_id: NewEventChainLinks(chain_id, sequence_number)
table="event_auth_chains", for event_id, (chain_id, sequence_number) in new_chain_tuples.items()
keys=("event_id", "chain_id", "sequence_number"), }
values=[
(event_id, c_id, seq)
for event_id, (c_id, seq) in new_chain_tuples.items()
],
)
db_pool.simple_delete_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
column="event_id",
values=new_chain_tuples,
)
# Now we need to calculate any new links between chains caused by # Now we need to calculate any new links between chains caused by
# the new events. # the new events.
@ -825,10 +884,38 @@ class PersistEventsStore:
auth_chain_id, auth_sequence_number = chain_map[auth_id] auth_chain_id, auth_sequence_number = chain_map[auth_id]
# Step 2a, add link between the event and auth event # Step 2a, add link between the event and auth event
to_return[event_id].links.append((auth_chain_id, auth_sequence_number))
chain_links.add_link( chain_links.add_link(
(chain_id, sequence_number), (auth_chain_id, auth_sequence_number) (chain_id, sequence_number), (auth_chain_id, auth_sequence_number)
) )
return to_return
@classmethod
def _persist_chain_cover_index(
cls,
txn: LoggingTransaction,
db_pool: DatabasePool,
new_event_links: Dict[str, NewEventChainLinks],
) -> None:
db_pool.simple_insert_many_txn(
txn,
table="event_auth_chains",
keys=("event_id", "chain_id", "sequence_number"),
values=[
(event_id, new_links.chain_id, new_links.sequence_number)
for event_id, new_links in new_event_links.items()
],
)
db_pool.simple_delete_many_txn(
txn,
table="event_auth_chain_to_calculate",
keyvalues={},
column="event_id",
values=new_event_links,
)
db_pool.simple_insert_many_txn( db_pool.simple_insert_many_txn(
txn, txn,
table="event_auth_chain_links", table="event_auth_chain_links",
@ -838,7 +925,16 @@ class PersistEventsStore:
"target_chain_id", "target_chain_id",
"target_sequence_number", "target_sequence_number",
), ),
values=list(chain_links.get_additions()), values=[
(
new_links.chain_id,
new_links.sequence_number,
target_chain_id,
target_sequence_number,
)
for new_links in new_event_links.values()
for (target_chain_id, target_sequence_number) in new_links.links
],
) )
@staticmethod @staticmethod

View file

@ -541,6 +541,8 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
# Try uploading *different* keys; it should cause a 501 error.
keys_upload_body = self.make_device_keys(USER_ID, DEVICE)
channel = self.make_request( channel = self.make_request(
"POST", "POST",
"/_matrix/client/v3/keys/device_signing/upload", "/_matrix/client/v3/keys/device_signing/upload",

View file

@ -435,10 +435,6 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
True, True,
channel.json_body["features"]["msc3881"], channel.json_body["features"]["msc3881"],
) )
self.assertEqual(
False,
channel.json_body["features"]["msc3967"],
)
# test nothing blows up if you try to disable a feature that isn't already enabled # test nothing blows up if you try to disable a feature that isn't already enabled
url = f"{self.url}/{self.other_user}" url = f"{self.url}/{self.other_user}"

View file

@ -155,71 +155,6 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
} }
def test_device_signing_with_uia(self) -> None: def test_device_signing_with_uia(self) -> None:
"""Device signing key upload requires UIA."""
password = "wonderland"
device_id = "ABCDEFGHI"
alice_id = self.register_user("alice", password)
alice_token = self.login("alice", password, device_id=device_id)
content = self.make_device_keys(alice_id, device_id)
channel = self.make_request(
"POST",
"/_matrix/client/v3/keys/device_signing/upload",
content,
alice_token,
)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
# Grab the session
session = channel.json_body["session"]
# Ensure that flows are what is expected.
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
# add UI auth
content["auth"] = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": alice_id},
"password": password,
"session": session,
}
channel = self.make_request(
"POST",
"/_matrix/client/v3/keys/device_signing/upload",
content,
alice_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@override_config({"ui_auth": {"session_timeout": "15m"}})
def test_device_signing_with_uia_session_timeout(self) -> None:
"""Device signing key upload requires UIA buy passes with grace period."""
password = "wonderland"
device_id = "ABCDEFGHI"
alice_id = self.register_user("alice", password)
alice_token = self.login("alice", password, device_id=device_id)
content = self.make_device_keys(alice_id, device_id)
channel = self.make_request(
"POST",
"/_matrix/client/v3/keys/device_signing/upload",
content,
alice_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@override_config(
{
"experimental_features": {"msc3967_enabled": True},
"ui_auth": {"session_timeout": "15s"},
}
)
def test_device_signing_with_msc3967(self) -> None:
"""Device signing key follows MSC3967 behaviour when enabled."""
password = "wonderland" password = "wonderland"
device_id = "ABCDEFGHI" device_id = "ABCDEFGHI"
alice_id = self.register_user("alice", password) alice_id = self.register_user("alice", password)

View file

@ -447,7 +447,14 @@ class EventChainStoreTestCase(HomeserverTestCase):
) )
# Actually call the function that calculates the auth chain stuff. # Actually call the function that calculates the auth chain stuff.
persist_events_store._persist_event_auth_chain_txn(txn, events) new_event_links = (
persist_events_store.calculate_chain_cover_index_for_events_txn(
txn, events[0].room_id, [e for e in events if e.is_state()]
)
)
persist_events_store._persist_event_auth_chain_txn(
txn, events, new_event_links
)
self.get_success( self.get_success(
persist_events_store.db_pool.runInteraction( persist_events_store.db_pool.runInteraction(

View file

@ -365,12 +365,19 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
}, },
) )
events = [
cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
for event_id in AUTH_GRAPH
]
new_event_links = (
self.persist_events.calculate_chain_cover_index_for_events_txn(
txn, room_id, [e for e in events if e.is_state()]
)
)
self.persist_events._persist_event_auth_chain_txn( self.persist_events._persist_event_auth_chain_txn(
txn, txn,
[ events,
cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id])) new_event_links,
for event_id in AUTH_GRAPH
],
) )
self.get_success( self.get_success(
@ -628,13 +635,20 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
) )
# Insert all events apart from 'B' # Insert all events apart from 'B'
events = [
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
for event_id in auth_graph
if event_id != "b"
]
new_event_links = (
self.persist_events.calculate_chain_cover_index_for_events_txn(
txn, room_id, [e for e in events if e.is_state()]
)
)
self.persist_events._persist_event_auth_chain_txn( self.persist_events._persist_event_auth_chain_txn(
txn, txn,
[ events,
cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id])) new_event_links,
for event_id in auth_graph
if event_id != "b"
],
) )
# Now we insert the event 'B' without a chain cover, by temporarily # Now we insert the event 'B' without a chain cover, by temporarily
@ -647,9 +661,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
updatevalues={"has_auth_chain_index": False}, updatevalues={"has_auth_chain_index": False},
) )
events = [cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))]
new_event_links = (
self.persist_events.calculate_chain_cover_index_for_events_txn(
txn, room_id, [e for e in events if e.is_state()]
)
)
self.persist_events._persist_event_auth_chain_txn( self.persist_events._persist_event_auth_chain_txn(
txn, txn, events, new_event_links
[cast(EventBase, FakeEvent("b", room_id, auth_graph["b"]))],
) )
self.store.db_pool.simple_update_txn( self.store.db_pool.simple_update_txn(