Correctly handle tags changing in paginated sync

This commit is contained in:
Erik Johnston 2016-06-23 13:43:25 +01:00
parent a90140358b
commit 8c3fca8b28
5 changed files with 127 additions and 4 deletions

View file

@ -22,6 +22,7 @@ from synapse.push.clientformat import format_push_rules_for_user
from synapse.visibility import filter_events_for_client
from synapse.types import SyncNextBatchToken, SyncPaginationState
from synapse.api.errors import Codes, SynapseError
from synapse.storage.tags import (TAG_CHANGE_NEWLY_TAGGED, TAG_CHANGE_ALL_REMOVED)
from twisted.internet import defer
@ -774,12 +775,33 @@ class SyncHandler(object):
all_tags = yield self.store.get_tags_for_user(user_id)
if sync_result_builder.since_token:
stream_id = sync_result_builder.since_token.account_data_key
tag_changes = yield self.store.get_room_tags_changed(user_id, stream_id)
else:
tag_changes = {}
if missing_state:
for r in room_entries:
if r.room_id in missing_state:
if include_all_tags and r.room_id in all_tags:
r.always_include = True
continue
if include_all_tags:
change = tag_changes.get(r.room_id)
if change == TAG_CHANGE_NEWLY_TAGGED:
r.since_token = None
r.always_include = True
r.full_state = True
r.would_require_resync = True
r.events = None
r.synced = True
continue
elif change == TAG_CHANGE_ALL_REMOVED:
r.always_include = True
r.synced = False
continue
elif r.room_id in all_tags:
r.always_include = True
continue
if r.room_id in include_map:
since = include_map[r.room_id].get("since", None)
if since:

View file

@ -51,6 +51,9 @@ class SlavedAccountDataStore(BaseSlavedStore):
get_updated_account_data_for_user = (
DataStore.get_updated_account_data_for_user.__func__
)
get_room_tags_changed = (
DataStore.get_room_tags_changed.__func__
)
def get_max_account_data_stream_id(self):
return self._account_data_id_gen.get_current_token()

View file

@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 32
SCHEMA_VERSION = 33
dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -0,0 +1,24 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE room_tags_change_revisions(
user_id TEXT NOT NULL,
room_id TEXT NOT NULL,
stream_id BIGINT NOT NULL,
change TEXT NOT NULL
);
CREATE INDEX room_tags_change_revisions_rm_idx ON room_tags_change_revisions(user_id, room_id, stream_id);
CREATE INDEX room_tags_change_revisions_idx ON room_tags_change_revisions(user_id, stream_id);

View file

@ -17,12 +17,18 @@ from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached
from twisted.internet import defer
from collections import Counter
import ujson as json
import logging
logger = logging.getLogger(__name__)
TAG_CHANGE_NEWLY_TAGGED = "newly_tagged"
TAG_CHANGE_ALL_REMOVED = "all_removed"
class TagsStore(SQLBaseStore):
def get_max_account_data_stream_id(self):
"""Get the current max stream id for the private user data stream
@ -170,6 +176,39 @@ class TagsStore(SQLBaseStore):
row["tag"]: json.loads(row["content"]) for row in rows
})
def get_room_tags_changed(self, user_id, stream_id):
changed = self._account_data_stream_cache.has_entity_changed(
user_id, int(stream_id)
)
if not changed:
return {}
def _get_room_tags_changed(txn):
txn.execute(
"SELECT room_id, change FROM room_tags_change_revisions"
" WHERE user_id = ? AND stream_id > ?",
(user_id, stream_id)
)
results = Counter()
for room_id, change in txn.fetchall():
if change == TAG_CHANGE_NEWLY_TAGGED:
results[room_id] += 1
elif change == TAG_CHANGE_ALL_REMOVED:
results[room_id] -= 1
else:
logger.warn("Unexpected tag change: %r", change)
return {
room_id: TAG_CHANGE_NEWLY_TAGGED if count > 0 else TAG_CHANGE_ALL_REMOVED
for room_id, count in results.items()
if count
}
return self.runInteraction("get_room_tags_changed", _get_room_tags_changed)
@defer.inlineCallbacks
def add_tag_to_room(self, user_id, room_id, tag, content):
"""Add a tag to a room for a user.
@ -184,6 +223,12 @@ class TagsStore(SQLBaseStore):
content_json = json.dumps(content)
def add_tag_txn(txn, next_id):
txn.execute(
"SELECT count(*) FROM room_tags WHERE user_id = ? AND room_id = ?",
(user_id, room_id),
)
existing_tags, = txn.fetchone()
self._simple_upsert_txn(
txn,
table="room_tags",
@ -197,6 +242,17 @@ class TagsStore(SQLBaseStore):
}
)
self._update_revision_txn(txn, user_id, room_id, next_id)
if not existing_tags:
self._simple_insert_txn(
txn,
table="room_tags_change_revisions",
values={
"user_id": user_id,
"room_id": room_id,
"stream_id": next_id,
"change": TAG_CHANGE_NEWLY_TAGGED,
}
)
with self._account_data_id_gen.get_next() as next_id:
yield self.runInteraction("add_tag", add_tag_txn, next_id)
@ -218,6 +274,24 @@ class TagsStore(SQLBaseStore):
" WHERE user_id = ? AND room_id = ? AND tag = ?"
)
txn.execute(sql, (user_id, room_id, tag))
if txn.rowcount > 0:
txn.execute(
"SELECT count(*) FROM room_tags WHERE user_id = ? AND room_id = ?",
(user_id, room_id),
)
existing_tags, = txn.fetchone()
if not existing_tags:
self._simple_insert_txn(
txn,
table="room_tags_change_revisions",
values={
"user_id": user_id,
"room_id": room_id,
"stream_id": next_id,
"change": TAG_CHANGE_ALL_REMOVED,
}
)
self._update_revision_txn(txn, user_id, room_id, next_id)
with self._account_data_id_gen.get_next() as next_id: