diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py index 65c6673a87..d18f6b6cfd 100644 --- a/synapse/federation/transaction_queue.py +++ b/synapse/federation/transaction_queue.py @@ -306,62 +306,74 @@ class TransactionQueue(object): yield run_on_reactor() while True: - pending_pdus = self.pending_pdus_by_dest.pop(destination, []) - pending_edus = self.pending_edus_by_dest.pop(destination, []) - pending_presence = self.pending_presence_by_dest.pop(destination, {}) - pending_failures = self.pending_failures_by_dest.pop(destination, []) + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_presence = self.pending_presence_by_dest.pop(destination, {}) + pending_failures = self.pending_failures_by_dest.pop(destination, []) - pending_edus.extend( - self.pending_edus_keyed_by_dest.pop(destination, {}).values() + pending_edus.extend( + self.pending_edus_keyed_by_dest.pop(destination, {}).values() + ) + + limiter = yield get_retry_limiter( + destination, + self.clock, + self.store, + ) + + device_message_edus, device_stream_id, dev_list_id = ( + yield self._get_new_device_messages(destination) + ) + + pending_edus.extend(device_message_edus) + if pending_presence: + pending_edus.append( + Edu( + origin=self.server_name, + destination=destination, + edu_type="m.presence", + content={ + "push": [ + format_user_presence_state( + presence, self.clock.time_msec() + ) + for presence in pending_presence.values() + ] + }, + ) ) - limiter = yield get_retry_limiter( - destination, - self.clock, - self.store, - ) + if pending_pdus: + logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", + destination, len(pending_pdus)) - device_message_edus, device_stream_id = ( - yield self._get_new_device_messages(destination) + if not pending_pdus and not pending_edus and not pending_failures: + logger.debug("TX [%s] Nothing to send", destination) + self.last_device_stream_id_by_dest[destination] = ( + device_stream_id ) + return - pending_edus.extend(device_message_edus) - if pending_presence: - pending_edus.append( - Edu( - origin=self.server_name, - destination=destination, - edu_type="m.presence", - content={ - "push": [ - format_user_presence_state( - presence, self.clock.time_msec() - ) - for presence in pending_presence.values() - ] - }, - ) + success = yield self._send_new_transaction( + destination, pending_pdus, pending_edus, pending_failures, + limiter=limiter, + ) + if success: + # Remove the acknowledged device messages from the database + # Only bother if we actually sent some device messages + if device_message_edus: + yield self.store.delete_device_msgs_for_remote( + destination, device_stream_id + ) + logger.info("Marking as sent %r %r", destination, dev_list_id) + yield self.store.mark_as_sent_devices_by_remote( + destination, dev_list_id ) - if pending_pdus: - logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d", - destination, len(pending_pdus)) - - if not pending_pdus and not pending_edus and not pending_failures: - logger.debug("TX [%s] Nothing to send", destination) - self.last_device_stream_id_by_dest[destination] = ( - device_stream_id - ) - return - - success = yield self._send_new_transaction( - destination, pending_pdus, pending_edus, pending_failures, - device_stream_id, - includes_device_messages=bool(device_message_edus), - limiter=limiter, - ) - if not success: - break + self.last_device_stream_id_by_dest[destination] = device_stream_id + self.last_device_list_stream_id_by_dest[destination] = dev_list_id + else: + break except NotRetryingDestination: logger.debug( "TX [%s] not ready for retry yet - " @@ -374,8 +386,6 @@ class TransactionQueue(object): @defer.inlineCallbacks def _get_new_device_messages(self, destination): - # TODO: Send appropriate device list messages - last_device_stream_id = self.last_device_stream_id_by_dest.get(destination, 0) to_device_stream_id = self.store.get_to_device_stream_token() contents, stream_id = yield self.store.get_new_device_msgs_for_remote( @@ -404,13 +414,12 @@ class TransactionQueue(object): ) for content in results ) - defer.returnValue((edus, stream_id)) + defer.returnValue((edus, stream_id, now_stream_id)) @measure_func("_send_new_transaction") @defer.inlineCallbacks def _send_new_transaction(self, destination, pending_pdus, pending_edus, - pending_failures, device_stream_id, - includes_device_messages, limiter): + pending_failures, limiter): # Sort based on the order field pending_pdus.sort(key=lambda t: t[1]) @@ -521,14 +530,6 @@ class TransactionQueue(object): "Failed to send event %s to %s", p.event_id, destination ) success = False - else: - # Remove the acknowledged device messages from the database - # Only bother if we actually sent some device messages - if includes_device_messages: - yield self.store.delete_device_msgs_for_remote( - destination, device_stream_id - ) - self.last_device_stream_id_by_dest[destination] = device_stream_id except RuntimeError as e: # We capture this here as there as nothing actually listens # for this finishing functions deferred. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index d92780b642..ba4c48d590 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -29,6 +29,7 @@ class DeviceHandler(BaseHandler): super(DeviceHandler, self).__init__(hs) self.state = hs.get_state_handler() + self.federation = hs.get_federation_sender() @defer.inlineCallbacks def check_device_registered(self, user_id, device_id, diff --git a/synapse/storage/devices.py b/synapse/storage/devices.py index b594f501f9..9628e2ff75 100644 --- a/synapse/storage/devices.py +++ b/synapse/storage/devices.py @@ -141,11 +141,11 @@ class DeviceStore(SQLBaseStore): def get_devices_by_remote(self, destination, from_stream_id): now_stream_id = self._device_list_id_gen.get_current_token() - has_changed = self._device_list_stream_cache.has_entity_changed( + has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) ) if not has_changed: - defer.returnValue((now_stream_id, [])) + return (now_stream_id, []) return self.runInteraction( "get_devices_by_remote", self._get_devices_by_remote_txn, @@ -165,7 +165,7 @@ class DeviceStore(SQLBaseStore): rows = txn.fetchall() if not rows: - return now_stream_id, [] + return (now_stream_id, []) # maps (user_id, device_id) -> stream_id query_map = {(r[0], r[1]): r[2] for r in rows} @@ -189,7 +189,7 @@ class DeviceStore(SQLBaseStore): result = { "user_id": user_id, "device_id": device_id, - "prev_id": prev_id, + "prev_id": [prev_id] if prev_id else [], "stream_id": stream_id, } @@ -202,9 +202,9 @@ class DeviceStore(SQLBaseStore): if device_display_name: result["device_display_name"] = device_display_name - results.setdefault(user_id, {})[device_id] = result + results.append(result) - return now_stream_id, results + return (now_stream_id, results) def mark_as_sent_devices_by_remote(self, destination, stream_id): return self.runInteraction( @@ -212,19 +212,6 @@ class DeviceStore(SQLBaseStore): destination, stream_id, ) - @defer.inlineCallbacks - def get_user_whose_devices_changed(self, from_key): - from_key = int(from_key) - changed = self._device_list_stream_cache.get_all_entities_changed(from_key) - if changed is not None: - defer.returnValue(set(changed)) - - sql = """ - SELECT user_id FROM device_lists_stream WHERE stream_id > ? - """ - rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) - defer.returnValue(set(row["user_id"] for row in rows)) - def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id): sql = """ DELETE FROM device_lists_outbound_pokes @@ -239,7 +226,20 @@ class DeviceStore(SQLBaseStore): UPDATE device_lists_outbound_pokes SET sent = ? WHERE destination = ? AND stream_id <= ? """ - txn.execute(sql, (destination, True,)) + txn.execute(sql, (True, destination, stream_id,)) + + @defer.inlineCallbacks + def get_user_whose_devices_changed(self, from_key): + from_key = int(from_key) + changed = self._device_list_stream_cache.get_all_entities_changed(from_key) + if changed is not None: + defer.returnValue(set(changed)) + + sql = """ + SELECT user_id FROM device_lists_stream WHERE stream_id > ? + """ + rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key) + defer.returnValue(set(row["user_id"] for row in rows)) @defer.inlineCallbacks def add_device_change_to_streams(self, user_id, device_id, hosts):