diff --git a/synapse/storage/deviceinbox.py b/synapse/storage/deviceinbox.py index 68116b0394..57202a5bda 100644 --- a/synapse/storage/deviceinbox.py +++ b/synapse/storage/deviceinbox.py @@ -37,9 +37,21 @@ class DeviceInboxStore(SQLBaseStore): inserted. """ - def select_devices_txn(txn, user_id, devices): - if not devices: - return [] + with self._device_inbox_id_gen.get_next() as stream_id: + yield self.runInteraction( + "add_messages_to_device_inbox", + self._add_messages_to_device_inbox_txn, + stream_id, + messages_by_user_then_device, + ) + + defer.returnValue(self._device_inbox_id_gen.get_current_token()) + + def _add_messages_to_device_inbox_txn(self, txn, stream_id, + messages_by_user_then_device): + local_users_and_devices = set() + for user_id, messages_by_device in messages_by_user_then_device.items(): + devices = messages_by_device.keys() sql = ( "SELECT user_id, device_id FROM devices" " WHERE user_id = ? AND device_id IN (" @@ -48,41 +60,24 @@ class DeviceInboxStore(SQLBaseStore): ) # TODO: Maybe this needs to be done in batches if there are # too many local devices for a given user. - args = [user_id] + devices - txn.execute(sql, args) - return [tuple(row) for row in txn.fetchall()] + txn.execute(sql, [user_id] + devices) + local_users_and_devices.update(map(tuple, txn.fetchall())) - def add_messages_to_device_inbox_txn(txn, stream_id): - local_users_and_devices = set() - for user_id, messages_by_device in messages_by_user_then_device.items(): - local_users_and_devices.update( - select_devices_txn(txn, user_id, messages_by_device.keys()) - ) + sql = ( + "INSERT INTO device_inbox" + " (user_id, device_id, stream_id, message_json)" + " VALUES (?,?,?,?)" + ) + rows = [] + for user_id, messages_by_device in messages_by_user_then_device.items(): + for device_id, message in messages_by_device.items(): + message_json = ujson.dumps(message) + # Only insert into the local inbox if the device exists on + # this server + if (user_id, device_id) in local_users_and_devices: + rows.append((user_id, device_id, stream_id, message_json)) - sql = ( - "INSERT INTO device_inbox" - " (user_id, device_id, stream_id, message_json)" - " VALUES (?,?,?,?)" - ) - rows = [] - for user_id, messages_by_device in messages_by_user_then_device.items(): - for device_id, message in messages_by_device.items(): - message_json = ujson.dumps(message) - # Only insert into the local inbox if the device exists on - # this server - if (user_id, device_id) in local_users_and_devices: - rows.append((user_id, device_id, stream_id, message_json)) - - txn.executemany(sql, rows) - - with self._device_inbox_id_gen.get_next() as stream_id: - yield self.runInteraction( - "add_messages_to_device_inbox", - add_messages_to_device_inbox_txn, - stream_id - ) - - defer.returnValue(self._device_inbox_id_gen.get_current_token()) + txn.executemany(sql, rows) def get_new_messages_for_device( self, user_id, device_id, last_stream_id, current_stream_id, limit=100