Accept a start & end event ID when creating a receipt.

This commit is contained in:
Patrick Cloke 2022-05-26 15:55:06 -04:00
parent 48b2d6c9ef
commit 1ec3885aa9
4 changed files with 66 additions and 36 deletions

View file

@ -19,7 +19,9 @@ from synapse.appservice import ApplicationService
from synapse.streams import EventSource
from synapse.types import (
JsonDict,
RangedReadReceipt,
ReadReceipt,
Receipt,
StreamKeyType,
UserID,
get_domain_from_id,
@ -65,7 +67,7 @@ class ReceiptsHandler:
async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
"""Called when we receive an EDU of type m.receipt from a remote HS."""
receipts = []
receipts: List[Receipt] = []
for room_id, room_values in content.items():
# If we're not in the room just ditch the event entirely. This is
# probably an old server that has come back and thinks we're still in
@ -103,7 +105,7 @@ class ReceiptsHandler:
await self._handle_new_receipts(receipts)
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
async def _handle_new_receipts(self, receipts: List[Receipt]) -> bool:
"""Takes a list of receipts, stores them and informs the notifier."""
min_batch_id: Optional[int] = None
max_batch_id: Optional[int] = None
@ -140,24 +142,45 @@ class ReceiptsHandler:
return True
async def received_client_receipt(
self, room_id: str, receipt_type: str, user_id: str, event_id: str
self,
room_id: str,
receipt_type: str,
user_id: str,
end_event_id: str,
start_event_id: Optional[str] = None,
) -> None:
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
receipt = ReadReceipt(
room_id=room_id,
receipt_type=receipt_type,
user_id=user_id,
event_ids=[event_id],
data={"ts": int(self.clock.time_msec())},
)
if start_event_id:
receipt: Receipt = RangedReadReceipt(
room_id=room_id,
receipt_type=receipt_type,
user_id=user_id,
start_event_id=start_event_id,
end_event_id=end_event_id,
data={"ts": int(self.clock.time_msec())},
)
else:
receipt = ReadReceipt(
room_id=room_id,
receipt_type=receipt_type,
user_id=user_id,
event_ids=[end_event_id],
data={"ts": int(self.clock.time_msec())},
)
is_new = await self._handle_new_receipts([receipt])
if not is_new:
return
if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE:
# XXX How to handle this for a ranged read receipt.
if (
isinstance(receipt, ReadReceipt)
and self.federation_sender
and receipt_type != ReceiptTypes.READ_PRIVATE
):
await self.federation_sender.send_read_receipt(receipt)

View file

@ -71,7 +71,7 @@ class ReadMarkerRestServlet(RestServlet):
room_id,
ReceiptTypes.READ,
user_id=requester.user.to_string(),
event_id=read_event_id,
end_event_id=read_event_id,
)
read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None)
@ -80,7 +80,7 @@ class ReadMarkerRestServlet(RestServlet):
room_id,
ReceiptTypes.READ_PRIVATE,
user_id=requester.user.to_string(),
event_id=read_private_event_id,
end_event_id=read_private_event_id,
)
read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None)

View file

@ -80,7 +80,7 @@ class ReceiptRestServlet(RestServlet):
room_id,
receipt_type,
user_id=requester.user.to_string(),
event_id=event_id,
end_event_id=event_id,
)
return 200, {}

View file

@ -42,7 +42,7 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import JsonDict, ReadReceipt
from synapse.types import JsonDict, RangedReadReceipt, ReadReceipt, Receipt
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -725,7 +725,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
else:
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
async def insert_receipt(self, receipt: ReadReceipt) -> Optional[Tuple[int, int]]:
async def insert_receipt(self, receipt: Receipt) -> Optional[Tuple[int, int]]:
"""Insert a receipt, either from local client or remote server.
Automatically does conversion between linearized and graph
@ -737,19 +737,25 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
assert self._can_write_to_receipts
if not receipt.event_ids:
return None
if isinstance(receipt, ReadReceipt):
event_ids = receipt.event_ids
if not event_ids:
return None
if len(receipt.event_ids) == 1:
linearized_event_id = receipt.event_ids[0]
if len(event_ids) == 1:
linearized_event_id = event_ids[0]
else:
# we need to points in graph -> linearized form.
linearized_event_id = await self.db_pool.runInteraction(
"insert_receipt_conv",
self._graph_to_linear,
receipt.room_id,
event_ids,
)
elif isinstance(receipt, RangedReadReceipt):
linearized_event_id = receipt.end_event_id
else:
# we need to points in graph -> linearized form.
linearized_event_id = await self.db_pool.runInteraction(
"insert_receipt_conv",
self._graph_to_linear,
receipt.room_id,
receipt.event_ids,
)
raise ValueError("Unexpected receipt type: %s", type(receipt))
async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
event_ts = await self.db_pool.runInteraction(
@ -779,15 +785,16 @@ class ReceiptsWorkerStore(SQLBaseStore):
now - event_ts,
)
await self.db_pool.runInteraction(
"insert_graph_receipt",
self._insert_graph_receipt_txn,
receipt.room_id,
receipt.receipt_type,
receipt.user_id,
receipt.event_ids,
receipt.data,
)
# XXX These aren't really used right now, go away.
# await self.db_pool.runInteraction(
# "insert_graph_receipt",
# self._insert_graph_receipt_txn,
# room_id,
# receipt_type,
# user_id,
# event_ids,
# data,
# )
max_persisted_id = self._receipts_id_gen.get_current_token()