From 1ec3885aa9b416b4f697745bcaf8f8bda51738d9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 26 May 2022 15:55:06 -0400 Subject: [PATCH] Accept a start & end event ID when creating a receipt. --- synapse/handlers/receipts.py | 45 ++++++++++++++----- synapse/rest/client/read_marker.py | 4 +- synapse/rest/client/receipts.py | 2 +- synapse/storage/databases/main/receipts.py | 51 ++++++++++++---------- 4 files changed, 66 insertions(+), 36 deletions(-) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 5588545850..989ede27c1 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -19,7 +19,9 @@ from synapse.appservice import ApplicationService from synapse.streams import EventSource from synapse.types import ( JsonDict, + RangedReadReceipt, ReadReceipt, + Receipt, StreamKeyType, UserID, get_domain_from_id, @@ -65,7 +67,7 @@ class ReceiptsHandler: async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None: """Called when we receive an EDU of type m.receipt from a remote HS.""" - receipts = [] + receipts: List[Receipt] = [] for room_id, room_values in content.items(): # If we're not in the room just ditch the event entirely. This is # probably an old server that has come back and thinks we're still in @@ -103,7 +105,7 @@ class ReceiptsHandler: await self._handle_new_receipts(receipts) - async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool: + async def _handle_new_receipts(self, receipts: List[Receipt]) -> bool: """Takes a list of receipts, stores them and informs the notifier.""" min_batch_id: Optional[int] = None max_batch_id: Optional[int] = None @@ -140,24 +142,45 @@ class ReceiptsHandler: return True async def received_client_receipt( - self, room_id: str, receipt_type: str, user_id: str, event_id: str + self, + room_id: str, + receipt_type: str, + user_id: str, + end_event_id: str, + start_event_id: Optional[str] = None, ) -> None: """Called when a client tells us a local user has read up to the given event_id in the room. """ - receipt = ReadReceipt( - room_id=room_id, - receipt_type=receipt_type, - user_id=user_id, - event_ids=[event_id], - data={"ts": int(self.clock.time_msec())}, - ) + + if start_event_id: + receipt: Receipt = RangedReadReceipt( + room_id=room_id, + receipt_type=receipt_type, + user_id=user_id, + start_event_id=start_event_id, + end_event_id=end_event_id, + data={"ts": int(self.clock.time_msec())}, + ) + else: + receipt = ReadReceipt( + room_id=room_id, + receipt_type=receipt_type, + user_id=user_id, + event_ids=[end_event_id], + data={"ts": int(self.clock.time_msec())}, + ) is_new = await self._handle_new_receipts([receipt]) if not is_new: return - if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE: + # XXX How to handle this for a ranged read receipt. + if ( + isinstance(receipt, ReadReceipt) + and self.federation_sender + and receipt_type != ReceiptTypes.READ_PRIVATE + ): await self.federation_sender.send_read_receipt(receipt) diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 3644705e6a..15c031159e 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -71,7 +71,7 @@ class ReadMarkerRestServlet(RestServlet): room_id, ReceiptTypes.READ, user_id=requester.user.to_string(), - event_id=read_event_id, + end_event_id=read_event_id, ) read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None) @@ -80,7 +80,7 @@ class ReadMarkerRestServlet(RestServlet): room_id, ReceiptTypes.READ_PRIVATE, user_id=requester.user.to_string(), - event_id=read_private_event_id, + end_event_id=read_private_event_id, ) read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None) diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 4b03eb876b..f1acd42f5e 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -80,7 +80,7 @@ class ReceiptRestServlet(RestServlet): room_id, receipt_type, user_id=requester.user.to_string(), - event_id=event_id, + end_event_id=event_id, ) return 200, {} diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 2252dd8608..d318753650 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -42,7 +42,7 @@ from synapse.storage.util.id_generators import ( MultiWriterIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict, ReadReceipt +from synapse.types import JsonDict, RangedReadReceipt, ReadReceipt, Receipt from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -725,7 +725,7 @@ class ReceiptsWorkerStore(SQLBaseStore): else: raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) - async def insert_receipt(self, receipt: ReadReceipt) -> Optional[Tuple[int, int]]: + async def insert_receipt(self, receipt: Receipt) -> Optional[Tuple[int, int]]: """Insert a receipt, either from local client or remote server. Automatically does conversion between linearized and graph @@ -737,19 +737,25 @@ class ReceiptsWorkerStore(SQLBaseStore): """ assert self._can_write_to_receipts - if not receipt.event_ids: - return None + if isinstance(receipt, ReadReceipt): + event_ids = receipt.event_ids + if not event_ids: + return None - if len(receipt.event_ids) == 1: - linearized_event_id = receipt.event_ids[0] + if len(event_ids) == 1: + linearized_event_id = event_ids[0] + else: + # we need to points in graph -> linearized form. + linearized_event_id = await self.db_pool.runInteraction( + "insert_receipt_conv", + self._graph_to_linear, + receipt.room_id, + event_ids, + ) + elif isinstance(receipt, RangedReadReceipt): + linearized_event_id = receipt.end_event_id else: - # we need to points in graph -> linearized form. - linearized_event_id = await self.db_pool.runInteraction( - "insert_receipt_conv", - self._graph_to_linear, - receipt.room_id, - receipt.event_ids, - ) + raise ValueError("Unexpected receipt type: %s", type(receipt)) async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined] event_ts = await self.db_pool.runInteraction( @@ -779,15 +785,16 @@ class ReceiptsWorkerStore(SQLBaseStore): now - event_ts, ) - await self.db_pool.runInteraction( - "insert_graph_receipt", - self._insert_graph_receipt_txn, - receipt.room_id, - receipt.receipt_type, - receipt.user_id, - receipt.event_ids, - receipt.data, - ) + # XXX These aren't really used right now, go away. + # await self.db_pool.runInteraction( + # "insert_graph_receipt", + # self._insert_graph_receipt_txn, + # room_id, + # receipt_type, + # user_id, + # event_ids, + # data, + # ) max_persisted_id = self._receipts_id_gen.get_current_token()