Separate creating an event context from persisting it in the federation handler (#9800)

This refactoring allows adding logic that uses the event context
before persisting it.
This commit is contained in:
Patrick Cloke 2021-04-14 12:35:28 -04:00 committed by GitHub
parent e8816c6ace
commit 936e69825a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 118 additions and 67 deletions

1
changelog.d/9800.feature Normal file
View file

@ -0,0 +1 @@
Update experimental support for [MSC3083](https://github.com/matrix-org/matrix-doc/pull/3083): restricting room access via group membership.

View file

@ -103,7 +103,7 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True) @attr.s(slots=True)
class _NewEventInfo: class _NewEventInfo:
"""Holds information about a received event, ready for passing to _handle_new_events """Holds information about a received event, ready for passing to _auth_and_persist_events
Attributes: Attributes:
event: the received event event: the received event
@ -807,7 +807,10 @@ class FederationHandler(BaseHandler):
logger.debug("Processing event: %s", event) logger.debug("Processing event: %s", event)
try: try:
await self._handle_new_event(origin, event, state=state) context = await self.state_handler.compute_event_context(
event, old_state=state
)
await self._auth_and_persist_event(origin, event, context, state=state)
except AuthError as e: except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id) raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
@ -1010,7 +1013,9 @@ class FederationHandler(BaseHandler):
) )
if ev_infos: if ev_infos:
await self._handle_new_events(dest, room_id, ev_infos, backfilled=True) await self._auth_and_persist_events(
dest, room_id, ev_infos, backfilled=True
)
# Step 2: Persist the rest of the events in the chunk one by one # Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth) events.sort(key=lambda e: e.depth)
@ -1023,10 +1028,12 @@ class FederationHandler(BaseHandler):
# non-outliers # non-outliers
assert not event.internal_metadata.is_outlier() assert not event.internal_metadata.is_outlier()
context = await self.state_handler.compute_event_context(event)
# We store these one at a time since each event depends on the # We store these one at a time since each event depends on the
# previous to work out the state. # previous to work out the state.
# TODO: We can probably do something more clever here. # TODO: We can probably do something more clever here.
await self._handle_new_event(dest, event, backfilled=True) await self._auth_and_persist_event(dest, event, context, backfilled=True)
return events return events
@ -1360,7 +1367,7 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth)) event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events( await self._auth_and_persist_events(
destination, destination,
room_id, room_id,
event_infos, event_infos,
@ -1666,10 +1673,11 @@ class FederationHandler(BaseHandler):
# would introduce the danger of backwards-compatibility problems. # would introduce the danger of backwards-compatibility problems.
event.internal_metadata.send_on_behalf_of = origin event.internal_metadata.send_on_behalf_of = origin
context = await self._handle_new_event(origin, event) context = await self.state_handler.compute_event_context(event)
context = await self._auth_and_persist_event(origin, event, context)
logger.debug( logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s", "on_send_join_request: After _auth_and_persist_event: %s, sigs: %s",
event.event_id, event.event_id,
event.signatures, event.signatures,
) )
@ -1878,10 +1886,11 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False event.internal_metadata.outlier = False
await self._handle_new_event(origin, event) context = await self.state_handler.compute_event_context(event)
await self._auth_and_persist_event(origin, event, context)
logger.debug( logger.debug(
"on_send_leave_request: After _handle_new_event: %s, sigs: %s", "on_send_leave_request: After _auth_and_persist_event: %s, sigs: %s",
event.event_id, event.event_id,
event.signatures, event.signatures,
) )
@ -1989,16 +1998,47 @@ class FederationHandler(BaseHandler):
async def get_min_depth_for_context(self, context: str) -> int: async def get_min_depth_for_context(self, context: str) -> int:
return await self.store.get_min_depth(context) return await self.store.get_min_depth(context)
async def _handle_new_event( async def _auth_and_persist_event(
self, self,
origin: str, origin: str,
event: EventBase, event: EventBase,
context: EventContext,
state: Optional[Iterable[EventBase]] = None, state: Optional[Iterable[EventBase]] = None,
auth_events: Optional[MutableStateMap[EventBase]] = None, auth_events: Optional[MutableStateMap[EventBase]] = None,
backfilled: bool = False, backfilled: bool = False,
) -> EventContext: ) -> EventContext:
context = await self._prep_event( """
origin, event, state=state, auth_events=auth_events, backfilled=backfilled Process an event by performing auth checks and then persisting to the database.
Args:
origin: The host the event originates from.
event: The event itself.
context:
The event context.
NB that this function potentially modifies it.
state:
The state events used to check the event for soft-fail. If this is
not provided the current state events will be used.
auth_events:
Map from (event_type, state_key) to event
Normally, our calculated auth_events based on the state of the room
at the event's position in the DAG, though occasionally (eg if the
event is an outlier), may be the auth events claimed by the remote
server.
backfilled: True if the event was backfilled.
Returns:
The event context.
"""
context = await self._check_event_auth(
origin,
event,
context,
state=state,
auth_events=auth_events,
backfilled=backfilled,
) )
try: try:
@ -2022,7 +2062,7 @@ class FederationHandler(BaseHandler):
return context return context
async def _handle_new_events( async def _auth_and_persist_events(
self, self,
origin: str, origin: str,
room_id: str, room_id: str,
@ -2040,9 +2080,13 @@ class FederationHandler(BaseHandler):
async def prep(ev_info: _NewEventInfo): async def prep(ev_info: _NewEventInfo):
event = ev_info.event event = ev_info.event
with nested_logging_context(suffix=event.event_id): with nested_logging_context(suffix=event.event_id):
res = await self._prep_event( res = await self.state_handler.compute_event_context(
event, old_state=ev_info.state
)
res = await self._check_event_auth(
origin, origin,
event, event,
res,
state=ev_info.state, state=ev_info.state,
auth_events=ev_info.auth_events, auth_events=ev_info.auth_events,
backfilled=backfilled, backfilled=backfilled,
@ -2177,49 +2221,6 @@ class FederationHandler(BaseHandler):
room_id, [(event, new_event_context)] room_id, [(event, new_event_context)]
) )
async def _prep_event(
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext:
context = await self.state_handler.compute_event_context(event, old_state=state)
if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = await self.store.get_event(
event.prev_event_ids()[0], allow_none=True
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
context = await self.do_auth(origin, event, context, auth_events=auth_events)
if not context.rejected:
await self._check_for_soft_fail(event, state, backfilled)
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)
return context
async def _check_for_soft_fail( async def _check_for_soft_fail(
self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
) -> None: ) -> None:
@ -2330,19 +2331,28 @@ class FederationHandler(BaseHandler):
return missing_events return missing_events
async def do_auth( async def _check_event_auth(
self, self,
origin: str, origin: str,
event: EventBase, event: EventBase,
context: EventContext, context: EventContext,
auth_events: MutableStateMap[EventBase], state: Optional[Iterable[EventBase]],
auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext: ) -> EventContext:
""" """
Checks whether an event should be rejected (for failing auth checks).
Args: Args:
origin: origin: The host the event originates from.
event: event: The event itself.
context: context:
The event context.
NB that this function potentially modifies it.
state:
The state events used to check the event for soft-fail. If this is
not provided the current state events will be used.
auth_events: auth_events:
Map from (event_type, state_key) to event Map from (event_type, state_key) to event
@ -2352,12 +2362,34 @@ class FederationHandler(BaseHandler):
server. server.
Also NB that this function adds entries to it. Also NB that this function adds entries to it.
If this is not provided, it is calculated from the previous state IDs.
backfilled: True if the event was backfilled.
Returns: Returns:
updated context object The updated context object.
""" """
room_version = await self.store.get_room_version_id(event.room_id) room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version] room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
if not auth_events:
prev_state_ids = await context.get_prev_state_ids()
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_x = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_x.values()}
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_event_ids():
if len(event.prev_event_ids()) == 1 and event.depth < 5:
c = await self.store.get_event(
event.prev_event_ids()[0], allow_none=True
)
if c and c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
try: try:
context = await self._update_auth_events_and_context_for_auth( context = await self._update_auth_events_and_context_for_auth(
origin, event, context, auth_events origin, event, context, auth_events
@ -2379,6 +2411,17 @@ class FederationHandler(BaseHandler):
logger.warning("Failed auth resolution for %r because %s", event, e) logger.warning("Failed auth resolution for %r because %s", event, e)
context.rejected = RejectedReason.AUTH_ERROR context.rejected = RejectedReason.AUTH_ERROR
if not context.rejected:
await self._check_for_soft_fail(event, state, backfilled)
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)
return context return context
async def _update_auth_events_and_context_for_auth( async def _update_auth_events_and_context_for_auth(
@ -2388,7 +2431,7 @@ class FederationHandler(BaseHandler):
context: EventContext, context: EventContext,
auth_events: MutableStateMap[EventBase], auth_events: MutableStateMap[EventBase],
) -> EventContext: ) -> EventContext:
"""Helper for do_auth. See there for docs. """Helper for _check_event_auth. See there for docs.
Checks whether a given event has the expected auth events. If it Checks whether a given event has the expected auth events. If it
doesn't then we talk to the remote server to compare state to see if doesn't then we talk to the remote server to compare state to see if
@ -2468,9 +2511,14 @@ class FederationHandler(BaseHandler):
e.internal_metadata.outlier = True e.internal_metadata.outlier = True
logger.debug( logger.debug(
"do_auth %s missing_auth: %s", event.event_id, e.event_id "_check_event_auth %s missing_auth: %s",
event.event_id,
e.event_id,
)
context = await self.state_handler.compute_event_context(e)
await self._auth_and_persist_event(
origin, e, context, auth_events=auth
) )
await self._handle_new_event(origin, e, auth_events=auth)
if e.event_id in event_auth_events: if e.event_id in event_auth_events:
auth_events[(e.type, e.state_key)] = e auth_events[(e.type, e.state_key)] = e

View file

@ -75,8 +75,10 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
) )
self.handler = self.homeserver.get_federation_handler() self.handler = self.homeserver.get_federation_handler()
self.handler.do_auth = lambda origin, event, context, auth_events: succeed( self.handler._check_event_auth = (
context lambda origin, event, context, state, auth_events, backfilled: succeed(
context
)
) )
self.client = self.homeserver.get_federation_client() self.client = self.homeserver.get_federation_client()
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed( self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(