Merge remote-tracking branch 'origin/develop' into rav/remove_who_forgot_in_room

This commit is contained in:
Richard van der Hoff 2018-07-23 17:15:12 +01:00
commit 4f5cc8e4e7
30 changed files with 480 additions and 199 deletions

1
changelog.d/3520.bugfix Normal file
View file

@ -0,0 +1 @@
Correctly announce deleted devices over federation

0
changelog.d/3562.misc Normal file
View file

1
changelog.d/3572.misc Normal file
View file

@ -0,0 +1 @@
Merge Linearizer and Limiter

0
changelog.d/3577.misc Normal file
View file

1
changelog.d/3579.misc Normal file
View file

@ -0,0 +1 @@
Lazily load state on master process when using workers to reduce DB consumption

1
changelog.d/3581.misc Normal file
View file

@ -0,0 +1 @@
Lazily load state on master process when using workers to reduce DB consumption

1
changelog.d/3582.misc Normal file
View file

@ -0,0 +1 @@
Lazily load state on master process when using workers to reduce DB consumption

View file

@ -65,8 +65,9 @@ class Auth(object):
@defer.inlineCallbacks
def check_from_context(self, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
event, prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
@ -544,7 +545,8 @@ class Auth(object):
@defer.inlineCallbacks
def add_auth_events(self, builder, context):
auth_ids = yield self.compute_auth_events(builder, context.prev_state_ids)
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_ids = yield self.compute_auth_events(builder, prev_state_ids)
auth_events_entries = yield self.store.add_event_hashes(
auth_ids

View file

@ -18,6 +18,8 @@ import logging
import os
import sys
from six import iteritems
from twisted.application import service
from twisted.internet import defer, reactor
from twisted.web.resource import EncodingResourceWrapper, NoResource
@ -442,7 +444,7 @@ def run(hs):
stats["total_nonbridged_users"] = total_nonbridged_users
daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
for name, count in daily_user_type_results.iteritems():
for name, count in iteritems(daily_user_type_results):
stats["daily_user_type_" + name] = count
room_count = yield hs.get_datastore().get_room_count()
@ -453,7 +455,7 @@ def run(hs):
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
r30_results = yield hs.get_datastore().count_r30_users()
for name, count in r30_results.iteritems():
for name, count in iteritems(r30_results):
stats["r30_users_" + name] = count
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()

View file

@ -25,6 +25,8 @@ import subprocess
import sys
import time
from six import iteritems
import yaml
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
@ -173,7 +175,7 @@ def main():
os.environ["SYNAPSE_CACHE_FACTOR"] = str(cache_factor)
cache_factors = config.get("synctl_cache_factors", {})
for cache_name, factor in cache_factors.iteritems():
for cache_name, factor in iteritems(cache_factors):
os.environ["SYNAPSE_CACHE_FACTOR_" + cache_name.upper()] = str(factor)
worker_configfiles = []

View file

@ -13,22 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from six import iteritems
from frozendict import frozendict
from twisted.internet import defer
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
class EventContext(object):
"""
Attributes:
current_state_ids (dict[(str, str), str]):
The current state map including the current event.
(type, state_key) -> event_id
prev_state_ids (dict[(str, str), str]):
The current state map excluding the current event.
(type, state_key) -> event_id
state_group (int|None): state group id, if the state has been stored
as a state group. This is usually only None if e.g. the event is
an outlier.
@ -45,38 +41,77 @@ class EventContext(object):
prev_state_events (?): XXX: is this ever set to anything other than
the empty list?
_current_state_ids (dict[(str, str), str]|None):
The current state map including the current event. None if outlier
or we haven't fetched the state from DB yet.
(type, state_key) -> event_id
_prev_state_ids (dict[(str, str), str]|None):
The current state map excluding the current event. None if outlier
or we haven't fetched the state from DB yet.
(type, state_key) -> event_id
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
been calculated. None if we haven't started calculating yet
_event_type (str): The type of the event the context is associated with.
Only set when state has not been fetched yet.
_event_state_key (str|None): The state_key of the event the context is
associated with. Only set when state has not been fetched yet.
_prev_state_id (str|None): If the event associated with the context is
a state event, then `_prev_state_id` is the event_id of the state
that was replaced.
Only set when state has not been fetched yet.
"""
__slots__ = [
"current_state_ids",
"prev_state_ids",
"state_group",
"rejected",
"prev_group",
"delta_ids",
"prev_state_events",
"app_service",
"_current_state_ids",
"_prev_state_ids",
"_prev_state_id",
"_event_type",
"_event_state_key",
"_fetching_state_deferred",
]
def __init__(self):
# The current state including the current event
self.current_state_ids = None
# The current state excluding the current event
self.prev_state_ids = None
self.state_group = None
self.prev_state_events = []
self.rejected = False
self.app_service = None
@staticmethod
def with_state(state_group, current_state_ids, prev_state_ids,
prev_group=None, delta_ids=None):
context = EventContext()
# The current state including the current event
context._current_state_ids = current_state_ids
# The current state excluding the current event
context._prev_state_ids = prev_state_ids
context.state_group = state_group
context._prev_state_id = None
context._event_type = None
context._event_state_key = None
context._fetching_state_deferred = defer.succeed(None)
# A previously persisted state group and a delta between that
# and this state.
self.prev_group = None
self.delta_ids = None
context.prev_group = prev_group
context.delta_ids = delta_ids
self.prev_state_events = None
return context
self.app_service = None
def serialize(self, event):
@defer.inlineCallbacks
def serialize(self, event, store):
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
@ -92,11 +127,12 @@ class EventContext(object):
# the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state.
if event.is_state():
prev_state_id = self.prev_state_ids.get((event.type, event.state_key))
prev_state_ids = yield self.get_prev_state_ids(store)
prev_state_id = prev_state_ids.get((event.type, event.state_key))
else:
prev_state_id = None
return {
defer.returnValue({
"prev_state_id": prev_state_id,
"event_type": event.type,
"event_state_key": event.state_key if event.is_state() else None,
@ -106,10 +142,9 @@ class EventContext(object):
"delta_ids": _encode_state_dict(self.delta_ids),
"prev_state_events": self.prev_state_events,
"app_service_id": self.app_service.id if self.app_service else None
}
})
@staticmethod
@defer.inlineCallbacks
def deserialize(store, input):
"""Converts a dict that was produced by `serialize` back into a
EventContext.
@ -122,32 +157,100 @@ class EventContext(object):
EventContext
"""
context = EventContext()
context.state_group = input["state_group"]
context.rejected = input["rejected"]
context.prev_group = input["prev_group"]
context.delta_ids = _decode_state_dict(input["delta_ids"])
context.prev_state_events = input["prev_state_events"]
# We use the state_group and prev_state_id stuff to pull the
# current_state_ids out of the DB and construct prev_state_ids.
prev_state_id = input["prev_state_id"]
event_type = input["event_type"]
event_state_key = input["event_state_key"]
context._prev_state_id = input["prev_state_id"]
context._event_type = input["event_type"]
context._event_state_key = input["event_state_key"]
context._fetching_state_deferred = None
context.current_state_ids = yield store.get_state_ids_for_group(
context.state_group,
)
if prev_state_id and event_state_key:
context.prev_state_ids = dict(context.current_state_ids)
context.prev_state_ids[(event_type, event_state_key)] = prev_state_id
else:
context.prev_state_ids = context.current_state_ids
context.state_group = input["state_group"]
context.prev_group = input["prev_group"]
context.delta_ids = _decode_state_dict(input["delta_ids"])
context.rejected = input["rejected"]
context.prev_state_events = input["prev_state_events"]
app_service_id = input["app_service_id"]
if app_service_id:
context.app_service = store.get_app_service_by_id(app_service_id)
defer.returnValue(context)
return context
@defer.inlineCallbacks
def get_current_state_ids(self, store):
"""Gets the current state IDs
Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
is None, which happens when the associated event is an outlier.
"""
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store,
)
yield make_deferred_yieldable(self._fetching_state_deferred)
defer.returnValue(self._current_state_ids)
@defer.inlineCallbacks
def get_prev_state_ids(self, store):
"""Gets the prev state IDs
Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
is None, which happens when the associated event is an outlier.
"""
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(
self._fill_out_state, store,
)
yield make_deferred_yieldable(self._fetching_state_deferred)
defer.returnValue(self._prev_state_ids)
@defer.inlineCallbacks
def _fill_out_state(self, store):
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return
self._current_state_ids = yield store.get_state_ids_for_group(
self.state_group,
)
if self._prev_state_id and self._event_state_key is not None:
self._prev_state_ids = dict(self._current_state_ids)
key = (self._event_type, self._event_state_key)
self._prev_state_ids[key] = self._prev_state_id
else:
self._prev_state_ids = self._current_state_ids
@defer.inlineCallbacks
def update_state(self, state_group, prev_state_ids, current_state_ids,
delta_ids):
"""Replace the state in the context
"""
# We need to make sure we wait for any ongoing fetching of state
# to complete so that the updated state doesn't get clobbered
if self._fetching_state_deferred:
yield make_deferred_yieldable(self._fetching_state_deferred)
self.state_group = state_group
self._prev_state_ids = prev_state_ids
self._current_state_ids = current_state_ids
self.delta_ids = delta_ids
# We need to ensure that that we've marked as having fetched the state
self._fetching_state_deferred = defer.succeed(None)
def _encode_state_dict(state_dict):
@ -159,7 +262,7 @@ def _encode_state_dict(state_dict):
return [
(etype, state_key, v)
for (etype, state_key), v in state_dict.iteritems()
for (etype, state_key), v in iteritems(state_dict)
]

View file

@ -112,8 +112,9 @@ class BaseHandler(object):
guest_access = event.content.get("guest_access", "forbidden")
if guest_access != "can_join":
if context:
current_state_ids = yield context.get_current_state_ids(self.store)
current_state = yield self.store.get_events(
list(context.current_state_ids.values())
list(current_state_ids.values())
)
else:
current_state = yield self.state_handler.get_current_state(

View file

@ -21,8 +21,8 @@ import logging
import sys
import six
from six import iteritems
from six.moves import http_client
from six import iteritems, itervalues
from six.moves import http_client, zip
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json
@ -486,7 +486,10 @@ class FederationHandler(BaseHandler):
# joined the room. Don't bother if the user is just
# changing their profile info.
newly_joined = True
prev_state_id = context.prev_state_ids.get(
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_id = prev_state_ids.get(
(event.type, event.state_key)
)
if prev_state_id:
@ -731,7 +734,7 @@ class FederationHandler(BaseHandler):
"""
joined_users = [
(state_key, int(event.depth))
for (e_type, state_key), event in state.iteritems()
for (e_type, state_key), event in iteritems(state)
if e_type == EventTypes.Member
and event.membership == Membership.JOIN
]
@ -748,7 +751,7 @@ class FederationHandler(BaseHandler):
except Exception:
pass
return sorted(joined_domains.iteritems(), key=lambda d: d[1])
return sorted(joined_domains.items(), key=lambda d: d[1])
curr_domains = get_domains_from_state(curr_state)
@ -811,7 +814,7 @@ class FederationHandler(BaseHandler):
tried_domains = set(likely_domains)
tried_domains.add(self.server_name)
event_ids = list(extremities.iterkeys())
event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = logcontext.preserve_fn(
@ -827,15 +830,15 @@ class FederationHandler(BaseHandler):
states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events(
[e_id for ids in states.itervalues() for e_id in ids.itervalues()],
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False
)
states = {
key: {
k: state_map[e_id]
for k, e_id in state_dict.iteritems()
for k, e_id in iteritems(state_dict)
if e_id in state_map
} for key, state_dict in states.iteritems()
} for key, state_dict in iteritems(states)
}
for e_id, _ in sorted_extremeties_tuple:
@ -1106,10 +1109,12 @@ class FederationHandler(BaseHandler):
user = UserID.from_string(event.state_key)
yield user_joined_room(self.distributor, user, event.room_id)
state_ids = list(context.prev_state_ids.values())
prev_state_ids = yield context.get_prev_state_ids(self.store)
state_ids = list(prev_state_ids.values())
auth_chain = yield self.store.get_auth_chain(state_ids)
state = yield self.store.get_events(list(context.prev_state_ids.values()))
state = yield self.store.get_events(list(prev_state_ids.values()))
defer.returnValue({
"state": list(state.values()),
@ -1515,7 +1520,7 @@ class FederationHandler(BaseHandler):
yield self.store.persist_events(
[
(ev_info["event"], context)
for ev_info, context in itertools.izip(event_infos, contexts)
for ev_info, context in zip(event_infos, contexts)
],
backfilled=backfilled,
)
@ -1635,8 +1640,9 @@ class FederationHandler(BaseHandler):
)
if not auth_events:
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
event, prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
@ -1876,9 +1882,10 @@ class FederationHandler(BaseHandler):
break
if do_resolution:
prev_state_ids = yield context.get_prev_state_ids(self.store)
# 1. Get what we think is the auth chain.
auth_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids
event, prev_state_ids
)
local_auth_chain = yield self.store.get_auth_chain(
auth_ids, include_given=True
@ -1968,21 +1975,35 @@ class FederationHandler(BaseHandler):
k: a.event_id for k, a in iteritems(auth_events)
if k != event_key
}
context.current_state_ids = dict(context.current_state_ids)
context.current_state_ids.update(state_updates)
current_state_ids = yield context.get_current_state_ids(self.store)
current_state_ids = dict(current_state_ids)
current_state_ids.update(state_updates)
if context.delta_ids is not None:
context.delta_ids = dict(context.delta_ids)
context.delta_ids.update(state_updates)
context.prev_state_ids = dict(context.prev_state_ids)
context.prev_state_ids.update({
delta_ids = dict(context.delta_ids)
delta_ids.update(state_updates)
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = dict(prev_state_ids)
prev_state_ids.update({
k: a.event_id for k, a in iteritems(auth_events)
})
context.state_group = yield self.store.store_state_group(
state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=context.prev_group,
delta_ids=context.delta_ids,
current_state_ids=context.current_state_ids,
delta_ids=delta_ids,
current_state_ids=current_state_ids,
)
yield context.update_state(
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
delta_ids=delta_ids,
)
@defer.inlineCallbacks
@ -2222,7 +2243,8 @@ class FederationHandler(BaseHandler):
event.content["third_party_invite"]["signed"]["token"]
)
original_invite = None
original_invite_id = context.prev_state_ids.get(key)
prev_state_ids = yield context.get_prev_state_ids(self.store)
original_invite_id = prev_state_ids.get(key)
if original_invite_id:
original_invite = yield self.store.get_event(
original_invite_id, allow_none=True
@ -2264,7 +2286,8 @@ class FederationHandler(BaseHandler):
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
invite_event_id = context.prev_state_ids.get(
prev_state_ids = yield context.get_prev_state_ids(self.store)
invite_event_id = prev_state_ids.get(
(EventTypes.ThirdPartyInvite, token,)
)

View file

@ -630,7 +630,8 @@ class EventCreationHandler(object):
If so, returns the version of the event in context.
Otherwise, returns None.
"""
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_event_id = prev_state_ids.get((event.type, event.state_key))
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
if not prev_event:
return
@ -752,8 +753,8 @@ class EventCreationHandler(object):
event = builder.build()
logger.debug(
"Created event %s with state: %s",
event.event_id, context.prev_state_ids,
"Created event %s",
event.event_id,
)
defer.returnValue(
@ -806,8 +807,9 @@ class EventCreationHandler(object):
# If we're a worker we need to hit out to the master.
if self.config.worker_app:
yield send_event_to_master(
self.hs.get_clock(),
self.http_client,
clock=self.hs.get_clock(),
store=self.store,
client=self.http_client,
host=self.config.worker_replication_host,
port=self.config.worker_replication_http_port,
requester=requester,
@ -884,9 +886,11 @@ class EventCreationHandler(object):
e.sender == event.sender
)
current_state_ids = yield context.get_current_state_ids(self.store)
state_to_include_ids = [
e_id
for k, e_id in iteritems(context.current_state_ids)
for k, e_id in iteritems(current_state_ids)
if k[0] in self.hs.config.room_invite_state_types
or k == (EventTypes.Member, event.sender)
]
@ -922,8 +926,9 @@ class EventCreationHandler(object):
)
if event.type == EventTypes.Redaction:
prev_state_ids = yield context.get_prev_state_ids(self.store)
auth_events_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids, for_verification=True,
event, prev_state_ids, for_verification=True,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
@ -943,11 +948,13 @@ class EventCreationHandler(object):
"You don't have permission to redact events"
)
if event.type == EventTypes.Create and context.prev_state_ids:
raise AuthError(
403,
"Changing the room create event is forbidden",
)
if event.type == EventTypes.Create:
prev_state_ids = yield context.get_prev_state_ids(self.store)
if prev_state_ids:
raise AuthError(
403,
"Changing the room create event is forbidden",
)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context

View file

@ -201,7 +201,9 @@ class RoomMemberHandler(object):
ratelimit=ratelimit,
)
prev_member_event_id = context.prev_state_ids.get(
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_member_event_id = prev_state_ids.get(
(EventTypes.Member, target.to_string()),
None
)
@ -496,9 +498,10 @@ class RoomMemberHandler(object):
if prev_event is not None:
return
prev_state_ids = yield context.get_prev_state_ids(self.store)
if event.membership == Membership.JOIN:
if requester.is_guest:
guest_can_join = yield self._can_guest_join(context.prev_state_ids)
guest_can_join = yield self._can_guest_join(prev_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
@ -517,7 +520,7 @@ class RoomMemberHandler(object):
ratelimit=ratelimit,
)
prev_member_event_id = context.prev_state_ids.get(
prev_member_event_id = prev_state_ids.get(
(EventTypes.Member, event.state_key),
None
)

View file

@ -274,7 +274,7 @@ class Notifier(object):
logger.exception("Error notifying application services of event")
def on_new_event(self, stream_key, new_token, users=[], rooms=[]):
""" Used to inform listeners that something has happend event wise.
""" Used to inform listeners that something has happened event wise.
Will wake up all listeners for the given users and rooms.
"""

View file

@ -112,7 +112,8 @@ class BulkPushRuleEvaluator(object):
@defer.inlineCallbacks
def _get_power_levels_and_sender_level(self, event, context):
pl_event_id = context.prev_state_ids.get(POWER_KEY)
prev_state_ids = yield context.get_prev_state_ids(self.store)
pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case
@ -120,7 +121,7 @@ class BulkPushRuleEvaluator(object):
auth_events = {POWER_KEY: pl_event}
else:
auth_events_ids = yield self.auth.compute_auth_events(
event, context.prev_state_ids, for_verification=False,
event, prev_state_ids, for_verification=False,
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = {
@ -304,7 +305,7 @@ class RulesForRoom(object):
push_rules_delta_state_cache_metric.inc_hits()
else:
current_state_ids = context.current_state_ids
current_state_ids = yield context.get_current_state_ids(self.store)
push_rules_delta_state_cache_metric.inc_misses()
push_rules_state_size_counter.inc(len(current_state_ids))

View file

@ -34,12 +34,13 @@ logger = logging.getLogger(__name__)
@defer.inlineCallbacks
def send_event_to_master(clock, client, host, port, requester, event, context,
def send_event_to_master(clock, store, client, host, port, requester, event, context,
ratelimit, extra_users):
"""Send event to be handled on the master
Args:
clock (synapse.util.Clock)
store (DataStore)
client (SimpleHttpClient)
host (str): host of master
port (int): port on master listening for HTTP replication
@ -53,11 +54,13 @@ def send_event_to_master(clock, client, host, port, requester, event, context,
host, port, event.event_id,
)
serialized_context = yield context.serialize(event, store)
payload = {
"event": event.get_pdu_json(),
"internal_metadata": event.internal_metadata.get_dict(),
"rejected_reason": event.rejected_reason,
"context": context.serialize(event),
"context": serialized_context,
"requester": requester.serialize(),
"ratelimit": ratelimit,
"extra_users": [u.to_string() for u in extra_users],

View file

@ -18,7 +18,7 @@ import hashlib
import logging
from collections import namedtuple
from six import iteritems, itervalues
from six import iteritems, iterkeys, itervalues
from frozendict import frozendict
@ -203,25 +203,27 @@ class StateHandler(object):
# If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
context = EventContext()
if old_state:
context.prev_state_ids = {
prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
if event.is_state():
context.current_state_ids = dict(context.prev_state_ids)
current_state_ids = dict(prev_state_ids)
key = (event.type, event.state_key)
context.current_state_ids[key] = event.event_id
current_state_ids[key] = event.event_id
else:
context.current_state_ids = context.prev_state_ids
current_state_ids = prev_state_ids
else:
context.current_state_ids = {}
context.prev_state_ids = {}
context.prev_state_events = []
current_state_ids = {}
prev_state_ids = {}
# We don't store state for outliers, so we don't generate a state
# froup for it.
context.state_group = None
# group for it.
context = EventContext.with_state(
state_group=None,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
)
defer.returnValue(context)
@ -230,31 +232,35 @@ class StateHandler(object):
# Let's just correctly fill out the context and create a
# new state group for it.
context = EventContext()
context.prev_state_ids = {
prev_state_ids = {
(s.type, s.state_key): s.event_id for s in old_state
}
if event.is_state():
key = (event.type, event.state_key)
if key in context.prev_state_ids:
replaces = context.prev_state_ids[key]
if key in prev_state_ids:
replaces = prev_state_ids[key]
if replaces != event.event_id: # Paranoia check
event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
current_state_ids = dict(prev_state_ids)
current_state_ids[key] = event.event_id
else:
context.current_state_ids = context.prev_state_ids
current_state_ids = prev_state_ids
context.state_group = yield self.store.store_state_group(
state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=None,
delta_ids=None,
current_state_ids=context.current_state_ids,
current_state_ids=current_state_ids,
)
context = EventContext.with_state(
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
)
context.prev_state_events = []
defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context")
@ -262,47 +268,47 @@ class StateHandler(object):
event.room_id, [e for e, _ in event.prev_events],
)
curr_state = entry.state
prev_state_ids = entry.state
prev_group = None
delta_ids = None
context = EventContext()
context.prev_state_ids = curr_state
if event.is_state():
# If this is a state event then we need to create a new state
# group for the state after this event.
key = (event.type, event.state_key)
if key in context.prev_state_ids:
replaces = context.prev_state_ids[key]
if key in prev_state_ids:
replaces = prev_state_ids[key]
event.unsigned["replaces_state"] = replaces
context.current_state_ids = dict(context.prev_state_ids)
context.current_state_ids[key] = event.event_id
current_state_ids = dict(prev_state_ids)
current_state_ids[key] = event.event_id
if entry.state_group:
# If the state at the event has a state group assigned then
# we can use that as the prev group
context.prev_group = entry.state_group
context.delta_ids = {
prev_group = entry.state_group
delta_ids = {
key: event.event_id
}
elif entry.prev_group:
# If the state at the event only has a prev group, then we can
# use that as a prev group too.
context.prev_group = entry.prev_group
context.delta_ids = dict(entry.delta_ids)
context.delta_ids[key] = event.event_id
prev_group = entry.prev_group
delta_ids = dict(entry.delta_ids)
delta_ids[key] = event.event_id
context.state_group = yield self.store.store_state_group(
state_group = yield self.store.store_state_group(
event.event_id,
event.room_id,
prev_group=context.prev_group,
delta_ids=context.delta_ids,
current_state_ids=context.current_state_ids,
prev_group=prev_group,
delta_ids=delta_ids,
current_state_ids=current_state_ids,
)
else:
context.current_state_ids = context.prev_state_ids
context.prev_group = entry.prev_group
context.delta_ids = entry.delta_ids
current_state_ids = prev_state_ids
prev_group = entry.prev_group
delta_ids = entry.delta_ids
if entry.state_group is None:
entry.state_group = yield self.store.store_state_group(
@ -310,13 +316,20 @@ class StateHandler(object):
event.room_id,
prev_group=entry.prev_group,
delta_ids=entry.delta_ids,
current_state_ids=context.current_state_ids,
current_state_ids=current_state_ids,
)
entry.state_id = entry.state_group
context.state_group = entry.state_group
state_group = entry.state_group
context = EventContext.with_state(
state_group=state_group,
current_state_ids=current_state_ids,
prev_state_ids=prev_state_ids,
prev_group=prev_group,
delta_ids=delta_ids,
)
context.prev_state_events = []
defer.returnValue(context)
@defer.inlineCallbacks
@ -647,7 +660,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
for event_id in event_ids
)
if event_map is not None:
needed_events -= set(event_map.iterkeys())
needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d conflicted events", len(needed_events))
@ -668,7 +681,7 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
new_needed_events = set(itervalues(auth_events))
new_needed_events -= needed_events
if event_map is not None:
new_needed_events -= set(event_map.iterkeys())
new_needed_events -= set(iterkeys(event_map))
logger.info("Asking for %d auth events", len(new_needed_events))

View file

@ -248,17 +248,31 @@ class DeviceStore(SQLBaseStore):
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
content, stream_id):
self._simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"content": json.dumps(content),
}
)
if content.get("deleted"):
self._simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
)
txn.call_after(
self.device_id_exists_cache.invalidate, (user_id, device_id,)
)
else:
self._simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"content": json.dumps(content),
}
)
txn.call_after(self._get_cached_user_device.invalidate, (user_id, device_id,))
txn.call_after(self._get_cached_devices_for_user.invalidate, (user_id,))
@ -366,7 +380,7 @@ class DeviceStore(SQLBaseStore):
now_stream_id = max(stream_id for stream_id in itervalues(query_map))
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True
txn, query_map.keys(), include_all_devices=True, include_deleted_devices=True
)
prev_sent_id_sql = """
@ -393,12 +407,15 @@ class DeviceStore(SQLBaseStore):
prev_id = stream_id
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
if device is not None:
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
else:
result["deleted"] = True
results.append(result)

View file

@ -64,12 +64,18 @@ class EndToEndKeyStore(SQLBaseStore):
)
@defer.inlineCallbacks
def get_e2e_device_keys(self, query_list, include_all_devices=False):
def get_e2e_device_keys(
self, query_list, include_all_devices=False,
include_deleted_devices=False,
):
"""Fetch a list of device keys.
Args:
query_list(list): List of pairs of user_ids and device_ids.
include_all_devices (bool): whether to include entries for devices
that don't have device keys
include_deleted_devices (bool): whether to include null entries for
devices which no longer exist (but were in the query_list).
This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
dict containing "key_json", "device_display_name".
@ -79,7 +85,7 @@ class EndToEndKeyStore(SQLBaseStore):
results = yield self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn,
query_list, include_all_devices,
query_list, include_all_devices, include_deleted_devices,
)
for user_id, device_keys in iteritems(results):
@ -88,10 +94,19 @@ class EndToEndKeyStore(SQLBaseStore):
defer.returnValue(results)
def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
def _get_e2e_device_keys_txn(
self, txn, query_list, include_all_devices=False,
include_deleted_devices=False,
):
query_clauses = []
query_params = []
if include_all_devices is False:
include_deleted_devices = False
if include_deleted_devices:
deleted_devices = set(query_list)
for (user_id, device_id) in query_list:
query_clause = "user_id = ?"
query_params.append(user_id)
@ -119,8 +134,14 @@ class EndToEndKeyStore(SQLBaseStore):
result = {}
for row in rows:
if include_deleted_devices:
deleted_devices.remove((row["user_id"], row["device_id"]))
result.setdefault(row["user_id"], {})[row["device_id"]] = row
if include_deleted_devices:
for user_id, device_id in deleted_devices:
result.setdefault(user_id, {})[device_id] = None
return result
@defer.inlineCallbacks

View file

@ -549,7 +549,7 @@ class EventsStore(EventsWorkerStore):
if ctx.state_group in state_groups_map:
continue
state_groups_map[ctx.state_group] = ctx.current_state_ids
state_groups_map[ctx.state_group] = yield ctx.get_current_state_ids(self)
# We need to map the event_ids to their state groups. First, let's
# check if the event is one we're persisting, in which case we can

View file

@ -185,6 +185,7 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
defer.returnValue(results)
@defer.inlineCallbacks
def bulk_get_push_rules_for_room(self, event, context):
state_group = context.state_group
if not state_group:
@ -194,9 +195,11 @@ class PushRulesWorkerStore(ApplicationServiceWorkerStore,
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._bulk_get_push_rules_for_room(
event.room_id, state_group, context.current_state_ids, event=event
current_state_ids = yield context.get_current_state_ids(self)
result = yield self._bulk_get_push_rules_for_room(
event.room_id, state_group, current_state_ids, event=event
)
defer.returnValue(result)
@cachedInlineCallbacks(num_args=2, cache_context=True)
def _bulk_get_push_rules_for_room(self, room_id, state_group, current_state_ids,

View file

@ -232,6 +232,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
defer.returnValue(user_who_share_room)
@defer.inlineCallbacks
def get_joined_users_from_context(self, event, context):
state_group = context.state_group
if not state_group:
@ -241,11 +242,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# To do this we set the state_group to a new object as object() != object()
state_group = object()
return self._get_joined_users_from_context(
event.room_id, state_group, context.current_state_ids,
current_state_ids = yield context.get_current_state_ids(self)
result = yield self._get_joined_users_from_context(
event.room_id, state_group, current_state_ids,
event=event,
context=context,
)
defer.returnValue(result)
def get_joined_users_from_state(self, room_id, state_entry):
state_group = state_entry.state_group

View file

@ -184,13 +184,13 @@ class Linearizer(object):
# key_to_defer is a map from the key to a 2 element list where
# the first element is the number of things executing, and
# the second element is a deque of deferreds for the things blocked from
# executing.
# the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing.
self.key_to_defer = {}
@defer.inlineCallbacks
def queue(self, key):
entry = self.key_to_defer.setdefault(key, [0, collections.deque()])
entry = self.key_to_defer.setdefault(key, [0, collections.OrderedDict()])
# If the number of things executing is greater than the maximum
# then add a deferred to the list of blocked items
@ -198,12 +198,28 @@ class Linearizer(object):
# this item so that it can continue executing.
if entry[0] >= self.max_count:
new_defer = defer.Deferred()
entry[1].append(new_defer)
entry[1][new_defer] = 1
logger.info(
"Waiting to acquire linearizer lock %r for key %r", self.name, key,
)
yield make_deferred_yieldable(new_defer)
try:
yield make_deferred_yieldable(new_defer)
except Exception as e:
if isinstance(e, CancelledError):
logger.info(
"Cancelling wait for linearizer lock %r for key %r",
self.name, key,
)
else:
logger.warn(
"Unexpected exception waiting for linearizer lock %r for key %r",
self.name, key,
)
# we just have to take ourselves back out of the queue.
del entry[1][new_defer]
raise
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
entry[0] += 1
@ -238,7 +254,7 @@ class Linearizer(object):
entry[0] -= 1
if entry[1]:
next_def = entry[1].popleft()
(next_def, _) = entry[1].popitem(last=False)
# we need to run the next thing in the sentinel context.
with PreserveLoggingContext():

View file

@ -12,11 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import logging
import operator
import six
from six import iteritems, itervalues
from six.moves import map
from twisted.internet import defer
@ -204,7 +205,7 @@ def filter_events_for_client(store, user_id, events, is_peeking=False,
return event
# check each event: gives an iterable[None|EventBase]
filtered_events = itertools.imap(allowed, events)
filtered_events = map(allowed, events)
# remove the None entries
filtered_events = filter(operator.truth, filtered_events)
@ -244,7 +245,7 @@ def filter_events_for_server(store, server_name, events):
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
# into the room.
for ev in state.itervalues():
for ev in itervalues(state):
if ev.type != EventTypes.Member:
continue
try:
@ -278,7 +279,7 @@ def filter_events_for_server(store, server_name, events):
)
visibility_ids = set()
for sids in event_to_state_ids.itervalues():
for sids in itervalues(event_to_state_ids):
hist = sids.get((EventTypes.RoomHistoryVisibility, ""))
if hist:
visibility_ids.add(hist)
@ -291,7 +292,7 @@ def filter_events_for_server(store, server_name, events):
event_map = yield store.get_events(visibility_ids)
all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in event_map.itervalues()
for e in itervalues(event_map)
)
if all_open:
@ -329,7 +330,7 @@ def filter_events_for_server(store, server_name, events):
#
state_key_to_event_id_set = {
e
for key_to_eid in six.itervalues(event_to_state_ids)
for key_to_eid in itervalues(event_to_state_ids)
for e in key_to_eid.items()
}
@ -352,10 +353,10 @@ def filter_events_for_server(store, server_name, events):
event_to_state = {
e_id: {
key: event_map[inner_e_id]
for key, inner_e_id in key_to_eid.iteritems()
for key, inner_e_id in iteritems(key_to_eid)
if inner_e_id in event_map
}
for e_id, key_to_eid in event_to_state_ids.iteritems()
for e_id, key_to_eid in iteritems(event_to_state_ids)
}
defer.returnValue([

View file

@ -222,9 +222,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_ids = {
key: e.event_id for key, e in state.items()
}
context = EventContext()
context.current_state_ids = state_ids
context.prev_state_ids = state_ids
context = EventContext.with_state(
state_group=None,
current_state_ids=state_ids,
prev_state_ids=state_ids
)
else:
state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event)

View file

@ -137,7 +137,6 @@ class MessageAcceptTests(unittest.TestCase):
)
self.assertEqual(self.successResultOf(extrem)[0], "$join:test.serv")
@unittest.DEBUG
def test_cant_hide_past_history(self):
"""
If you send a message, you must be able to provide the direct
@ -178,7 +177,7 @@ class MessageAcceptTests(unittest.TestCase):
for x, y in d.items()
if x == ("m.room.member", "@us:test")
],
"auth_chain_ids": d.values(),
"auth_chain_ids": list(d.values()),
}
)

View file

@ -204,7 +204,8 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
self.assertEqual(2, len(context_store["D"].prev_state_ids))
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
self.assertEqual(2, len(prev_state_ids))
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
@ -255,9 +256,11 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
self.assertSetEqual(
{"START", "A", "C"},
{e_id for e_id in context_store["D"].prev_state_ids.values()}
{e_id for e_id in prev_state_ids.values()}
)
@defer.inlineCallbacks
@ -318,9 +321,11 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
self.assertSetEqual(
{"START", "A", "B", "C"},
{e for e in context_store["E"].prev_state_ids.values()}
{e for e in prev_state_ids.values()}
)
@defer.inlineCallbacks
@ -398,9 +403,11 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"},
{e for e in context_store["D"].prev_state_ids.values()}
{e for e in prev_state_ids.values()}
)
def _add_depths(self, nodes, edges):
@ -429,8 +436,10 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state
)
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(
set(e.event_id for e in old_state), set(context.current_state_ids.values())
set(e.event_id for e in old_state), set(current_state_ids.values())
)
self.assertIsNotNone(context.state_group)
@ -449,8 +458,10 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state
)
prev_state_ids = yield context.get_prev_state_ids(self.store)
self.assertEqual(
set(e.event_id for e in old_state), set(context.prev_state_ids.values())
set(e.event_id for e in old_state), set(prev_state_ids.values())
)
@defer.inlineCallbacks
@ -475,9 +486,11 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event)
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(
set([e.event_id for e in old_state]),
set(context.current_state_ids.values())
set(current_state_ids.values())
)
self.assertEqual(group_name, context.state_group)
@ -504,9 +517,11 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event)
prev_state_ids = yield context.get_prev_state_ids(self.store)
self.assertEqual(
set([e.event_id for e in old_state]),
set(context.prev_state_ids.values())
set(prev_state_ids.values())
)
self.assertIsNotNone(context.state_group)
@ -545,7 +560,9 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
self.assertEqual(len(context.current_state_ids), 6)
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(len(current_state_ids), 6)
self.assertIsNotNone(context.state_group)
@ -585,7 +602,9 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
self.assertEqual(len(context.current_state_ids), 6)
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(len(current_state_ids), 6)
self.assertIsNotNone(context.state_group)
@ -642,8 +661,10 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(
old_state_2[3].event_id, context.current_state_ids[("test1", "1")]
old_state_2[3].event_id, current_state_ids[("test1", "1")]
)
# Reverse the depth to make sure we are actually using the depths
@ -670,8 +691,10 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
)
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(
old_state_1[3].event_id, context.current_state_ids[("test1", "1")]
old_state_1[3].event_id, current_state_ids[("test1", "1")]
)
def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,

View file

@ -17,6 +17,7 @@
from six.moves import range
from twisted.internet import defer, reactor
from twisted.internet.defer import CancelledError
from synapse.util import Clock, logcontext
from synapse.util.async import Linearizer
@ -112,3 +113,33 @@ class LinearizerTestCase(unittest.TestCase):
d6 = limiter.queue(key)
with (yield d6):
pass
@defer.inlineCallbacks
def test_cancellation(self):
linearizer = Linearizer()
key = object()
d1 = linearizer.queue(key)
cm1 = yield d1
d2 = linearizer.queue(key)
self.assertFalse(d2.called)
d3 = linearizer.queue(key)
self.assertFalse(d3.called)
d2.cancel()
with cm1:
pass
self.assertTrue(d2.called)
try:
yield d2
self.fail("Expected d2 to raise CancelledError")
except CancelledError:
pass
with (yield d3):
pass