diff --git a/synapse/handlers/private_user_data.py b/synapse/handlers/private_user_data.py new file mode 100644 index 0000000000..1778c71325 --- /dev/null +++ b/synapse/handlers/private_user_data.py @@ -0,0 +1,46 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + + +class PrivateUserDataEventSource(object): + def __init__(self, hs): + self.store = hs.get_datastore() + + def get_current_key(self, direction='f'): + return self.store.get_max_private_user_data_stream_id() + + @defer.inlineCallbacks + def get_new_events_for_user(self, user, from_key, limit): + user_id = user.to_string() + last_stream_id = from_key + + current_stream_id = yield self.store.get_max_private_user_data_stream_id() + tags = yield self.store.get_updated_tags(user_id, last_stream_id) + + results = [] + for room_id, room_tags in tags.items(): + results.append({ + "type": "m.tag", + "content": {"tags": room_tags}, + "room_id": room_id, + }) + + defer.returnValue((results, current_stream_id)) + + @defer.inlineCallbacks + def get_pagination_rows(self, user, config, key): + defer.returnValue(([], config.to_id)) diff --git a/synapse/notifier.py b/synapse/notifier.py index f998fc83bf..a78ee3c1e7 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -270,7 +270,7 @@ class Notifier(object): @defer.inlineCallbacks def wait_for_events(self, user, rooms, timeout, callback, - from_token=StreamToken("s0", "0", "0", "0")): + from_token=StreamToken("s0", "0", "0", "0", "0")): """Wait until the callback returns a non empty response or the timeout fires. """ diff --git a/synapse/rest/client/v2_alpha/tags.py b/synapse/rest/client/v2_alpha/tags.py index 15c9347fc1..486add9909 100644 --- a/synapse/rest/client/v2_alpha/tags.py +++ b/synapse/rest/client/v2_alpha/tags.py @@ -62,6 +62,7 @@ class TagServlet(RestServlet): super(TagServlet, self).__init__() self.auth = hs.get_auth() self.store = hs.get_datastore() + self.notifier = hs.get_notifier() @defer.inlineCallbacks def on_PUT(self, request, user_id, room_id, tag): @@ -69,9 +70,12 @@ class TagServlet(RestServlet): if user_id != auth_user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - yield self.store.add_tag_to_room(user_id, room_id, tag) + max_id = yield self.store.add_tag_to_room(user_id, room_id, tag) + + yield self.notifier.on_new_event( + "private_user_data_key", max_id, users=[user_id] + ) - # TODO: poke the notifier. defer.returnValue((200, {})) @defer.inlineCallbacks @@ -80,7 +84,11 @@ class TagServlet(RestServlet): if user_id != auth_user.to_string(): raise AuthError(403, "Cannot add tags for other users.") - yield self.store.remove_tag_from_room(user_id, room_id, tag) + max_id = yield self.store.remove_tag_from_room(user_id, room_id, tag) + + yield self.notifier.on_new_event( + "private_user_data_key", max_id, users=[user_id] + ) # TODO: poke the notifier. defer.returnValue((200, {})) diff --git a/synapse/storage/tags.py b/synapse/storage/tags.py index 64c65fc321..2d5c49144a 100644 --- a/synapse/storage/tags.py +++ b/synapse/storage/tags.py @@ -31,6 +31,14 @@ class TagsStore(SQLBaseStore): "private_user_data_max_stream_id", "stream_id" ) + def get_max_private_user_data_stream_id(self): + """Get the current max stream id for the private user data stream + + Returns: + A deferred int. + """ + return self._private_user_data_id_gen.get_max_token(self) + @cached() def get_tags_for_user(self, user_id): """Get all the tags for a user. @@ -83,7 +91,7 @@ class TagsStore(SQLBaseStore): results = {} if room_ids: - tags_by_room = yield self.get_tags_for_user(self, user_id) + tags_by_room = yield self.get_tags_for_user(user_id) for room_id in room_ids: results[room_id] = tags_by_room[room_id] @@ -129,6 +137,9 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) + result = yield self._private_user_data_id_gen.get_max_token(self) + defer.returnValue(result) + @defer.inlineCallbacks def remove_tag_from_room(self, user_id, room_id, tag): """Remove a tag from a room for a user. @@ -148,6 +159,9 @@ class TagsStore(SQLBaseStore): self.get_tags_for_user.invalidate((user_id,)) + result = yield self._private_user_data_id_gen.get_max_token(self) + defer.returnValue(result) + def _update_revision_txn(self, txn, user_id, room_id, next_id): """Update the latest revision of the tags for the given user and room. diff --git a/synapse/streams/events.py b/synapse/streams/events.py index 699083ae12..f0d68b5bf2 100644 --- a/synapse/streams/events.py +++ b/synapse/streams/events.py @@ -21,6 +21,7 @@ from synapse.handlers.presence import PresenceEventSource from synapse.handlers.room import RoomEventSource from synapse.handlers.typing import TypingNotificationEventSource from synapse.handlers.receipts import ReceiptEventSource +from synapse.handlers.private_user_data import PrivateUserDataEventSource class EventSources(object): @@ -29,6 +30,7 @@ class EventSources(object): "presence": PresenceEventSource, "typing": TypingNotificationEventSource, "receipt": ReceiptEventSource, + "private_user_data": PrivateUserDataEventSource, } def __init__(self, hs): @@ -52,5 +54,8 @@ class EventSources(object): receipt_key=( yield self.sources["receipt"].get_current_key() ), + private_user_data_key=( + yield self.sources["private_user_data"].get_current_key() + ), ) defer.returnValue(token) diff --git a/synapse/types.py b/synapse/types.py index 9cffc33d27..8d3a8d88cc 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -98,10 +98,13 @@ class EventID(DomainSpecificString): class StreamToken( - namedtuple( - "Token", - ("room_key", "presence_key", "typing_key", "receipt_key") - ) + namedtuple("Token", ( + "room_key", + "presence_key", + "typing_key", + "receipt_key", + "private_user_data_key", + )) ): _SEPARATOR = "_" @@ -128,13 +131,14 @@ class StreamToken( else: return int(self.room_key[1:].split("-")[-1]) - def is_after(self, other_token): + def is_after(self, other): """Does this token contain events that the other doesn't?""" return ( - (other_token.room_stream_id < self.room_stream_id) - or (int(other_token.presence_key) < int(self.presence_key)) - or (int(other_token.typing_key) < int(self.typing_key)) - or (int(other_token.receipt_key) < int(self.receipt_key)) + (other.room_stream_id < self.room_stream_id) + or (int(other.presence_key) < int(self.presence_key)) + or (int(other.typing_key) < int(self.typing_key)) + or (int(other.receipt_key) < int(self.receipt_key)) + or (int(other.private_user_data_key) < int(self.private_user_data_key)) ) def copy_and_advance(self, key, new_value):