diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 685792dbdc..dc5b6ef79d 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -32,6 +32,7 @@ from .appservice import ApplicationServicesHandler from .sync import SyncHandler from .auth import AuthHandler from .identity import IdentityHandler +from .receipts import ReceiptsHandler class Handlers(object): @@ -57,6 +58,7 @@ class Handlers(object): self.directory_handler = DirectoryHandler(hs) self.typing_notification_handler = TypingNotificationHandler(hs) self.admin_handler = AdminHandler(hs) + self.receipts_handler = ReceiptsHandler(hs) asapi = ApplicationServiceApi(hs) self.appservice_handler = ApplicationServicesHandler( hs, asapi, AppServiceScheduler( diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index d8b117612d..9d6d4f0978 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -334,6 +334,11 @@ class MessageHandler(BaseHandler): user, pagination_config.get_source_config("presence"), None ) + receipt_stream = self.hs.get_event_sources().sources["receipt"] + receipt, _ = yield receipt_stream.get_pagination_rows( + user, pagination_config.get_source_config("receipt"), None + ) + public_room_ids = yield self.store.get_public_room_ids() limit = pagin_config.limit @@ -404,7 +409,8 @@ class MessageHandler(BaseHandler): ret = { "rooms": rooms_ret, "presence": presence, - "end": now_token.to_string() + "receipts": receipt, + "end": now_token.to_string(), } defer.returnValue(ret) @@ -465,9 +471,12 @@ class MessageHandler(BaseHandler): defer.returnValue([p for success, p in presence_defs if success]) - presence, (messages, token) = yield defer.gatherResults( + receipts_handler = self.hs.get_handlers().receipts_handler + + presence, receipts, (messages, token) = yield defer.gatherResults( [ get_presence(), + receipts_handler.get_receipts_for_room(room_id, now_token.receipt_key), self.store.get_recent_events_for_room( room_id, limit=limit, @@ -495,5 +504,6 @@ class MessageHandler(BaseHandler): "end": end_token.to_string(), }, "state": state, - "presence": presence + "presence": presence, + "receipts": receipts, }) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 7c03198313..341a516da2 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -992,7 +992,7 @@ class PresenceHandler(BaseHandler): room_ids([str]): List of room_ids to notify. """ with PreserveLoggingContext(): - self.notifier.on_new_user_event( + self.notifier.on_new_event( "presence_key", self._user_cachemap_latest_serial, users_to_push, diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py new file mode 100644 index 0000000000..5b3df6932b --- /dev/null +++ b/synapse/handlers/receipts.py @@ -0,0 +1,207 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseHandler + +from twisted.internet import defer + +from synapse.util.logcontext import PreserveLoggingContext + +import logging + + +logger = logging.getLogger(__name__) + + +class ReceiptsHandler(BaseHandler): + def __init__(self, hs): + super(ReceiptsHandler, self).__init__(hs) + + self.hs = hs + self.federation = hs.get_replication_layer() + self.federation.register_edu_handler( + "m.receipt", self._received_remote_receipt + ) + self.clock = self.hs.get_clock() + + self._receipt_cache = None + + @defer.inlineCallbacks + def received_client_receipt(self, room_id, receipt_type, user_id, + event_id): + """Called when a client tells us a local user has read up to the given + event_id in the room. + """ + receipt = { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "event_ids": [event_id], + "data": { + "ts": int(self.clock.time_msec()), + } + } + + is_new = yield self._handle_new_receipts([receipt]) + + if is_new: + self._push_remotes([receipt]) + + @defer.inlineCallbacks + def _received_remote_receipt(self, origin, content): + """Called when we receive an EDU of type m.receipt from a remote HS. + """ + receipts = [ + { + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "event_ids": user_values["event_ids"], + "data": user_values.get("data", {}), + } + for room_id, room_values in content.items() + for receipt_type, users in room_values.items() + for user_id, user_values in users.items() + ] + + yield self._handle_new_receipts(receipts) + + @defer.inlineCallbacks + def _handle_new_receipts(self, receipts): + """Takes a list of receipts, stores them and informs the notifier. + """ + for receipt in receipts: + room_id = receipt["room_id"] + receipt_type = receipt["receipt_type"] + user_id = receipt["user_id"] + event_ids = receipt["event_ids"] + data = receipt["data"] + + res = yield self.store.insert_receipt( + room_id, receipt_type, user_id, event_ids, data + ) + + if not res: + # res will be None if this read receipt is 'old' + defer.returnValue(False) + + stream_id, max_persisted_id = res + + with PreserveLoggingContext(): + self.notifier.on_new_event( + "receipt_key", max_persisted_id, rooms=[room_id] + ) + + defer.returnValue(True) + + @defer.inlineCallbacks + def _push_remotes(self, receipts): + """Given a list of receipts, works out which remote servers should be + poked and pokes them. + """ + # TODO: Some of this stuff should be coallesced. + for receipt in receipts: + room_id = receipt["room_id"] + receipt_type = receipt["receipt_type"] + user_id = receipt["user_id"] + event_ids = receipt["event_ids"] + data = receipt["data"] + + remotedomains = set() + + rm_handler = self.hs.get_handlers().room_member_handler + yield rm_handler.fetch_room_distributions_into( + room_id, localusers=None, remotedomains=remotedomains + ) + + logger.debug("Sending receipt to: %r", remotedomains) + + for domain in remotedomains: + self.federation.send_edu( + destination=domain, + edu_type="m.receipt", + content={ + room_id: { + receipt_type: { + user_id: { + "event_ids": event_ids, + "data": data, + } + } + }, + }, + ) + + @defer.inlineCallbacks + def get_receipts_for_room(self, room_id, to_key): + """Gets all receipts for a room, upto the given key. + """ + result = yield self.store.get_linearized_receipts_for_room( + room_id, + to_key=to_key, + ) + + if not result: + defer.returnValue([]) + + event = { + "type": "m.receipt", + "room_id": room_id, + "content": result, + } + + defer.returnValue([event]) + + +class ReceiptEventSource(object): + def __init__(self, hs): + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def get_new_events_for_user(self, user, from_key, limit): + from_key = int(from_key) + to_key = yield self.get_current_key() + + rooms = yield self.store.get_rooms_for_user(user.to_string()) + rooms = [room.room_id for room in rooms] + events = yield self.store.get_linearized_receipts_for_rooms( + rooms, + from_key=from_key, + to_key=to_key, + ) + + defer.returnValue((events, to_key)) + + def get_current_key(self, direction='f'): + return self.store.get_max_receipt_stream_id() + + @defer.inlineCallbacks + def get_pagination_rows(self, user, config, key): + to_key = int(config.from_key) + + if config.to_key: + from_key = int(config.to_key) + else: + from_key = None + + rooms = yield self.store.get_rooms_for_user(user.to_string()) + rooms = [room.room_id for room in rooms] + events = yield self.store.get_linearized_receipts_for_rooms( + rooms, + from_key=from_key, + to_key=to_key, + ) + + defer.returnValue((events, to_key)) diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index a9895292c2..026bd2b9d4 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -218,7 +218,7 @@ class TypingNotificationHandler(BaseHandler): self._room_serials[room_id] = self._latest_room_serial with PreserveLoggingContext(): - self.notifier.on_new_user_event( + self.notifier.on_new_event( "typing_key", self._latest_room_serial, rooms=[room_id] ) diff --git a/synapse/notifier.py b/synapse/notifier.py index bdd03dcbe8..85ae343135 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -221,16 +221,7 @@ class Notifier(object): event ) - room_id = event.room_id - - room_user_streams = self.room_to_user_streams.get(room_id, set()) - - user_streams = room_user_streams.copy() - - for user in extra_users: - user_stream = self.user_to_user_stream.get(str(user)) - if user_stream is not None: - user_streams.add(user_stream) + app_streams = set() for appservice in self.appservice_to_user_streams: # TODO (kegan): Redundant appservice listener checks? @@ -242,24 +233,20 @@ class Notifier(object): app_user_streams = self.appservice_to_user_streams.get( appservice, set() ) - user_streams |= app_user_streams + app_streams |= app_user_streams - logger.debug("on_new_room_event listeners %s", user_streams) - - time_now_ms = self.clock.time_msec() - for user_stream in user_streams: - try: - user_stream.notify( - "room_key", "s%d" % (room_stream_id,), time_now_ms - ) - except: - logger.exception("Failed to notify listener") + self.on_new_event( + "room_key", room_stream_id, + users=extra_users, + rooms=[event.room_id], + extra_streams=app_streams, + ) @defer.inlineCallbacks @log_function - def on_new_user_event(self, stream_key, new_token, users=[], rooms=[]): - """ Used to inform listeners that something has happend - presence/user event wise. + def on_new_event(self, stream_key, new_token, users=[], rooms=[], + extra_streams=set()): + """ Used to inform listeners that something has happend event wise. Will wake up all listeners for the given users and rooms. """ @@ -283,7 +270,7 @@ class Notifier(object): @defer.inlineCallbacks def wait_for_events(self, user, rooms, timeout, callback, - from_token=StreamToken("s0", "0", "0")): + from_token=StreamToken("s0", "0", "0", "0")): """Wait until the callback returns a non empty response or the timeout fires. """ diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 17587170c8..115bee8c41 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -31,6 +31,7 @@ REQUIREMENTS = { "pillow": ["PIL"], "pydenticon": ["pydenticon"], "ujson": ["ujson"], + "blist": ["blist"], "pysaml2": ["saml2"], } CONDITIONAL_REQUIREMENTS = { diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index c3323d2a8a..33f961e898 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -19,6 +19,7 @@ from . import ( account, register, auth, + receipts, keys, ) @@ -39,4 +40,5 @@ class ClientV2AlphaRestResource(JsonResource): account.register_servlets(hs, client_resource) register.register_servlets(hs, client_resource) auth.register_servlets(hs, client_resource) + receipts.register_servlets(hs, client_resource) keys.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/receipts.py b/synapse/rest/client/v2_alpha/receipts.py new file mode 100644 index 0000000000..40406e2ede --- /dev/null +++ b/synapse/rest/client/v2_alpha/receipts.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import RestServlet +from ._base import client_v2_pattern + +import logging + + +logger = logging.getLogger(__name__) + + +class ReceiptRestServlet(RestServlet): + PATTERN = client_v2_pattern( + "/rooms/(?P[^/]*)" + "/receipt/(?P[^/]*)" + "/(?P[^/]*)$" + ) + + def __init__(self, hs): + super(ReceiptRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.receipts_handler = hs.get_handlers().receipts_handler + + @defer.inlineCallbacks + def on_POST(self, request, room_id, receipt_type, event_id): + user, client = yield self.auth.get_user_by_req(request) + + yield self.receipts_handler.received_client_receipt( + room_id, + receipt_type, + user_id=user.to_string(), + event_id=event_id + ) + + defer.returnValue((200, {})) + + +def register_servlets(hs, http_server): + ReceiptRestServlet(hs).register(http_server) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e089d81675..71d5d92500 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -39,6 +39,8 @@ from .signatures import SignatureStore from .filtering import FilteringStore from .end_to_end_keys import EndToEndKeyStore +from .receipts import ReceiptsStore + import fnmatch import imp @@ -75,6 +77,7 @@ class DataStore(RoomMemberStore, RoomStore, PushRuleStore, ApplicationServiceTransactionStore, EventsStore, + ReceiptsStore, EndToEndKeyStore, ): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8d33def6c6..8f812f0fd7 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -329,13 +329,14 @@ class SQLBaseStore(object): self.database_engine = hs.database_engine - self._stream_id_gen = StreamIdGenerator() + self._stream_id_gen = StreamIdGenerator("events", "stream_ordering") self._transaction_id_gen = IdGenerator("sent_transactions", "id", self) self._state_groups_id_gen = IdGenerator("state_groups", "id", self) self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self) self._pushers_id_gen = IdGenerator("pushers", "id", self) self._push_rule_id_gen = IdGenerator("push_rules", "id", self) self._push_rules_enable_id_gen = IdGenerator("push_rules_enable", "id", self) + self._receipts_id_gen = StreamIdGenerator("receipts_linearized", "stream_id") def start_profiling(self): self._previous_loop_ts = self._clock.time_msec() diff --git a/synapse/storage/receipts.py b/synapse/storage/receipts.py new file mode 100644 index 0000000000..d515a0a15c --- /dev/null +++ b/synapse/storage/receipts.py @@ -0,0 +1,348 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore, cached + +from twisted.internet import defer + +from synapse.util import unwrapFirstError + +from blist import sorteddict +import logging +import ujson as json + + +logger = logging.getLogger(__name__) + + +class ReceiptsStore(SQLBaseStore): + def __init__(self, hs): + super(ReceiptsStore, self).__init__(hs) + + self._receipts_stream_cache = _RoomStreamChangeCache() + + @defer.inlineCallbacks + def get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + """Get receipts for multiple rooms for sending to clients. + + Args: + room_ids (list): List of room_ids. + to_key (int): Max stream id to fetch receipts upto. + from_key (int): Min stream id to fetch receipts from. None fetches + from the start. + + Returns: + list: A list of receipts. + """ + room_ids = set(room_ids) + + if from_key: + room_ids = yield self._receipts_stream_cache.get_rooms_changed( + self, room_ids, from_key + ) + + results = yield defer.gatherResults( + [ + self.get_linearized_receipts_for_room( + room_id, to_key, from_key=from_key + ) + for room_id in room_ids + ], + consumeErrors=True, + ).addErrback(unwrapFirstError) + + defer.returnValue([ev for res in results for ev in res]) + + @defer.inlineCallbacks + def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): + """Get receipts for a single room for sending to clients. + + Args: + room_ids (str): The room id. + to_key (int): Max stream id to fetch receipts upto. + from_key (int): Min stream id to fetch receipts from. None fetches + from the start. + + Returns: + list: A list of receipts. + """ + def f(txn): + if from_key: + sql = ( + "SELECT * FROM receipts_linearized WHERE" + " room_id = ? AND stream_id > ? AND stream_id <= ?" + ) + + txn.execute( + sql, + (room_id, from_key, to_key) + ) + else: + sql = ( + "SELECT * FROM receipts_linearized WHERE" + " room_id = ? AND stream_id <= ?" + ) + + txn.execute( + sql, + (room_id, to_key) + ) + + rows = self.cursor_to_dict(txn) + + return rows + + rows = yield self.runInteraction( + "get_linearized_receipts_for_room", f + ) + + if not rows: + defer.returnValue([]) + + content = {} + for row in rows: + content.setdefault( + row["event_id"], {} + ).setdefault( + row["receipt_type"], {} + )[row["user_id"]] = json.loads(row["data"]) + + defer.returnValue([{ + "type": "m.receipt", + "room_id": room_id, + "content": content, + }]) + + def get_max_receipt_stream_id(self): + return self._receipts_id_gen.get_max_token(self) + + @cached + @defer.inlineCallbacks + def get_graph_receipts_for_room(self, room_id): + """Get receipts for sending to remote servers. + """ + rows = yield self._simple_select_list( + table="receipts_graph", + keyvalues={"room_id": room_id}, + retcols=["receipt_type", "user_id", "event_id"], + desc="get_linearized_receipts_for_room", + ) + + result = {} + for row in rows: + result.setdefault( + row["user_id"], {} + ).setdefault( + row["receipt_type"], [] + ).append(row["event_id"]) + + defer.returnValue(result) + + def insert_linearized_receipt_txn(self, txn, room_id, receipt_type, + user_id, event_id, data, stream_id): + + # We don't want to clobber receipts for more recent events, so we + # have to compare orderings of existing receipts + sql = ( + "SELECT topological_ordering, stream_ordering, event_id FROM events" + " INNER JOIN receipts_linearized as r USING (event_id, room_id)" + " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" + ) + + txn.execute(sql, (room_id, receipt_type, user_id)) + results = txn.fetchall() + + if results: + res = self._simple_select_one_txn( + txn, + table="events", + retcols=["topological_ordering", "stream_ordering"], + keyvalues={"event_id": event_id}, + ) + topological_ordering = int(res["topological_ordering"]) + stream_ordering = int(res["stream_ordering"]) + + for to, so, _ in results: + if int(to) > topological_ordering: + return False + elif int(to) == topological_ordering and int(so) >= stream_ordering: + return False + + self._simple_delete_txn( + txn, + table="receipts_linearized", + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + ) + + self._simple_insert_txn( + txn, + table="receipts_linearized", + values={ + "stream_id": stream_id, + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "event_id": event_id, + "data": json.dumps(data), + } + ) + + return True + + @defer.inlineCallbacks + def insert_receipt(self, room_id, receipt_type, user_id, event_ids, data): + """Insert a receipt, either from local client or remote server. + + Automatically does conversion between linearized and graph + representations. + """ + if not event_ids: + return + + if len(event_ids) == 1: + linearized_event_id = event_ids[0] + else: + # we need to points in graph -> linearized form. + # TODO: Make this better. + def graph_to_linear(txn): + query = ( + "SELECT event_id WHERE room_id = ? AND stream_ordering IN (" + " SELECT max(stream_ordering) WHERE event_id IN (%s)" + ")" + ) % (",".join(["?"] * len(event_ids))) + + txn.execute(query, [room_id] + event_ids) + rows = txn.fetchall() + if rows: + return rows[0][0] + else: + raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,)) + + linearized_event_id = yield self.runInteraction( + "insert_receipt_conv", graph_to_linear + ) + + stream_id_manager = yield self._receipts_id_gen.get_next(self) + with stream_id_manager as stream_id: + yield self._receipts_stream_cache.room_has_changed( + self, room_id, stream_id + ) + have_persisted = yield self.runInteraction( + "insert_linearized_receipt", + self.insert_linearized_receipt_txn, + room_id, receipt_type, user_id, linearized_event_id, + data, + stream_id=stream_id, + ) + + if not have_persisted: + defer.returnValue(None) + + yield self.insert_graph_receipt( + room_id, receipt_type, user_id, event_ids, data + ) + + max_persisted_id = yield self._stream_id_gen.get_max_token(self) + defer.returnValue((stream_id, max_persisted_id)) + + def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, + data): + return self.runInteraction( + "insert_graph_receipt", + self.insert_graph_receipt_txn, + room_id, receipt_type, user_id, event_ids, data + ) + + def insert_graph_receipt_txn(self, txn, room_id, receipt_type, + user_id, event_ids, data): + self._simple_delete_txn( + txn, + table="receipts_graph", + keyvalues={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + } + ) + self._simple_insert_txn( + txn, + table="receipts_graph", + values={ + "room_id": room_id, + "receipt_type": receipt_type, + "user_id": user_id, + "event_ids": json.dumps(event_ids), + "data": json.dumps(data), + } + ) + + +class _RoomStreamChangeCache(object): + """Keeps track of the stream_id of the latest change in rooms. + + Given a list of rooms and stream key, it will give a subset of rooms that + may have changed since that key. If the key is too old then the cache + will simply return all rooms. + """ + def __init__(self, size_of_cache=1000): + self._size_of_cache = size_of_cache + self._room_to_key = {} + self._cache = sorteddict() + self._earliest_key = None + + @defer.inlineCallbacks + def get_rooms_changed(self, store, room_ids, key): + """Returns subset of room ids that have had new receipts since the + given key. If the key is too old it will just return the given list. + """ + if key > (yield self._get_earliest_key(store)): + keys = self._cache.keys() + i = keys.bisect_right(key) + + result = set( + self._cache[k] for k in keys[i:] + ).intersection(room_ids) + else: + result = room_ids + + defer.returnValue(result) + + @defer.inlineCallbacks + def room_has_changed(self, store, room_id, key): + """Informs the cache that the room has been changed at the given key. + """ + if key > (yield self._get_earliest_key(store)): + old_key = self._room_to_key.get(room_id, None) + if old_key: + key = max(key, old_key) + self._cache.pop(old_key, None) + self._cache[key] = room_id + + while len(self._cache) > self._size_of_cache: + k, r = self._cache.popitem() + self._earliest_key = max(k, self._earliest_key) + self._room_to_key.pop(r, None) + + @defer.inlineCallbacks + def _get_earliest_key(self, store): + if self._earliest_key is None: + self._earliest_key = yield store.get_max_receipt_stream_id() + self._earliest_key = int(self._earliest_key) + + defer.returnValue(self._earliest_key) diff --git a/synapse/storage/schema/delta/21/receipts.sql b/synapse/storage/schema/delta/21/receipts.sql new file mode 100644 index 0000000000..2f64d609fc --- /dev/null +++ b/synapse/storage/schema/delta/21/receipts.sql @@ -0,0 +1,38 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +CREATE TABLE IF NOT EXISTS receipts_graph( + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_ids TEXT NOT NULL, + data TEXT NOT NULL, + CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id) +); + +CREATE TABLE IF NOT EXISTS receipts_linearized ( + stream_id BIGINT NOT NULL, + room_id TEXT NOT NULL, + receipt_type TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL, + data TEXT NOT NULL, + CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id) +); + +CREATE INDEX receipts_linearized_id ON receipts_linearized( + stream_id +); diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index 83eab63098..e956df62c7 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -72,7 +72,10 @@ class StreamIdGenerator(object): with stream_id_gen.get_next_txn(txn) as stream_id: # ... persist event ... """ - def __init__(self): + def __init__(self, table, column): + self.table = table + self.column = column + self._lock = threading.Lock() self._current_max = None @@ -157,7 +160,7 @@ class StreamIdGenerator(object): def _get_or_compute_current_max(self, txn): with self._lock: - txn.execute("SELECT MAX(stream_ordering) FROM events") + txn.execute("SELECT MAX(%s) FROM %s" % (self.column, self.table)) rows = txn.fetchall() val, = rows[0] diff --git a/synapse/streams/events.py b/synapse/streams/events.py index dff7970bea..aaa3609aa5 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -20,6 +20,7 @@ from synapse.types import StreamToken from synapse.handlers.presence import PresenceEventSource from synapse.handlers.room import RoomEventSource from synapse.handlers.typing import TypingNotificationEventSource +from synapse.handlers.receipts import ReceiptEventSource class NullSource(object): @@ -43,6 +44,7 @@ class EventSources(object): "room": RoomEventSource, "presence": PresenceEventSource, "typing": TypingNotificationEventSource, + "receipt": ReceiptEventSource, } def __init__(self, hs): @@ -62,7 +64,10 @@ class EventSources(object): ), typing_key=( yield self.sources["typing"].get_current_key() - ) + ), + receipt_key=( + yield self.sources["receipt"].get_current_key() + ), ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index 1b21160c57..dd1b10d646 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -100,7 +100,7 @@ class EventID(DomainSpecificString): class StreamToken( namedtuple( "Token", - ("room_key", "presence_key", "typing_key") + ("room_key", "presence_key", "typing_key", "receipt_key") ) ): _SEPARATOR = "_" @@ -109,6 +109,9 @@ class StreamToken( def from_string(cls, string): try: keys = string.split(cls._SEPARATOR) + if len(keys) == len(cls._fields) - 1: + # i.e. old token from before receipt_key + keys.append("0") return cls(*keys) except: raise SynapseError(400, "Invalid Token") @@ -131,6 +134,7 @@ class StreamToken( (other_token.room_stream_id < self.room_stream_id) or (int(other_token.presence_key) < int(self.presence_key)) or (int(other_token.typing_key) < int(self.typing_key)) + or (int(other_token.receipt_key) < int(self.receipt_key)) ) def copy_and_advance(self, key, new_value): diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 7ccbe2ea9c..41bb08b7ca 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -66,8 +66,8 @@ class TypingNotificationsTestCase(unittest.TestCase): self.mock_federation_resource = MockHttpResource() - mock_notifier = Mock(spec=["on_new_user_event"]) - self.on_new_user_event = mock_notifier.on_new_user_event + mock_notifier = Mock(spec=["on_new_event"]) + self.on_new_event = mock_notifier.on_new_event self.auth = Mock(spec=[]) @@ -182,7 +182,7 @@ class TypingNotificationsTestCase(unittest.TestCase): timeout=20000, ) - self.on_new_user_event.assert_has_calls([ + self.on_new_event.assert_has_calls([ call('typing_key', 1, rooms=[self.room_id]), ]) @@ -245,7 +245,7 @@ class TypingNotificationsTestCase(unittest.TestCase): ) ) - self.on_new_user_event.assert_has_calls([ + self.on_new_event.assert_has_calls([ call('typing_key', 1, rooms=[self.room_id]), ]) @@ -299,7 +299,7 @@ class TypingNotificationsTestCase(unittest.TestCase): room_id=self.room_id, ) - self.on_new_user_event.assert_has_calls([ + self.on_new_event.assert_has_calls([ call('typing_key', 1, rooms=[self.room_id]), ]) @@ -331,10 +331,10 @@ class TypingNotificationsTestCase(unittest.TestCase): timeout=10000, ) - self.on_new_user_event.assert_has_calls([ + self.on_new_event.assert_has_calls([ call('typing_key', 1, rooms=[self.room_id]), ]) - self.on_new_user_event.reset_mock() + self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 1) events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) @@ -351,7 +351,7 @@ class TypingNotificationsTestCase(unittest.TestCase): self.clock.advance_time(11) - self.on_new_user_event.assert_has_calls([ + self.on_new_event.assert_has_calls([ call('typing_key', 2, rooms=[self.room_id]), ]) @@ -377,10 +377,10 @@ class TypingNotificationsTestCase(unittest.TestCase): timeout=10000, ) - self.on_new_user_event.assert_has_calls([ + self.on_new_event.assert_has_calls([ call('typing_key', 3, rooms=[self.room_id]), ]) - self.on_new_user_event.reset_mock() + self.on_new_event.reset_mock() self.assertEquals(self.event_source.get_current_key(), 3) events = yield self.event_source.get_new_events_for_user(self.u_apple, 0, None) diff --git a/tests/rest/client/v1/test_events.py b/tests/rest/client/v1/test_events.py index 445272e323..ac3b0b58ac 100644 --- a/tests/rest/client/v1/test_events.py +++ b/tests/rest/client/v1/test_events.py @@ -183,7 +183,17 @@ class EventStreamPermissionsTestCase(RestTestCase): ) self.assertEquals(200, code, msg=str(response)) - self.assertEquals(0, len(response["chunk"])) + # We may get a presence event for ourselves down + self.assertEquals( + 0, + len([ + c for c in response["chunk"] + if not ( + c.get("type") == "m.presence" + and c["content"].get("user_id") == self.user_id + ) + ]) + ) # joined room (expect all content for room) yield self.join(room=room_id, user=self.user_id, tok=self.token) diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 4b32c7a203..089a71568c 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -357,7 +357,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): # all be ours # I'll already get my own presence state change - self.assertEquals({"start": "0_1_0", "end": "0_1_0", "chunk": []}, + self.assertEquals({"start": "0_1_0_0", "end": "0_1_0_0", "chunk": []}, response ) @@ -376,7 +376,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): "/events?from=s0_1_0&timeout=0", None) self.assertEquals(200, code) - self.assertEquals({"start": "s0_1_0", "end": "s0_2_0", "chunk": [ + self.assertEquals({"start": "s0_1_0_0", "end": "s0_2_0_0", "chunk": [ {"type": "m.presence", "content": { "user_id": "@banana:test",