From eb2733c2d8f584f2d0d52995e24fb73429add53f Mon Sep 17 00:00:00 2001 From: "H. Shay" Date: Mon, 19 Dec 2022 17:58:50 -0800 Subject: [PATCH] allow third party rules to rebuild a batched event --- synapse/handlers/message.py | 81 ++++++++++++++++++++++++++++--------- 1 file changed, 62 insertions(+), 19 deletions(-) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 5cbe89f4fd..8c006d777e 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -1256,12 +1256,22 @@ class EventCreationHandler: raise SynapseError( 403, "This event is not allowed in this context", Codes.FORBIDDEN ) + # the third-party rules want to replace the event. We'll need to build a new + # event. elif new_content is not None: - # the third-party rules want to replace the event. We'll need to build a new - # event. - event, context = await self._rebuild_event_after_third_party_rules( - new_content, event - ) + if for_batch: + event, context = await self._rebuild_event_after_third_party_rules( + new_content, + event, + for_batch=True, + state_map=state_map, + prev_event_ids=prev_event_ids, + current_state_group=current_state_group, + ) + else: + event, context = await self._rebuild_event_after_third_party_rules( + new_content, event + ) self.validator.validate_new(event, self.config) await self._validate_event_relation(event) @@ -2047,10 +2057,27 @@ class EventCreationHandler: del self._rooms_to_exclude_from_dummy_event_insertion[room_id] async def _rebuild_event_after_third_party_rules( - self, third_party_result: dict, original_event: EventBase + self, + third_party_result: dict, + original_event: EventBase, + for_batch: bool = False, + state_map: Optional[StateMap[str]] = None, + prev_event_ids: Optional[List[str]] = None, + current_state_group: Optional[int] = None, ) -> Tuple[EventBase, EventContext]: - # the third_party_event_rules want to replace the event. - # we do some basic checks, and then return the replacement event and context. + """ + The third_party_event_rules want to replace the event. + we do some basic checks, and then return the replacement event and context. + If for_batch is true, state_map, prev_event_ids, and current_state_group must be + provided. + third_party_result: An event dict provided by third-party rules + original_event: the original event being replaced + for_batch: whether the event being created is part of a batch being created + for batch persisting to the DB + state_map: the current state prior to the new event + prev_event_ids: the prev_events of the new event + current_state_group: the state group prior to the event + """ # Construct a new EventBuilder and validate it, which helps with the # rest of these checks. @@ -2093,16 +2120,32 @@ class EventCreationHandler: for k, v in original_event.internal_metadata.get_dict().items(): setattr(builder.internal_metadata, k, v) - # modules can send new state events, so we re-calculate the auth events just in - # case. - prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) + if for_batch: + assert state_map is not None + assert current_state_group is not None + assert prev_event_ids is not None + auth_event_ids = self._event_auth_handler.compute_auth_events( + builder, state_map + ) + event = await builder.build( + prev_event_ids=prev_event_ids, + auth_event_ids=auth_event_ids, + ) + context = await self.state.compute_event_context_for_batched( + event, + state_ids_before_event=state_map, + current_state_group=current_state_group, + ) + else: + # modules can send new state events, so we re-calculate the auth events just in + # case. + prev_event_ids = await self.store.get_prev_events_for_room(builder.room_id) + event = await builder.build( + prev_event_ids=prev_event_ids, + auth_event_ids=None, + ) + # we rebuild the event context, to be on the safe side. If nothing else, + # delta_ids might need an update. + context = await self.state.compute_event_context(event) - event = await builder.build( - prev_event_ids=prev_event_ids, - auth_event_ids=None, - ) - - # we rebuild the event context, to be on the safe side. If nothing else, - # delta_ids might need an update. - context = await self.state.compute_event_context(event) return event, context