Don't bother with a timeout for one time keys on the server.

This commit is contained in:
Mark Haines 2015-07-09 14:04:03 +01:00
parent 8fb79eeea4
commit bf0d59ed30
3 changed files with 13 additions and 33 deletions

View file

@ -50,7 +50,6 @@ class KeyUploadServlet(RestServlet):
"one_time_keys": { "one_time_keys": {
"<algorithm>:<key_id>": "<key_base64>" "<algorithm>:<key_id>": "<key_base64>"
}, },
"one_time_keys_valid_for": <millisecond duration>,
} }
""" """
PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)") PATTERN = client_v2_pattern("/keys/upload/(?P<device_id>[^/]*)")
@ -87,13 +86,10 @@ class KeyUploadServlet(RestServlet):
) )
one_time_keys = body.get("one_time_keys", None) one_time_keys = body.get("one_time_keys", None)
one_time_keys_valid_for = body.get("one_time_keys_valid_for", None)
if one_time_keys: if one_time_keys:
valid_until = int(one_time_keys_valid_for) + time_now
logger.info( logger.info(
"Adding %d one_time_keys for device %r for user %r at %d" "Adding %d one_time_keys for device %r for user %r at %d",
" valid_until %d", len(one_time_keys), device_id, user_id, time_now
len(one_time_keys), device_id, user_id, time_now, valid_until
) )
key_list = [] key_list = []
for key_id, key_json in one_time_keys.items(): for key_id, key_json in one_time_keys.items():
@ -103,23 +99,18 @@ class KeyUploadServlet(RestServlet):
)) ))
yield self.store.add_e2e_one_time_keys( yield self.store.add_e2e_one_time_keys(
user_id, device_id, time_now, valid_until, key_list user_id, device_id, time_now, key_list
) )
result = yield self.store.count_e2e_one_time_keys( result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
user_id, device_id, time_now
)
defer.returnValue((200, {"one_time_key_counts": result})) defer.returnValue((200, {"one_time_key_counts": result}))
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, device_id): def on_GET(self, request, device_id):
auth_user, client_info = yield self.auth.get_user_by_req(request) auth_user, client_info = yield self.auth.get_user_by_req(request)
user_id = auth_user.to_string() user_id = auth_user.to_string()
time_now = self.clock.time_msec()
result = yield self.store.count_e2e_one_time_keys( result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
user_id, device_id, time_now
)
defer.returnValue((200, {"one_time_key_counts": result})) defer.returnValue((200, {"one_time_key_counts": result}))
@ -249,9 +240,8 @@ class OneTimeKeyServlet(RestServlet):
@defer.inlineCallbacks @defer.inlineCallbacks
def on_GET(self, request, user_id, device_id, algorithm): def on_GET(self, request, user_id, device_id, algorithm):
yield self.auth.get_user_by_req(request) yield self.auth.get_user_by_req(request)
time_now = self.clock.time_msec()
results = yield self.store.take_e2e_one_time_keys( results = yield self.store.take_e2e_one_time_keys(
[(user_id, device_id, algorithm)], time_now [(user_id, device_id, algorithm)]
) )
defer.returnValue(self.json_result(request, results)) defer.returnValue(self.json_result(request, results))
@ -266,8 +256,7 @@ class OneTimeKeyServlet(RestServlet):
for user_id, device_keys in body.get("one_time_keys", {}).items(): for user_id, device_keys in body.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items(): for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm)) query.append((user_id, device_id, algorithm))
time_now = self.clock.time_msec() results = yield self.store.take_e2e_one_time_keys(query)
results = yield self.store.take_e2e_one_time_keys(query, time_now)
defer.returnValue(self.json_result(request, results)) defer.returnValue(self.json_result(request, results))
def json_result(self, request, results): def json_result(self, request, results):

View file

@ -55,14 +55,8 @@ class EndToEndKeyStore(SQLBaseStore):
return result return result
return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys)
def add_e2e_one_time_keys(self, user_id, device_id, time_now, valid_until, def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
key_list):
def _add_e2e_one_time_keys(txn): def _add_e2e_one_time_keys(txn):
sql = (
"DELETE FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND valid_until_ms < ?"
)
txn.execute(sql, (user_id, device_id, time_now))
for (algorithm, key_id, json_bytes) in key_list: for (algorithm, key_id, json_bytes) in key_list:
self._simple_upsert_txn( self._simple_upsert_txn(
txn, table="e2e_one_time_keys_json", txn, table="e2e_one_time_keys_json",
@ -74,7 +68,6 @@ class EndToEndKeyStore(SQLBaseStore):
}, },
values={ values={
"ts_added_ms": time_now, "ts_added_ms": time_now,
"valid_until_ms": valid_until,
"key_json": json_bytes, "key_json": json_bytes,
} }
) )
@ -82,7 +75,7 @@ class EndToEndKeyStore(SQLBaseStore):
"add_e2e_one_time_keys", _add_e2e_one_time_keys "add_e2e_one_time_keys", _add_e2e_one_time_keys
) )
def count_e2e_one_time_keys(self, user_id, device_id, time_now): def count_e2e_one_time_keys(self, user_id, device_id):
""" Count the number of one time keys the server has for a device """ Count the number of one time keys the server has for a device
Returns: Returns:
Dict mapping from algorithm to number of keys for that algorithm. Dict mapping from algorithm to number of keys for that algorithm.
@ -90,10 +83,10 @@ class EndToEndKeyStore(SQLBaseStore):
def _count_e2e_one_time_keys(txn): def _count_e2e_one_time_keys(txn):
sql = ( sql = (
"SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json" "SELECT algorithm, COUNT(key_id) FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND valid_until_ms >= ?" " WHERE user_id = ? AND device_id = ?"
" GROUP BY algorithm" " GROUP BY algorithm"
) )
txn.execute(sql, (user_id, device_id, time_now)) txn.execute(sql, (user_id, device_id))
result = {} result = {}
for algorithm, key_count in txn.fetchall(): for algorithm, key_count in txn.fetchall():
result[algorithm] = key_count result[algorithm] = key_count
@ -102,13 +95,12 @@ class EndToEndKeyStore(SQLBaseStore):
"count_e2e_one_time_keys", _count_e2e_one_time_keys "count_e2e_one_time_keys", _count_e2e_one_time_keys
) )
def take_e2e_one_time_keys(self, query_list, time_now): def take_e2e_one_time_keys(self, query_list):
"""Take a list of one time keys out of the database""" """Take a list of one time keys out of the database"""
def _take_e2e_one_time_keys(txn): def _take_e2e_one_time_keys(txn):
sql = ( sql = (
"SELECT key_id, key_json FROM e2e_one_time_keys_json" "SELECT key_id, key_json FROM e2e_one_time_keys_json"
" WHERE user_id = ? AND device_id = ? AND algorithm = ?" " WHERE user_id = ? AND device_id = ? AND algorithm = ?"
" AND valid_until_ms > ?"
" LIMIT 1" " LIMIT 1"
) )
result = {} result = {}
@ -116,7 +108,7 @@ class EndToEndKeyStore(SQLBaseStore):
for user_id, device_id, algorithm in query_list: for user_id, device_id, algorithm in query_list:
user_result = result.setdefault(user_id, {}) user_result = result.setdefault(user_id, {})
device_result = user_result.setdefault(device_id, {}) device_result = user_result.setdefault(device_id, {})
txn.execute(sql, (user_id, device_id, algorithm, time_now)) txn.execute(sql, (user_id, device_id, algorithm))
for key_id, key_json in txn.fetchall(): for key_id, key_json in txn.fetchall():
device_result[algorithm + ":" + key_id] = key_json device_result[algorithm + ":" + key_id] = key_json
delete.append((user_id, device_id, algorithm, key_id)) delete.append((user_id, device_id, algorithm, key_id))

View file

@ -29,7 +29,6 @@ CREATE TABLE IF NOT EXISTS e2e_one_time_keys_json (
algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for. algorithm TEXT NOT NULL, -- Which algorithm this one-time key is for.
key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads. key_id TEXT NOT NULL, -- An id for suppressing duplicate uploads.
ts_added_ms BIGINT NOT NULL, -- When this key was uploaded. ts_added_ms BIGINT NOT NULL, -- When this key was uploaded.
valid_until_ms BIGINT NOT NULL, -- When this key is valid until.
key_json TEXT NOT NULL, -- The key as a JSON blob. key_json TEXT NOT NULL, -- The key as a JSON blob.
CONSTRAINT uniqueness UNIQUE (user_id, device_id, algorithm, key_id) CONSTRAINT uniqueness UNIQUE (user_id, device_id, algorithm, key_id)
); );