diff --git a/changelog.d/12338.misc b/changelog.d/12338.misc new file mode 100644 index 0000000000..376089f327 --- /dev/null +++ b/changelog.d/12338.misc @@ -0,0 +1 @@ +Refactor relations code to remove an unnecessary class. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index b9497ff3f3..a36936b520 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Dict, Iterable, Optional, cast +from typing import TYPE_CHECKING, Dict, Iterable, Optional import attr from frozendict import frozendict @@ -25,7 +25,6 @@ from synapse.visibility import filter_events_for_client if TYPE_CHECKING: from synapse.server import HomeServer - from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -116,7 +115,7 @@ class RelationsHandler: if event is None: raise SynapseError(404, "Unknown parent event.") - pagination_chunk = await self._main_store.get_relations_for_event( + related_events, next_token = await self._main_store.get_relations_for_event( event_id=event_id, event=event, room_id=room_id, @@ -129,9 +128,7 @@ class RelationsHandler: to_token=to_token, ) - events = await self._main_store.get_events_as_list( - [c["event_id"] for c in pagination_chunk.chunk] - ) + events = await self._main_store.get_events_as_list(related_events) events = await filter_events_for_client( self._storage, user_id, events, is_peeking=(member_event_id is None) @@ -152,9 +149,16 @@ class RelationsHandler: events, now, bundle_aggregations=aggregations ) - return_value = await pagination_chunk.to_dict(self._main_store) - return_value["chunk"] = serialized_events - return_value["original_event"] = original_event + return_value = { + "chunk": serialized_events, + "original_event": original_event, + } + + if next_token: + return_value["next_batch"] = await next_token.to_string(self._main_store) + + if from_token: + return_value["prev_batch"] = await from_token.to_string(self._main_store) return return_value @@ -196,11 +200,18 @@ class RelationsHandler: if annotations: aggregations.annotations = {"chunk": annotations} - references = await self._main_store.get_relations_for_event( + references, next_token = await self._main_store.get_relations_for_event( event_id, event, room_id, RelationTypes.REFERENCE, direction="f" ) - if references.chunk: - aggregations.references = await references.to_dict(cast("DataStore", self)) + if references: + aggregations.references = { + "chunk": [{"event_id": event_id} for event_id in references] + } + + if next_token: + aggregations.references["next_batch"] = await next_token.to_string( + self._main_store + ) # Store the bundled aggregations in the event metadata for later use. return aggregations diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 3285450742..64a7808140 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -37,7 +37,6 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine -from synapse.storage.relations import PaginationChunk from synapse.types import JsonDict, RoomStreamToken, StreamToken from synapse.util.caches.descriptors import cached, cachedList @@ -71,7 +70,7 @@ class RelationsWorkerStore(SQLBaseStore): direction: str = "b", from_token: Optional[StreamToken] = None, to_token: Optional[StreamToken] = None, - ) -> PaginationChunk: + ) -> Tuple[List[str], Optional[StreamToken]]: """Get a list of relations for an event, ordered by topological ordering. Args: @@ -88,8 +87,10 @@ class RelationsWorkerStore(SQLBaseStore): to_token: Fetch rows up to the given token, or up to the end if None. Returns: - List of event IDs that match relations requested. The rows are of - the form `{"event_id": "..."}`. + A tuple of: + A list of related event IDs + + The next stream token, if one exists. """ # We don't use `event_id`, it's there so that we can cache based on # it. The `event_id` must match the `event.event_id`. @@ -144,7 +145,7 @@ class RelationsWorkerStore(SQLBaseStore): def _get_recent_references_for_event_txn( txn: LoggingTransaction, - ) -> PaginationChunk: + ) -> Tuple[List[str], Optional[StreamToken]]: txn.execute(sql, where_args + [limit + 1]) last_topo_id = None @@ -154,7 +155,7 @@ class RelationsWorkerStore(SQLBaseStore): # Do not include edits for redacted events as they leak event # content. if not is_redacted or row[1] != RelationTypes.REPLACE: - events.append({"event_id": row[0]}) + events.append(row[0]) last_topo_id = row[2] last_stream_id = row[3] @@ -177,9 +178,7 @@ class RelationsWorkerStore(SQLBaseStore): groups_key=0, ) - return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token - ) + return events[:limit], next_token return await self.db_pool.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py deleted file mode 100644 index b9d2b46799..0000000000 --- a/synapse/storage/relations.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright 2019 New Vector Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from typing import TYPE_CHECKING, Any, Dict, List, Optional - -import attr - -from synapse.types import JsonDict - -if TYPE_CHECKING: - from synapse.storage.databases.main import DataStore - -logger = logging.getLogger(__name__) - - -@attr.s(slots=True, auto_attribs=True) -class PaginationChunk: - """Returned by relation pagination APIs. - - Attributes: - chunk: The rows returned by pagination - next_batch: Token to fetch next set of results with, if - None then there are no more results. - prev_batch: Token to fetch previous set of results with, if - None then there are no previous results. - """ - - chunk: List[JsonDict] - next_batch: Optional[Any] = None - prev_batch: Optional[Any] = None - - async def to_dict(self, store: "DataStore") -> Dict[str, Any]: - d = {"chunk": self.chunk} - - if self.next_batch: - d["next_batch"] = await self.next_batch.to_string(store) - - if self.prev_batch: - d["prev_batch"] = await self.prev_batch.to_string(store) - - return d