Merge pull request #1857 from matrix-org/erikj/device_list_stream

Implement device lists updates over federation
This commit is contained in:
Erik Johnston 2017-01-30 14:35:21 +00:00 committed by GitHub
commit 9636b2407d
27 changed files with 1033 additions and 121 deletions

View file

@ -30,6 +30,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.storage.engines import create_engine from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState from synapse.storage.presence import UserPresenceState
from synapse.util.async import sleep from synapse.util.async import sleep
@ -56,7 +57,7 @@ logger = logging.getLogger("synapse.app.appservice")
class FederationSenderSlaveStore( class FederationSenderSlaveStore(
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore, SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
SlavedRegistrationStore, SlavedRegistrationStore, SlavedDeviceStore,
): ):
pass pass

View file

@ -39,6 +39,7 @@ from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.room import RoomStore from synapse.replication.slave.storage.room import RoomStore
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore from synapse.storage.client_ips import ClientIpStore
@ -77,6 +78,7 @@ class SynchrotronSlavedStore(
SlavedFilteringStore, SlavedFilteringStore,
SlavedPresenceStore, SlavedPresenceStore,
SlavedDeviceInboxStore, SlavedDeviceInboxStore,
SlavedDeviceStore,
RoomStore, RoomStore,
BaseSlavedStore, BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different ClientIpStore, # After BaseSlavedStore because the constructor is different
@ -380,6 +382,27 @@ class SynchrotronServer(HomeServer):
stream_key, position, users=users, rooms=rooms stream_key, position, users=users, rooms=rooms
) )
@defer.inlineCallbacks
def notify_device_list_update(result):
stream = result.get("device_lists")
if not stream:
return
position_index = stream["field_names"].index("position")
user_index = stream["field_names"].index("user_id")
for row in stream["rows"]:
position = row[position_index]
user_id = row[user_index]
rooms = yield store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
@defer.inlineCallbacks
def notify(result): def notify(result):
stream = result.get("events") stream = result.get("events")
if stream: if stream:
@ -417,6 +440,7 @@ class SynchrotronServer(HomeServer):
notify_from_stream( notify_from_stream(
result, "to_device", "to_device_key", user="user_id" result, "to_device", "to_device_key", user="user_id"
) )
yield notify_device_list_update(result)
while True: while True:
try: try:
@ -427,7 +451,7 @@ class SynchrotronServer(HomeServer):
yield store.process_replication(result) yield store.process_replication(result)
typing_handler.process_replication(result) typing_handler.process_replication(result)
yield presence_handler.process_replication(result) yield presence_handler.process_replication(result)
notify(result) yield notify(result)
except: except:
logger.exception("Error replicating from %r", replication_url) logger.exception("Error replicating from %r", replication_url)
yield sleep(5) yield sleep(5)

View file

@ -126,6 +126,16 @@ class FederationClient(FederationBase):
destination, content, timeout destination, content, timeout
) )
@log_function
def query_user_devices(self, destination, user_id, timeout=30000):
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
sent_queries_counter.inc("user_devices")
return self.transport_layer.query_user_devices(
destination, user_id, timeout
)
@log_function @log_function
def claim_client_keys(self, destination, content, timeout): def claim_client_keys(self, destination, content, timeout):
"""Claims one-time keys for a device hosted on a remote server. """Claims one-time keys for a device hosted on a remote server.

View file

@ -416,6 +416,9 @@ class FederationServer(FederationBase):
def on_query_client_keys(self, origin, content): def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content) return self.on_query_request("client_keys", content)
def on_query_user_devices(self, origin, user_id):
return self.on_query_request("user_devices", user_id)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def on_claim_client_keys(self, origin, content): def on_claim_client_keys(self, origin, content):

View file

@ -100,6 +100,7 @@ class TransactionQueue(object):
self.pending_failures_by_dest = {} self.pending_failures_by_dest = {}
self.last_device_stream_id_by_dest = {} self.last_device_stream_id_by_dest = {}
self.last_device_list_stream_id_by_dest = {}
# HACK to get unique tx id # HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec()) self._next_txn_id = int(self.clock.time_msec())
@ -305,62 +306,74 @@ class TransactionQueue(object):
yield run_on_reactor() yield run_on_reactor()
while True: while True:
pending_pdus = self.pending_pdus_by_dest.pop(destination, []) pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, []) pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {}) pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, []) pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend( pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values() self.pending_edus_keyed_by_dest.pop(destination, {}).values()
)
limiter = yield get_retry_limiter(
destination,
self.clock,
self.store,
)
device_message_edus, device_stream_id, dev_list_id = (
yield self._get_new_device_messages(destination)
)
pending_edus.extend(device_message_edus)
if pending_presence:
pending_edus.append(
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.presence",
content={
"push": [
format_user_presence_state(
presence, self.clock.time_msec()
)
for presence in pending_presence.values()
]
},
)
) )
limiter = yield get_retry_limiter( if pending_pdus:
destination, logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
self.clock, destination, len(pending_pdus))
self.store,
)
device_message_edus, device_stream_id = ( if not pending_pdus and not pending_edus and not pending_failures:
yield self._get_new_device_messages(destination) logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = (
device_stream_id
) )
return
pending_edus.extend(device_message_edus) success = yield self._send_new_transaction(
if pending_presence: destination, pending_pdus, pending_edus, pending_failures,
pending_edus.append( limiter=limiter,
Edu( )
origin=self.server_name, if success:
destination=destination, # Remove the acknowledged device messages from the database
edu_type="m.presence", # Only bother if we actually sent some device messages
content={ if device_message_edus:
"push": [ yield self.store.delete_device_msgs_for_remote(
format_user_presence_state( destination, device_stream_id
presence, self.clock.time_msec() )
) logger.info("Marking as sent %r %r", destination, dev_list_id)
for presence in pending_presence.values() yield self.store.mark_as_sent_devices_by_remote(
] destination, dev_list_id
},
)
) )
if pending_pdus: self.last_device_stream_id_by_dest[destination] = device_stream_id
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", self.last_device_list_stream_id_by_dest[destination] = dev_list_id
destination, len(pending_pdus)) else:
break
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = (
device_stream_id
)
return
success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures,
device_stream_id,
should_delete_from_device_stream=bool(device_message_edus),
limiter=limiter,
)
if not success:
break
except NotRetryingDestination: except NotRetryingDestination:
logger.debug( logger.debug(
"TX [%s] not ready for retry yet - " "TX [%s] not ready for retry yet - "
@ -387,13 +400,26 @@ class TransactionQueue(object):
) )
for content in contents for content in contents
] ]
defer.returnValue((edus, stream_id))
last_device_list = self.last_device_list_stream_id_by_dest.get(destination, 0)
now_stream_id, results = yield self.store.get_devices_by_remote(
destination, last_device_list
)
edus.extend(
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.device_list_update",
content=content,
)
for content in results
)
defer.returnValue((edus, stream_id, now_stream_id))
@measure_func("_send_new_transaction") @measure_func("_send_new_transaction")
@defer.inlineCallbacks @defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus, def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures, device_stream_id, pending_failures, limiter):
should_delete_from_device_stream, limiter):
# Sort based on the order field # Sort based on the order field
pending_pdus.sort(key=lambda t: t[1]) pending_pdus.sort(key=lambda t: t[1])
@ -504,13 +530,6 @@ class TransactionQueue(object):
"Failed to send event %s to %s", p.event_id, destination "Failed to send event %s to %s", p.event_id, destination
) )
success = False success = False
else:
# Remove the acknowledged device messages from the database
if should_delete_from_device_stream:
yield self.store.delete_device_msgs_for_remote(
destination, device_stream_id
)
self.last_device_stream_id_by_dest[destination] = device_stream_id
except RuntimeError as e: except RuntimeError as e:
# We capture this here as there as nothing actually listens # We capture this here as there as nothing actually listens
# for this finishing functions deferred. # for this finishing functions deferred.

View file

@ -346,6 +346,32 @@ class TransportLayerClient(object):
) )
defer.returnValue(content) defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def query_user_devices(self, destination, user_id, timeout):
"""Query the devices for a user id hosted on a remote server.
Response:
{
"stream_id": "...",
"devices": [ { ... } ]
}
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the device keys.
"""
path = PREFIX + "/user/devices/" + user_id
content = yield self.client.get_json(
destination=destination,
path=path,
timeout=timeout,
)
defer.returnValue(content)
@defer.inlineCallbacks @defer.inlineCallbacks
@log_function @log_function
def claim_client_keys(self, destination, query_content, timeout): def claim_client_keys(self, destination, query_content, timeout):

View file

@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet):
return self.handler.on_query_client_keys(origin, content) return self.handler.on_query_client_keys(origin, content)
class FederationUserDevicesQueryServlet(BaseFederationServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
def on_GET(self, origin, content, query, user_id):
return self.handler.on_query_user_devices(origin, user_id)
class FederationClientKeysClaimServlet(BaseFederationServlet): class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/user/keys/claim" PATH = "/user/keys/claim"
@ -613,6 +620,7 @@ SERVLET_CLASSES = (
FederationGetMissingEventsServlet, FederationGetMissingEventsServlet,
FederationEventAuthServlet, FederationEventAuthServlet,
FederationClientKeysQueryServlet, FederationClientKeysQueryServlet,
FederationUserDevicesQueryServlet,
FederationClientKeysClaimServlet, FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet, FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet, On3pidBindServlet,

View file

@ -15,6 +15,8 @@
from synapse.api import errors from synapse.api import errors
from synapse.util import stringutils from synapse.util import stringutils
from synapse.util.async import Linearizer
from synapse.types import get_domain_from_id
from twisted.internet import defer from twisted.internet import defer
from ._base import BaseHandler from ._base import BaseHandler
@ -27,6 +29,21 @@ class DeviceHandler(BaseHandler):
def __init__(self, hs): def __init__(self, hs):
super(DeviceHandler, self).__init__(hs) super(DeviceHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer()
self._remote_edue_linearizer = Linearizer(name="remote_device_list")
self.federation.register_edu_handler(
"m.device_list_update", self._incoming_device_list_update,
)
self.federation.register_query_handler(
"user_devices", self.on_federation_query_user_devices,
)
hs.get_distributor().observe("user_left_room", self.user_left_room)
@defer.inlineCallbacks @defer.inlineCallbacks
def check_device_registered(self, user_id, device_id, def check_device_registered(self, user_id, device_id,
initial_device_display_name=None): initial_device_display_name=None):
@ -45,29 +62,29 @@ class DeviceHandler(BaseHandler):
str: device id (generated if none was supplied) str: device id (generated if none was supplied)
""" """
if device_id is not None: if device_id is not None:
yield self.store.store_device( new_device = yield self.store.store_device(
user_id=user_id, user_id=user_id,
device_id=device_id, device_id=device_id,
initial_device_display_name=initial_device_display_name, initial_device_display_name=initial_device_display_name,
ignore_if_known=True,
) )
if new_device:
yield self.notify_device_update(user_id, [device_id])
defer.returnValue(device_id) defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few # if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash. # times in case of a clash.
attempts = 0 attempts = 0
while attempts < 5: while attempts < 5:
try: device_id = stringutils.random_string(10).upper()
device_id = stringutils.random_string(10).upper() new_device = yield self.store.store_device(
yield self.store.store_device( user_id=user_id,
user_id=user_id, device_id=device_id,
device_id=device_id, initial_device_display_name=initial_device_display_name,
initial_device_display_name=initial_device_display_name, )
ignore_if_known=False, if new_device:
) yield self.notify_device_update(user_id, [device_id])
defer.returnValue(device_id) defer.returnValue(device_id)
except errors.StoreError: attempts += 1
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.") raise errors.StoreError(500, "Couldn't generate a device ID.")
@ -147,6 +164,8 @@ class DeviceHandler(BaseHandler):
user_id=user_id, device_id=device_id user_id=user_id, device_id=device_id
) )
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks @defer.inlineCallbacks
def update_device(self, user_id, device_id, content): def update_device(self, user_id, device_id, content):
""" Update the given device """ Update the given device
@ -166,12 +185,110 @@ class DeviceHandler(BaseHandler):
device_id, device_id,
new_display_name=content.get("display_name") new_display_name=content.get("display_name")
) )
yield self.notify_device_update(user_id, [device_id])
except errors.StoreError, e: except errors.StoreError, e:
if e.code == 404: if e.code == 404:
raise errors.NotFoundError() raise errors.NotFoundError()
else: else:
raise raise
@defer.inlineCallbacks
def notify_device_update(self, user_id, device_ids):
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
"""
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
hosts = set()
if self.hs.is_mine_id(user_id):
for room_id in room_ids:
users = yield self.state.get_current_user_in_room(room_id)
hosts.update(get_domain_from_id(u) for u in users)
hosts.discard(self.server_name)
position = yield self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts)
)
yield self.notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
if hosts:
logger.info("Sending device list update notif to: %r", hosts)
for host in hosts:
self.federation_sender.send_device_messages(host)
@defer.inlineCallbacks
def _incoming_device_list_update(self, origin, edu_content):
user_id = edu_content["user_id"]
device_id = edu_content["device_id"]
stream_id = edu_content["stream_id"]
prev_ids = edu_content.get("prev_id", [])
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning("Got device list update edu for %r from %r", user_id, origin)
return
rooms = yield self.store.get_rooms_for_user(user_id)
if not rooms:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
return
with (yield self._remote_edue_linearizer.queue(user_id)):
# If the prev id matches whats in our cache table, then we don't need
# to resync the users device list, otherwise we do.
resync = True
if len(prev_ids) == 1:
extremity = yield self.store.get_device_list_last_stream_id_for_remote(
user_id
)
logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
if str(extremity) == str(prev_ids[0]):
resync = False
if resync:
# Fetch all devices for the user.
result = yield self.federation.query_user_devices(origin, user_id)
stream_id = result["stream_id"]
devices = result["devices"]
yield self.store.update_remote_device_list_cache(
user_id, devices, stream_id,
)
device_ids = [device["device_id"] for device in devices]
yield self.notify_device_update(user_id, device_ids)
else:
# Simply update the single device, since we know that is the only
# change (becuase of the single prev_id matching the current cache)
content = dict(edu_content)
for key in ("user_id", "device_id", "stream_id", "prev_ids"):
content.pop(key, None)
yield self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id,
)
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
defer.returnValue({
"user_id": user_id,
"stream_id": stream_id,
"devices": devices,
})
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
user_id = user.to_string()
rooms = yield self.store.get_rooms_for_user(user_id)
if not rooms:
# We no longer share rooms with this user, so we'll no longer
# receive device updates. Mark this in DB.
yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
def _update_device_from_client_ips(device, client_ips): def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {}) ip = client_ips.get((device["user_id"], device["device_id"]), {})

View file

@ -73,10 +73,9 @@ class E2eKeysHandler(object):
if self.is_mine_id(user_id): if self.is_mine_id(user_id):
local_query[user_id] = device_ids local_query[user_id] = device_ids
else: else:
domain = get_domain_from_id(user_id) remote_queries[user_id] = device_ids
remote_queries.setdefault(domain, {})[user_id] = device_ids
# do the queries # Firt get local devices.
failures = {} failures = {}
results = {} results = {}
if local_query: if local_query:
@ -85,9 +84,42 @@ class E2eKeysHandler(object):
if user_id in local_query: if user_id in local_query:
results[user_id] = keys results[user_id] = keys
# Now attempt to get any remote devices from our local cache.
remote_queries_not_in_cache = {}
if remote_queries:
query_list = []
for user_id, device_ids in remote_queries.iteritems():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
else:
query_list.append((user_id, None))
user_ids_not_in_cache, remote_results = (
yield self.store.get_user_devices_from_cache(
query_list
)
)
for user_id, devices in remote_results.iteritems():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.iteritems():
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
result = dict(keys)
unsigned = result.setdefault("unsigned", {})
if device_display_name:
unsigned["device_display_name"] = device_display_name
user_devices[device_id] = result
for user_id in user_ids_not_in_cache:
domain = get_domain_from_id(user_id)
r = remote_queries_not_in_cache.setdefault(domain, {})
r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache
@defer.inlineCallbacks @defer.inlineCallbacks
def do_remote_query(destination): def do_remote_query(destination):
destination_query = remote_queries[destination] destination_query = remote_queries_not_in_cache[destination]
try: try:
limiter = yield get_retry_limiter( limiter = yield get_retry_limiter(
destination, self.clock, self.store destination, self.clock, self.store
@ -119,7 +151,7 @@ class E2eKeysHandler(object):
yield preserve_context_over_deferred(defer.gatherResults([ yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination) preserve_fn(do_remote_query)(destination)
for destination in remote_queries for destination in remote_queries_not_in_cache
])) ]))
defer.returnValue({ defer.returnValue({
@ -259,6 +291,7 @@ class E2eKeysHandler(object):
user_id, device_id, time_now, user_id, device_id, time_now,
encode_canonical_json(device_keys) encode_canonical_json(device_keys)
) )
yield self.device_handler.notify_device_update(user_id, [device_id])
one_time_keys = keys.get("one_time_keys", None) one_time_keys = keys.get("one_time_keys", None)
if one_time_keys: if one_time_keys:

View file

@ -115,6 +115,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
"invited", # InvitedSyncResult for each invited room. "invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room. "archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device. "to_device", # List of direct messages for the device.
"device_lists", # List of user_ids whose devices have chanegd
])): ])):
__slots__ = [] __slots__ = []
@ -544,6 +545,10 @@ class SyncHandler(object):
yield self._generate_sync_entry_for_to_device(sync_result_builder) yield self._generate_sync_entry_for_to_device(sync_result_builder)
device_lists = yield self._generate_sync_entry_for_device_list(
sync_result_builder
)
defer.returnValue(SyncResult( defer.returnValue(SyncResult(
presence=sync_result_builder.presence, presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data, account_data=sync_result_builder.account_data,
@ -551,9 +556,32 @@ class SyncHandler(object):
invited=sync_result_builder.invited, invited=sync_result_builder.invited,
archived=sync_result_builder.archived, archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device, to_device=sync_result_builder.to_device,
device_lists=device_lists,
next_batch=sync_result_builder.now_token, next_batch=sync_result_builder.now_token,
)) ))
@defer.inlineCallbacks
def _generate_sync_entry_for_device_list(self, sync_result_builder):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
if since_token and since_token.device_list_key:
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = set(r.room_id for r in rooms)
user_ids_changed = set()
changed = yield self.store.get_user_whose_devices_changed(
since_token.device_list_key
)
for other_user_id in changed:
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
if room_ids.intersection(e.room_id for e in other_rooms):
user_ids_changed.add(other_user_id)
defer.returnValue(user_ids_changed)
else:
defer.returnValue([])
@defer.inlineCallbacks @defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder): def _generate_sync_entry_for_to_device(self, sync_result_builder):
"""Generates the portion of the sync response. Populates """Generates the portion of the sync response. Populates

View file

@ -46,6 +46,7 @@ STREAM_NAMES = (
("to_device",), ("to_device",),
("public_rooms",), ("public_rooms",),
("federation",), ("federation",),
("device_lists",),
) )
@ -140,6 +141,7 @@ class ReplicationResource(Resource):
caches_token = self.store.get_cache_stream_token() caches_token = self.store.get_cache_stream_token()
public_rooms_token = self.store.get_current_public_room_stream_id() public_rooms_token = self.store.get_current_public_room_stream_id()
federation_token = self.federation_sender.get_current_token() federation_token = self.federation_sender.get_current_token()
device_list_token = self.store.get_device_stream_token()
defer.returnValue(_ReplicationToken( defer.returnValue(_ReplicationToken(
room_stream_token, room_stream_token,
@ -155,6 +157,7 @@ class ReplicationResource(Resource):
int(stream_token.to_device_key), int(stream_token.to_device_key),
int(public_rooms_token), int(public_rooms_token),
int(federation_token), int(federation_token),
int(device_list_token),
)) ))
@request_handler() @request_handler()
@ -214,6 +217,7 @@ class ReplicationResource(Resource):
yield self.caches(writer, current_token, limit, request_streams) yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams) yield self.to_device(writer, current_token, limit, request_streams)
yield self.public_rooms(writer, current_token, limit, request_streams) yield self.public_rooms(writer, current_token, limit, request_streams)
yield self.device_lists(writer, current_token, limit, request_streams)
self.federation(writer, current_token, limit, request_streams, federation_ack) self.federation(writer, current_token, limit, request_streams, federation_ack)
self.streams(writer, current_token, request_streams) self.streams(writer, current_token, request_streams)
@ -495,6 +499,20 @@ class ReplicationResource(Resource):
"position", "type", "content", "position", "type", "content",
), position=upto_token) ), position=upto_token)
@defer.inlineCallbacks
def device_lists(self, writer, current_token, limit, request_streams):
current_position = current_token.device_lists
device_lists = request_streams.get("device_lists")
if device_lists is not None and device_lists != current_position:
changes = yield self.store.get_all_device_list_changes_for_remotes(
device_lists,
)
writer.write_header_and_rows("device_lists", changes, (
"position", "user_id", "destination",
), position=current_position)
class _Writer(object): class _Writer(object):
"""Writes the streams as a JSON object as the response to the request""" """Writes the streams as a JSON object as the response to the request"""
@ -527,7 +545,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", ( class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill", "events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "to_device", "public_rooms", "push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
"federation", "federation", "device_lists",
))): ))):
__slots__ = [] __slots__ = []

View file

@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs)
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
db_conn, "device_lists_stream", "stream_id",
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max,
)
get_device_stream_token = DataStore.get_device_stream_token.__func__
get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__
get_devices_by_remote = DataStore.get_devices_by_remote.__func__
_get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__
_get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__
mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__
_mark_as_sent_devices_by_remote_txn = (
DataStore._mark_as_sent_devices_by_remote_txn.__func__
)
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
result["device_lists"] = self._device_list_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("device_lists")
if stream:
self._device_list_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
stream_id = row[0]
user_id = row[1]
destination = row[2]
self._device_list_stream_cache.entity_has_changed(
user_id, stream_id
)
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, stream_id
)
return super(SlavedDeviceStore, self).process_replication(result)

View file

@ -170,12 +170,16 @@ class SyncRestServlet(RestServlet):
) )
archived = self.encode_archived( archived = self.encode_archived(
sync_result.archived, time_now, requester.access_token_id, filter.event_fields sync_result.archived, time_now, requester.access_token_id,
filter.event_fields,
) )
response_content = { response_content = {
"account_data": {"events": sync_result.account_data}, "account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device}, "to_device": {"events": sync_result.to_device},
"device_lists": {
"changed": list(sync_result.device_lists),
},
"presence": self.encode_presence( "presence": self.encode_presence(
sync_result.presence, time_now sync_result.presence, time_now
), ),

View file

@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._public_room_id_gen = StreamIdGenerator( self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id" db_conn, "public_room_list_stream", "stream_id"
) )
self._device_list_id_gen = StreamIdGenerator(
db_conn, "device_lists_stream", "stream_id",
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id") self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id") self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
@ -210,6 +213,14 @@ class DataStore(RoomMemberStore, RoomStore,
prefilled_cache=device_outbox_prefill, prefilled_cache=device_outbox_prefill,
) )
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max,
)
cur = LoggingTransaction( cur = LoggingTransaction(
db_conn.cursor(), db_conn.cursor(),
name="_find_stream_orderings_for_times_txn", name="_find_stream_orderings_for_times_txn",

View file

@ -387,6 +387,10 @@ class SQLBaseStore(object):
Args: Args:
table : string giving the table name table : string giving the table name
values : dict of new column names and values for them values : dict of new column names and values for them
Returns:
bool: Whether the row was inserted or not. Only useful when
`or_ignore` is True
""" """
try: try:
yield self.runInteraction( yield self.runInteraction(
@ -398,6 +402,8 @@ class SQLBaseStore(object):
# a cursor after we receive an error from the db. # a cursor after we receive an error from the db.
if not or_ignore: if not or_ignore:
raise raise
defer.returnValue(False)
defer.returnValue(True)
@staticmethod @staticmethod
def _simple_insert_txn(txn, table, values): def _simple_insert_txn(txn, table, values):

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import ujson as json
from twisted.internet import defer from twisted.internet import defer
@ -23,27 +24,29 @@ logger = logging.getLogger(__name__)
class DeviceStore(SQLBaseStore): class DeviceStore(SQLBaseStore):
def __init__(self, hs):
super(DeviceStore, self).__init__(hs)
self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
@defer.inlineCallbacks @defer.inlineCallbacks
def store_device(self, user_id, device_id, def store_device(self, user_id, device_id,
initial_device_display_name, initial_device_display_name):
ignore_if_known=True):
"""Ensure the given device is known; add it to the store if not """Ensure the given device is known; add it to the store if not
Args: Args:
user_id (str): id of user associated with the device user_id (str): id of user associated with the device
device_id (str): id of device device_id (str): id of device
initial_device_display_name (str): initial displayname of the initial_device_display_name (str): initial displayname of the
device device. Ignored if device exists.
ignore_if_known (bool): ignore integrity errors which mean the
device is already known
Returns: Returns:
defer.Deferred defer.Deferred: boolean whether the device was inserted or an
Raises: existing device existed with that ID.
StoreError: if ignore_if_known is False and the device was already
known
""" """
try: try:
yield self._simple_insert( inserted = yield self._simple_insert(
"devices", "devices",
values={ values={
"user_id": user_id, "user_id": user_id,
@ -51,8 +54,9 @@ class DeviceStore(SQLBaseStore):
"display_name": initial_device_display_name "display_name": initial_device_display_name
}, },
desc="store_device", desc="store_device",
or_ignore=ignore_if_known, or_ignore=True,
) )
defer.returnValue(inserted)
except Exception as e: except Exception as e:
logger.error("store_device with device_id=%s(%r) user_id=%s(%r)" logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
" display_name=%s(%r) failed: %s", " display_name=%s(%r) failed: %s",
@ -139,3 +143,432 @@ class DeviceStore(SQLBaseStore):
) )
defer.returnValue({d["device_id"]: d for d in devices}) defer.returnValue({d["device_id"]: d for d in devices})
def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self._simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
desc="get_device_list_remote_extremity",
allow_none=True,
)
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
return self._simple_delete(
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
)
def update_remote_device_list_cache_entry(self, user_id, device_id, content,
stream_id):
"""Updates a single user's device in the cache.
"""
return self.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id, device_id, content, stream_id,
)
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
content, stream_id):
self._simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"content": json.dumps(content),
}
)
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
values={
"stream_id": stream_id,
}
)
def update_remote_device_list_cache(self, user_id, devices, stream_id):
"""Replace the cache of the remote user's devices.
"""
return self.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id, devices, stream_id,
)
def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
stream_id):
self._simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
)
self._simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
values=[
{
"user_id": user_id,
"device_id": content["device_id"],
"content": json.dumps(content),
}
for content in devices
]
)
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
values={
"stream_id": stream_id,
}
)
def get_devices_by_remote(self, destination, from_stream_id):
"""Get stream of updates to send to remote servers
Returns:
(now_stream_id, [ { updates }, .. ])
"""
now_stream_id = self._device_list_id_gen.get_current_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
)
if not has_changed:
return (now_stream_id, [])
return self.runInteraction(
"get_devices_by_remote", self._get_devices_by_remote_txn,
destination, from_stream_id, now_stream_id,
)
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
now_stream_id):
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
rows = txn.fetchall()
if not rows:
return (now_stream_id, [])
# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in rows}
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True
)
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_pokes
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
results = []
for user_id, user_devices in devices.iteritems():
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
for device_id, device in user_devices.iteritems():
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
}
prev_id = stream_id
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return (now_stream_id, results)
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
Args:
query_list(list): List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
Returns:
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
return self.runInteraction(
"get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
query_list,
)
def _get_user_devices_from_cache_txn(self, txn, query_list):
user_ids = {user_id for user_id, _ in query_list}
user_ids_in_cache = set()
for user_id in user_ids:
stream_ids = self._simple_select_onecol_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
retcol="stream_id",
)
if stream_ids:
user_ids_in_cache.add(user_id)
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
if device_id:
content = self._simple_select_one_onecol_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="content",
)
results.setdefault(user_id, {})[device_id] = json.loads(content)
else:
devices = self._simple_select_list_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
retcols=("device_id", "content"),
)
results[user_id] = {
device["device_id"]: json.loads(device["content"])
for device in devices
}
user_ids_in_cache.discard(user_id)
return user_ids_not_in_cache, results
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
Returns:
(stream_id, devices)
"""
return self.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, user_id,
)
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(
txn, [(user_id, None)], include_all_devices=True
)
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.iteritems():
result = {
"device_id": device_id,
}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
destination, stream_id,
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
sql = """
DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id < (
SELECT coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
)
"""
txn.execute(sql, (destination, destination, stream_id,))
sql = """
UPDATE device_lists_outbound_pokes SET sent = ?
WHERE destination = ? AND stream_id <= ?
"""
txn.execute(sql, (True, destination, stream_id,))
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
"""Get set of users whose devices have changed since `from_key`.
"""
from_key = int(from_key)
changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
if changed is not None:
defer.returnValue(set(changed))
sql = """
SELECT user_id FROM device_lists_stream WHERE stream_id > ?
"""
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row["user_id"] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
"""
sql = """
SELECT stream_id, user_id, destination FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
WHERE stream_id > ?
"""
return self._execute(
"get_users_and_hosts_device_list", None,
sql, from_key,
)
@defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
with self._device_list_id_gen.get_next() as stream_id:
yield self.runInteraction(
"add_device_change_to_streams", self._add_device_change_txn,
user_id, device_ids, hosts, stream_id,
)
defer.returnValue(stream_id)
def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
now = self._clock.time_msec()
txn.call_after(
self._device_list_stream_cache.entity_has_changed,
user_id, stream_id,
)
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host, stream_id,
)
self._simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
{
"stream_id": stream_id,
"user_id": user_id,
"device_id": device_id,
}
for device_id in device_ids
]
)
self._simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
{
"destination": destination,
"stream_id": stream_id,
"user_id": user_id,
"device_id": device_id,
"sent": False,
"ts": now,
}
for destination in hosts
for device_id in device_ids
]
)
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
def _prune_old_outbound_device_pokes(self):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers. We keep one entry per
(destination, user_id) tuple to ensure that the prev_ids remain correct
if the server does come back.
"""
now = self._clock.time_msec()
def _prune_txn(txn):
select_sql = """
SELECT destination, user_id, max(stream_id) as stream_id
FROM device_lists_outbound_pokes
GROUP BY destination, user_id
"""
txn.execute(select_sql)
rows = txn.fetchall()
delete_sql = """
DELETE FROM device_lists_outbound_pokes
WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
"""
txn.executemany(
delete_sql,
(
(now, row["destination"], row["user_id"], row["stream_id"])
for row in rows
)
)
return self.runInteraction(
"_prune_old_outbound_device_pokes", _prune_txn
)

View file

@ -12,9 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections from twisted.internet import defer
import twisted.internet.defer
from ._base import SQLBaseStore from ._base import SQLBaseStore
@ -33,10 +31,12 @@ class EndToEndKeyStore(SQLBaseStore):
} }
) )
def get_e2e_device_keys(self, query_list): def get_e2e_device_keys(self, query_list, include_all_devices=False):
"""Fetch a list of device keys. """Fetch a list of device keys.
Args: Args:
query_list(list): List of pairs of user_ids and device_ids. query_list(list): List of pairs of user_ids and device_ids.
include_all_devices (bool): whether to include entries for devices
that don't have device keys
Returns: Returns:
Dict mapping from user-id to dict mapping from device_id to Dict mapping from user-id to dict mapping from device_id to
dict containing "key_json", "device_display_name". dict containing "key_json", "device_display_name".
@ -45,41 +45,42 @@ class EndToEndKeyStore(SQLBaseStore):
return {} return {}
return self.runInteraction( return self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list "get_e2e_device_keys", self._get_e2e_device_keys_txn,
query_list, include_all_devices,
) )
def _get_e2e_device_keys_txn(self, txn, query_list): def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
query_clauses = [] query_clauses = []
query_params = [] query_params = []
for (user_id, device_id) in query_list: for (user_id, device_id) in query_list:
query_clause = "k.user_id = ?" query_clause = "user_id = ?"
query_params.append(user_id) query_params.append(user_id)
if device_id: if device_id:
query_clause += " AND k.device_id = ?" query_clause += " AND device_id = ?"
query_params.append(device_id) query_params.append(device_id)
query_clauses.append(query_clause) query_clauses.append(query_clause)
sql = ( sql = (
"SELECT k.user_id, k.device_id, " "SELECT user_id, device_id, "
" d.display_name AS device_display_name, " " d.display_name AS device_display_name, "
" k.key_json" " k.key_json"
" FROM e2e_device_keys_json k" " FROM devices d"
" LEFT JOIN devices d ON d.user_id = k.user_id" " %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
" AND d.device_id = k.device_id"
" WHERE %s" " WHERE %s"
) % ( ) % (
"LEFT" if include_all_devices else "INNER",
" OR ".join("(" + q + ")" for q in query_clauses) " OR ".join("(" + q + ")" for q in query_clauses)
) )
txn.execute(sql, query_params) txn.execute(sql, query_params)
rows = self.cursor_to_dict(txn) rows = self.cursor_to_dict(txn)
result = collections.defaultdict(dict) result = {}
for row in rows: for row in rows:
result[row["user_id"]][row["device_id"]] = row result.setdefault(row["user_id"], {})[row["device_id"]] = row
return result return result
@ -152,7 +153,7 @@ class EndToEndKeyStore(SQLBaseStore):
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys "claim_e2e_one_time_keys", _claim_e2e_one_time_keys
) )
@twisted.internet.defer.inlineCallbacks @defer.inlineCallbacks
def delete_e2e_keys_by_device(self, user_id, device_id): def delete_e2e_keys_by_device(self, user_id, device_id):
yield self._simple_delete( yield self._simple_delete(
table="e2e_device_keys_json", table="e2e_device_keys_json",

View file

@ -0,0 +1,59 @@
/* Copyright 2017 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Cache of remote devices.
CREATE TABLE device_lists_remote_cache (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
content TEXT NOT NULL
);
CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
-- The last update we got for a user. Empty if we're not receiving updates for
-- that user.
CREATE TABLE device_lists_remote_extremeties (
user_id TEXT NOT NULL,
stream_id TEXT NOT NULL
);
CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id);
-- Stream of device lists updates. Includes both local and remotes
CREATE TABLE device_lists_stream (
stream_id BIGINT NOT NULL,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL
);
CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id);
-- The stream of updates to send to other servers. We keep at least one row
-- per user that was sent so that the prev_id for any new updates can be
-- calculated
CREATE TABLE device_lists_outbound_pokes (
destination TEXT NOT NULL,
stream_id BIGINT NOT NULL,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
sent BOOLEAN NOT NULL,
ts BIGINT NOT NULL -- So that in future we can clear out pokes to dead servers
);
CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id);
CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id);

View file

@ -44,6 +44,7 @@ class EventSources(object):
def get_current_token(self): def get_current_token(self):
push_rules_key, _ = self.store.get_push_rules_stream_token() push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token() to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
@ -63,6 +64,7 @@ class EventSources(object):
), ),
push_rules_key=push_rules_key, push_rules_key=push_rules_key,
to_device_key=to_device_key, to_device_key=to_device_key,
device_list_key=device_list_key,
) )
defer.returnValue(token) defer.returnValue(token)
@ -70,6 +72,7 @@ class EventSources(object):
def get_current_token_for_room(self, room_id): def get_current_token_for_room(self, room_id):
push_rules_key, _ = self.store.get_push_rules_stream_token() push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token() to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
token = StreamToken( token = StreamToken(
room_key=( room_key=(
@ -89,5 +92,6 @@ class EventSources(object):
), ),
push_rules_key=push_rules_key, push_rules_key=push_rules_key,
to_device_key=to_device_key, to_device_key=to_device_key,
device_list_key=device_list_key,
) )
defer.returnValue(token) defer.returnValue(token)

View file

@ -158,6 +158,7 @@ class StreamToken(
"account_data_key", "account_data_key",
"push_rules_key", "push_rules_key",
"to_device_key", "to_device_key",
"device_list_key",
)) ))
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
@ -195,6 +196,7 @@ class StreamToken(
or (int(other.account_data_key) < int(self.account_data_key)) or (int(other.account_data_key) < int(self.account_data_key))
or (int(other.push_rules_key) < int(self.push_rules_key)) or (int(other.push_rules_key) < int(self.push_rules_key))
or (int(other.to_device_key) < int(self.to_device_key)) or (int(other.to_device_key) < int(self.to_device_key))
or (int(other.device_list_key) < int(self.device_list_key))
) )
def copy_and_advance(self, key, new_value): def copy_and_advance(self, key, new_value):

View file

@ -35,51 +35,51 @@ class DeviceTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def setUp(self): def setUp(self):
hs = yield utils.setup_test_homeserver(handlers=None) hs = yield utils.setup_test_homeserver()
self.handler = synapse.handlers.device.DeviceHandler(hs) self.handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.clock = hs.get_clock() self.clock = hs.get_clock()
@defer.inlineCallbacks @defer.inlineCallbacks
def test_device_is_created_if_doesnt_exist(self): def test_device_is_created_if_doesnt_exist(self):
res = yield self.handler.check_device_registered( res = yield self.handler.check_device_registered(
user_id="boris", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="display name" initial_device_display_name="display name"
) )
self.assertEqual(res, "fco") self.assertEqual(res, "fco")
dev = yield self.handler.store.get_device("boris", "fco") dev = yield self.handler.store.get_device("@boris:foo", "fco")
self.assertEqual(dev["display_name"], "display name") self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_device_is_preserved_if_exists(self): def test_device_is_preserved_if_exists(self):
res1 = yield self.handler.check_device_registered( res1 = yield self.handler.check_device_registered(
user_id="boris", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="display name" initial_device_display_name="display name"
) )
self.assertEqual(res1, "fco") self.assertEqual(res1, "fco")
res2 = yield self.handler.check_device_registered( res2 = yield self.handler.check_device_registered(
user_id="boris", user_id="@boris:foo",
device_id="fco", device_id="fco",
initial_device_display_name="new display name" initial_device_display_name="new display name"
) )
self.assertEqual(res2, "fco") self.assertEqual(res2, "fco")
dev = yield self.handler.store.get_device("boris", "fco") dev = yield self.handler.store.get_device("@boris:foo", "fco")
self.assertEqual(dev["display_name"], "display name") self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks @defer.inlineCallbacks
def test_device_id_is_made_up_if_unspecified(self): def test_device_id_is_made_up_if_unspecified(self):
device_id = yield self.handler.check_device_registered( device_id = yield self.handler.check_device_registered(
user_id="theresa", user_id="@theresa:foo",
device_id=None, device_id=None,
initial_device_display_name="display" initial_device_display_name="display"
) )
dev = yield self.handler.store.get_device("theresa", device_id) dev = yield self.handler.store.get_device("@theresa:foo", device_id)
self.assertEqual(dev["display_name"], "display") self.assertEqual(dev["display_name"], "display")
@defer.inlineCallbacks @defer.inlineCallbacks

View file

@ -37,6 +37,7 @@ class DirectoryTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.mock_federation = Mock(spec=[ self.mock_federation = Mock(spec=[
"make_query", "make_query",
"register_edu_handler",
]) ])
self.query_handlers = {} self.query_handlers = {}

View file

@ -39,6 +39,7 @@ class ProfileTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.mock_federation = Mock(spec=[ self.mock_federation = Mock(spec=[
"make_query", "make_query",
"register_edu_handler",
]) ])
self.query_handlers = {} self.query_handlers = {}

View file

@ -75,6 +75,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"get_received_txn_response", "get_received_txn_response",
"set_received_txn_response", "set_received_txn_response",
"get_destination_retry_timings", "get_destination_retry_timings",
"get_devices_by_remote",
]), ]),
state_handler=self.state_handler, state_handler=self.state_handler,
handlers=None, handlers=None,
@ -99,6 +100,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
defer.succeed(retry_timings_res) defer.succeed(retry_timings_res)
) )
self.datastore.get_devices_by_remote.return_value = (0, [])
def get_received_txn_response(*args): def get_received_txn_response(*args):
return defer.succeed(None) return defer.succeed(None)
self.datastore.get_received_txn_response = get_received_txn_response self.datastore.get_received_txn_response = get_received_txn_response

View file

@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_topo_token_is_accepted(self): def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0" token = "t1-0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))
@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self): def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0" token = "s0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get( (code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" % "/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token)) (self.room_id, token))

View file

@ -39,7 +39,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
event_cache_size=1, event_cache_size=1,
password_providers=[], password_providers=[],
) )
hs = yield setup_test_homeserver(config=config, federation_sender=Mock()) hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
replication_layer=Mock(),
)
self.as_token = "token1" self.as_token = "token1"
self.as_url = "some_url" self.as_url = "some_url"
@ -112,7 +116,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
event_cache_size=1, event_cache_size=1,
password_providers=[], password_providers=[],
) )
hs = yield setup_test_homeserver(config=config, federation_sender=Mock()) hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
replication_layer=Mock(),
)
self.db_pool = hs.get_db_pool() self.db_pool = hs.get_db_pool()
self.as_list = [ self.as_list = [
@ -446,7 +454,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config,
datastore=Mock(), datastore=Mock(),
federation_sender=Mock() federation_sender=Mock(),
replication_layer=Mock(),
) )
ApplicationServiceStore(hs) ApplicationServiceStore(hs)
@ -463,7 +472,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config,
datastore=Mock(), datastore=Mock(),
federation_sender=Mock() federation_sender=Mock(),
replication_layer=Mock(),
) )
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:
@ -486,7 +496,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver( hs = yield setup_test_homeserver(
config=config, config=config,
datastore=Mock(), datastore=Mock(),
federation_sender=Mock() federation_sender=Mock(),
replication_layer=Mock(),
) )
with self.assertRaises(ConfigError) as cm: with self.assertRaises(ConfigError) as cm:

View file

@ -35,6 +35,10 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
now = 1470174257070 now = 1470174257070
json = '{ "key": "value" }' json = '{ "key": "value" }'
yield self.store.store_device(
"user", "device", None
)
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys(
"user", "device", now, json) "user", "device", now, json)
@ -71,6 +75,19 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
def test_multiple_devices(self): def test_multiple_devices(self):
now = 1470174257070 now = 1470174257070
yield self.store.store_device(
"user1", "device1", None
)
yield self.store.store_device(
"user1", "device2", None
)
yield self.store.store_device(
"user2", "device1", None
)
yield self.store.store_device(
"user2", "device2", None
)
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys(
"user1", "device1", now, 'json11') "user1", "device1", now, 'json11')
yield self.store.set_e2e_device_keys( yield self.store.set_e2e_device_keys(