Correctly handle tags changing in paginated sync

This commit is contained in:
Erik Johnston 2016-06-23 13:43:25 +01:00
parent a90140358b
commit 8c3fca8b28
5 changed files with 127 additions and 4 deletions

View file

@ -22,6 +22,7 @@ from synapse.push.clientformat import format_push_rules_for_user
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
from synapse.types import SyncNextBatchToken, SyncPaginationState from synapse.types import SyncNextBatchToken, SyncPaginationState
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.storage.tags import (TAG_CHANGE_NEWLY_TAGGED, TAG_CHANGE_ALL_REMOVED)
from twisted.internet import defer from twisted.internet import defer
@ -774,12 +775,33 @@ class SyncHandler(object):
all_tags = yield self.store.get_tags_for_user(user_id) all_tags = yield self.store.get_tags_for_user(user_id)
if sync_result_builder.since_token:
stream_id = sync_result_builder.since_token.account_data_key
tag_changes = yield self.store.get_room_tags_changed(user_id, stream_id)
else:
tag_changes = {}
if missing_state: if missing_state:
for r in room_entries: for r in room_entries:
if r.room_id in missing_state: if r.room_id in missing_state:
if include_all_tags and r.room_id in all_tags: if include_all_tags:
r.always_include = True change = tag_changes.get(r.room_id)
continue if change == TAG_CHANGE_NEWLY_TAGGED:
r.since_token = None
r.always_include = True
r.full_state = True
r.would_require_resync = True
r.events = None
r.synced = True
continue
elif change == TAG_CHANGE_ALL_REMOVED:
r.always_include = True
r.synced = False
continue
elif r.room_id in all_tags:
r.always_include = True
continue
if r.room_id in include_map: if r.room_id in include_map:
since = include_map[r.room_id].get("since", None) since = include_map[r.room_id].get("since", None)
if since: if since:

View file

@ -51,6 +51,9 @@ class SlavedAccountDataStore(BaseSlavedStore):
get_updated_account_data_for_user = ( get_updated_account_data_for_user = (
DataStore.get_updated_account_data_for_user.__func__ DataStore.get_updated_account_data_for_user.__func__
) )
get_room_tags_changed = (
DataStore.get_room_tags_changed.__func__
)
def get_max_account_data_stream_id(self): def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token() return self._account_data_id_gen.get_current_token()

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 32 SCHEMA_VERSION = 33
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -0,0 +1,24 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE room_tags_change_revisions(
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
stream_id BIGINT NOT NULL,
change TEXT NOT NULL
);
CREATE INDEX room_tags_change_revisions_rm_idx ON room_tags_change_revisions(user_id, room_id, stream_id);
CREATE INDEX room_tags_change_revisions_idx ON room_tags_change_revisions(user_id, stream_id);

View file

@ -17,12 +17,18 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from twisted.internet import defer from twisted.internet import defer
from collections import Counter
import ujson as json import ujson as json
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
TAG_CHANGE_NEWLY_TAGGED = "newly_tagged"
TAG_CHANGE_ALL_REMOVED = "all_removed"
class TagsStore(SQLBaseStore): class TagsStore(SQLBaseStore):
def get_max_account_data_stream_id(self): def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream """Get the current max stream id for the private user data stream
@ -170,6 +176,39 @@ class TagsStore(SQLBaseStore):
row["tag"]: json.loads(row["content"]) for row in rows row["tag"]: json.loads(row["content"]) for row in rows
}) })
def get_room_tags_changed(self, user_id, stream_id):
changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(stream_id)
)
if not changed:
return {}
def _get_room_tags_changed(txn):
txn.execute(
"SELECT room_id, change FROM room_tags_change_revisions"
" WHERE user_id = ? AND stream_id > ?",
(user_id, stream_id)
)
results = Counter()
for room_id, change in txn.fetchall():
if change == TAG_CHANGE_NEWLY_TAGGED:
results[room_id] += 1
elif change == TAG_CHANGE_ALL_REMOVED:
results[room_id] -= 1
else:
logger.warn("Unexpected tag change: %r", change)
return {
room_id: TAG_CHANGE_NEWLY_TAGGED if count > 0 else TAG_CHANGE_ALL_REMOVED
for room_id, count in results.items()
if count
}
return self.runInteraction("get_room_tags_changed", _get_room_tags_changed)
@defer.inlineCallbacks @defer.inlineCallbacks
def add_tag_to_room(self, user_id, room_id, tag, content): def add_tag_to_room(self, user_id, room_id, tag, content):
"""Add a tag to a room for a user. """Add a tag to a room for a user.
@ -184,6 +223,12 @@ class TagsStore(SQLBaseStore):
content_json = json.dumps(content) content_json = json.dumps(content)
def add_tag_txn(txn, next_id): def add_tag_txn(txn, next_id):
txn.execute(
"SELECT count(*) FROM room_tags WHERE user_id = ? AND room_id = ?",
(user_id, room_id),
)
existing_tags, = txn.fetchone()
self._simple_upsert_txn( self._simple_upsert_txn(
txn, txn,
table="room_tags", table="room_tags",
@ -197,6 +242,17 @@ class TagsStore(SQLBaseStore):
} }
) )
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
if not existing_tags:
self._simple_insert_txn(
txn,
table="room_tags_change_revisions",
values={
"user_id": user_id,
"room_id": room_id,
"stream_id": next_id,
"change": TAG_CHANGE_NEWLY_TAGGED,
}
)
with self._account_data_id_gen.get_next() as next_id: with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction("add_tag", add_tag_txn, next_id) yield self.runInteraction("add_tag", add_tag_txn, next_id)
@ -218,6 +274,24 @@ class TagsStore(SQLBaseStore):
" WHERE user_id = ? AND room_id = ? AND tag = ?" " WHERE user_id = ? AND room_id = ? AND tag = ?"
) )
txn.execute(sql, (user_id, room_id, tag)) txn.execute(sql, (user_id, room_id, tag))
if txn.rowcount > 0:
txn.execute(
"SELECT count(*) FROM room_tags WHERE user_id = ? AND room_id = ?",
(user_id, room_id),
)
existing_tags, = txn.fetchone()
if not existing_tags:
self._simple_insert_txn(
txn,
table="room_tags_change_revisions",
values={
"user_id": user_id,
"room_id": room_id,
"stream_id": next_id,
"change": TAG_CHANGE_ALL_REMOVED,
}
)
self._update_revision_txn(txn, user_id, room_id, next_id) self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id: with self._account_data_id_gen.get_next() as next_id: