Compare commits

...

5 commits

Author SHA1 Message Date
Erik Johnston bd91b8c7c1
Merge ca79b4d87d into 3aae60f17b 2024-06-14 21:11:36 +02:00
Richard van der Hoff 3aae60f17b
Enable cross-signing key upload without UIA (#17284)
Per MSC3967, which is now stable, we should not require UIA when
uploading cross-signing keys for the first time.

Fixes: #17227
2024-06-14 11:14:56 +01:00
Erik Johnston ca79b4d87d Use a sortedset instead 2024-05-09 10:58:00 +01:00
Erik Johnston 202a09cdb3 Newsfile 2024-05-08 16:05:24 +01:00
Erik Johnston db25e30a25 Perf improvement to getting auth chains 2024-05-08 16:04:35 +01:00
10 changed files with 123 additions and 128 deletions

1
changelog.d/17169.misc Normal file
View file

@ -0,0 +1 @@
Add database performance improvement when fetching auth chains.

View file

@ -0,0 +1 @@
Do not require user-interactive authentication for uploading cross-signing keys for the first time, per MSC3967.

View file

@ -393,9 +393,6 @@ class ExperimentalConfig(Config):
# MSC3391: Removing account data. # MSC3391: Removing account data.
self.msc3391_enabled = experimental.get("msc3391_enabled", False) self.msc3391_enabled = experimental.get("msc3391_enabled", False)
# MSC3967: Do not require UIA when first uploading cross signing keys
self.msc3967_enabled = experimental.get("msc3967_enabled", False)
# MSC3861: Matrix architecture change to delegate authentication via OIDC # MSC3861: Matrix architecture change to delegate authentication via OIDC
try: try:
self.msc3861 = MSC3861(**experimental.get("msc3861", {})) self.msc3861 = MSC3861(**experimental.get("msc3861", {}))

View file

@ -41,7 +41,6 @@ class ExperimentalFeature(str, Enum):
MSC3026 = "msc3026" MSC3026 = "msc3026"
MSC3881 = "msc3881" MSC3881 = "msc3881"
MSC3967 = "msc3967"
class ExperimentalFeaturesRestServlet(RestServlet): class ExperimentalFeaturesRestServlet(RestServlet):

View file

@ -382,44 +382,35 @@ class SigningKeyUploadServlet(RestServlet):
master_key_updatable_without_uia, master_key_updatable_without_uia,
) = await self.e2e_keys_handler.check_cross_signing_setup(user_id) ) = await self.e2e_keys_handler.check_cross_signing_setup(user_id)
# Before MSC3967 we required UIA both when setting up cross signing for the # Resending exactly the same keys should just 200 OK without doing a UIA prompt.
# first time and when resetting the device signing key. With MSC3967 we only keys_are_different = await self.e2e_keys_handler.has_different_keys(
# require UIA when resetting cross-signing, and not when setting up the first user_id, body
# time. Because there is no UIA in MSC3861, for now we throw an error if the )
# user tries to reset the device signing key when MSC3861 is enabled, but allow if not keys_are_different:
# first-time setup. return 200, {}
if self.hs.config.experimental.msc3861.enabled:
# The auth service has to explicitly mark the master key as replaceable
# without UIA to reset the device signing key with MSC3861.
if is_cross_signing_setup and not master_key_updatable_without_uia:
config = self.hs.config.experimental.msc3861
if config.account_management_url is not None:
url = f"{config.account_management_url}?action=org.matrix.cross_signing_reset"
else:
url = config.issuer
raise SynapseError( # The keys are different; is x-signing set up? If no, then this is first-time
HTTPStatus.NOT_IMPLEMENTED, # setup, and that is allowed without UIA, per MSC3967.
"To reset your end-to-end encryption cross-signing identity, " # If yes, then we need to authenticate the change.
f"you first need to approve it at {url} and then try again.", if is_cross_signing_setup:
Codes.UNRECOGNIZED, # With MSC3861, UIA is not possible. Instead, the auth service has to
) # explicitly mark the master key as replaceable.
# But first-time setup is fine if self.hs.config.experimental.msc3861.enabled:
if not master_key_updatable_without_uia:
config = self.hs.config.experimental.msc3861
if config.account_management_url is not None:
url = f"{config.account_management_url}?action=org.matrix.cross_signing_reset"
else:
url = config.issuer
elif self.hs.config.experimental.msc3967_enabled: raise SynapseError(
# MSC3967 allows this endpoint to 200 OK for idempotency. Resending exactly the same HTTPStatus.NOT_IMPLEMENTED,
# keys should just 200 OK without doing a UIA prompt. "To reset your end-to-end encryption cross-signing identity, "
keys_are_different = await self.e2e_keys_handler.has_different_keys( f"you first need to approve it at {url} and then try again.",
user_id, body Codes.UNRECOGNIZED,
) )
if not keys_are_different: else:
# FIXME: we do not fallthrough to upload_signing_keys_for_user because confusingly # Without MSC3861, we require UIA.
# if we do, we 500 as it looks like it tries to INSERT the same key twice, causing a
# unique key constraint violation. This sounds like a bug?
return 200, {}
# the keys are different, is x-signing set up? If no, then the keys don't exist which is
# why they are different. If yes, then we need to UIA to change them.
if is_cross_signing_setup:
await self.auth_handler.validate_user_via_ui_auth( await self.auth_handler.validate_user_via_ui_auth(
requester, requester,
request, request,
@ -428,18 +419,6 @@ class SigningKeyUploadServlet(RestServlet):
# Do not allow skipping of UIA auth. # Do not allow skipping of UIA auth.
can_skip_ui_auth=False, can_skip_ui_auth=False,
) )
# Otherwise we don't require UIA since we are setting up cross signing for first time
else:
# Previous behaviour is to always require UIA but allow it to be skipped
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
"add a device signing key to your account",
# Allow skipping of UI auth since this is frequently called directly
# after login and it is silly to ask users to re-auth immediately.
can_skip_ui_auth=True,
)
result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body) result = await self.e2e_keys_handler.upload_signing_keys_for_user(user_id, body)
return 200, result return 200, result

View file

@ -39,6 +39,7 @@ from typing import (
import attr import attr
from prometheus_client import Counter, Gauge from prometheus_client import Counter, Gauge
from sortedcontainers import SortedSet
from synapse.api.constants import MAX_DEPTH from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
@ -283,7 +284,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# A map from chain ID to max sequence number *reachable* from any event ID. # A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {} chains: Dict[int, int] = {}
for links in self._get_chain_links(txn, set(event_chains.keys())): for links in self._get_chain_links(txn, event_chains.keys()):
for chain_id in links: for chain_id in links:
if chain_id not in event_chains: if chain_id not in event_chains:
continue continue
@ -335,7 +336,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
@classmethod @classmethod
def _get_chain_links( def _get_chain_links(
cls, txn: LoggingTransaction, chains_to_fetch: Set[int] cls, txn: LoggingTransaction, chains_to_fetch: Collection[int]
) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]: ) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
"""Fetch all auth chain links from the given set of chains, and all """Fetch all auth chain links from the given set of chains, and all
links from those chains, recursively. links from those chains, recursively.
@ -371,9 +372,19 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id) INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
""" """
while chains_to_fetch: # We fetch the links in batches. Separate batches will likely fetch the
batch2 = tuple(itertools.islice(chains_to_fetch, 1000)) # same set of links (e.g. they'll always pull in the links to create
chains_to_fetch.difference_update(batch2) # event). To try and minimize the amount of redundant links, we query
# the chain IDs in reverse order, as there will be a correlation between
# the order of chain IDs and links (i.e., higher chain IDs are more
# likely to depend on lower chain IDs than vice versa).
BATCH_SIZE = 1000
chains_to_fetch_sorted = SortedSet(chains_to_fetch)
while chains_to_fetch_sorted:
batch2 = list(chains_to_fetch_sorted.islice(-BATCH_SIZE))
chains_to_fetch_sorted.difference_update(batch2)
clause, args = make_in_list_sql_clause( clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2 txn.database_engine, "origin_chain_id", batch2
) )
@ -391,7 +402,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
(origin_sequence_number, target_chain_id, target_sequence_number) (origin_sequence_number, target_chain_id, target_sequence_number)
) )
chains_to_fetch.difference_update(links) chains_to_fetch_sorted.difference_update(links)
yield links yield links
@ -581,7 +592,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# are reachable from any event. # are reachable from any event.
# (We need to take a copy of `seen_chains` as the function mutates it) # (We need to take a copy of `seen_chains` as the function mutates it)
for links in self._get_chain_links(txn, set(seen_chains)): for links in self._get_chain_links(txn, seen_chains):
for chains in set_to_chain: for chains in set_to_chain:
for chain_id in links: for chain_id in links:
if chain_id not in chains: if chain_id not in chains:

View file

@ -541,6 +541,8 @@ class MSC3861OAuthDelegation(HomeserverTestCase):
self.assertEqual(channel.code, 200, channel.json_body) self.assertEqual(channel.code, 200, channel.json_body)
# Try uploading *different* keys; it should cause a 501 error.
keys_upload_body = self.make_device_keys(USER_ID, DEVICE)
channel = self.make_request( channel = self.make_request(
"POST", "POST",
"/_matrix/client/v3/keys/device_signing/upload", "/_matrix/client/v3/keys/device_signing/upload",

View file

@ -435,10 +435,6 @@ class ExperimentalFeaturesTestCase(unittest.HomeserverTestCase):
True, True,
channel.json_body["features"]["msc3881"], channel.json_body["features"]["msc3881"],
) )
self.assertEqual(
False,
channel.json_body["features"]["msc3967"],
)
# test nothing blows up if you try to disable a feature that isn't already enabled # test nothing blows up if you try to disable a feature that isn't already enabled
url = f"{self.url}/{self.other_user}" url = f"{self.url}/{self.other_user}"

View file

@ -155,71 +155,6 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
} }
def test_device_signing_with_uia(self) -> None: def test_device_signing_with_uia(self) -> None:
"""Device signing key upload requires UIA."""
password = "wonderland"
device_id = "ABCDEFGHI"
alice_id = self.register_user("alice", password)
alice_token = self.login("alice", password, device_id=device_id)
content = self.make_device_keys(alice_id, device_id)
channel = self.make_request(
"POST",
"/_matrix/client/v3/keys/device_signing/upload",
content,
alice_token,
)
self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
# Grab the session
session = channel.json_body["session"]
# Ensure that flows are what is expected.
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
# add UI auth
content["auth"] = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": alice_id},
"password": password,
"session": session,
}
channel = self.make_request(
"POST",
"/_matrix/client/v3/keys/device_signing/upload",
content,
alice_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@override_config({"ui_auth": {"session_timeout": "15m"}})
def test_device_signing_with_uia_session_timeout(self) -> None:
"""Device signing key upload requires UIA buy passes with grace period."""
password = "wonderland"
device_id = "ABCDEFGHI"
alice_id = self.register_user("alice", password)
alice_token = self.login("alice", password, device_id=device_id)
content = self.make_device_keys(alice_id, device_id)
channel = self.make_request(
"POST",
"/_matrix/client/v3/keys/device_signing/upload",
content,
alice_token,
)
self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
@override_config(
{
"experimental_features": {"msc3967_enabled": True},
"ui_auth": {"session_timeout": "15s"},
}
)
def test_device_signing_with_msc3967(self) -> None:
"""Device signing key follows MSC3967 behaviour when enabled."""
password = "wonderland" password = "wonderland"
device_id = "ABCDEFGHI" device_id = "ABCDEFGHI"
alice_id = self.register_user("alice", password) alice_id = self.register_user("alice", password)

View file

@ -25,6 +25,7 @@ from synapse.rest.client import room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.util import Clock from synapse.util import Clock
from tests.test_utils.event_injection import inject_event
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@ -128,3 +129,76 @@ class PurgeTests(HomeserverTestCase):
self.store._invalidate_local_get_event_cache(create_event.event_id) self.store._invalidate_local_get_event_cache(create_event.event_id)
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError) self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError) self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
def test_state_groups_state_decreases(self) -> None:
response = self.helper.send(self.room_id, body="first")
first_event_id = response["event_id"]
batches = []
previous_event_id = first_event_id
for i in range(50):
state_event1 = self.get_success(
inject_event(
self.hs,
type="test.state",
sender=self.user_id,
state_key="",
room_id=self.room_id,
content={"key": i, "e": 1},
prev_event_ids=[previous_event_id],
origin_server_ts=1,
)
)
state_event2 = self.get_success(
inject_event(
self.hs,
type="test.state",
sender=self.user_id,
state_key="",
room_id=self.room_id,
content={"key": i, "e": 2},
prev_event_ids=[previous_event_id],
origin_server_ts=2,
)
)
# print(state_event2.origin_server_ts - state_event1.origin_server_ts)
message_event = self.get_success(
inject_event(
self.hs,
type="dummy_event",
sender=self.user_id,
room_id=self.room_id,
content={},
prev_event_ids=[state_event1.event_id, state_event2.event_id],
)
)
token = self.get_success(
self.store.get_topological_token_for_event(state_event1.event_id)
)
batches.append(token)
previous_event_id = message_event.event_id
self.helper.send(self.room_id, body="last event")
def count_state_groups() -> int:
sql = "SELECT COUNT(*) FROM state_groups_state WHERE room_id = ?"
rows = self.get_success(
self.store.db_pool.execute("test_deduplicate_joins", sql, self.room_id)
)
return rows[0][0]
print(count_state_groups())
for token in batches:
token_str = self.get_success(token.to_string(self.hs.get_datastores().main))
self.get_success(
self._storage_controllers.purge_events.purge_history(
self.room_id, token_str, False
)
)
print(count_state_groups())