Merge branch 'develop' into auth

This commit is contained in:
Daniel Wagner-Hall 2015-08-19 15:20:09 +01:00
commit f9e7493ac2
9 changed files with 128 additions and 25 deletions

View file

@ -352,6 +352,8 @@ class Auth(object):
if not user_id: if not user_id:
raise KeyError raise KeyError
request.authenticated_entity = user_id
defer.returnValue( defer.returnValue(
(UserID.from_string(user_id), ClientInfo("", "")) (UserID.from_string(user_id), ClientInfo("", ""))
) )
@ -425,6 +427,7 @@ class Auth(object):
"Unrecognised access token.", "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN errcode=Codes.UNKNOWN_TOKEN
) )
request.authenticated_entity = service.sender
defer.returnValue(service) defer.returnValue(service)
except KeyError: except KeyError:
raise AuthError( raise AuthError(

View file

@ -283,14 +283,15 @@ class AuthHandler(BaseHandler):
StoreError if there was a problem storing the token. StoreError if there was a problem storing the token.
LoginError if there was an authentication problem. LoginError if there was an authentication problem.
""" """
self._check_password(user_id, password) yield self._check_password(user_id, password)
reg_handler = self.hs.get_handlers().registration_handler reg_handler = self.hs.get_handlers().registration_handler
access_token = reg_handler.generate_token(user_id) access_token = reg_handler.generate_token(user_id)
logger.info("Adding token %s for user %s", access_token, user_id) logger.info("Logging in user %s", user_id)
yield self.store.add_access_token_to_user(user_id, access_token) yield self.store.add_access_token_to_user(user_id, access_token)
defer.returnValue(access_token) defer.returnValue(access_token)
@defer.inlineCallbacks
def _check_password(self, user_id, password): def _check_password(self, user_id, password):
"""Checks that user_id has passed password, raises LoginError if not.""" """Checks that user_id has passed password, raises LoginError if not."""
user_info = yield self.store.get_user_by_id(user_id=user_id) user_info = yield self.store.get_user_by_id(user_id=user_id)

View file

@ -171,7 +171,6 @@ class ReceiptEventSource(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit): def get_new_events_for_user(self, user, from_key, limit):
defer.returnValue(([], from_key))
from_key = int(from_key) from_key = int(from_key)
to_key = yield self.get_current_key() to_key = yield self.get_current_key()
@ -194,7 +193,6 @@ class ReceiptEventSource(object):
@defer.inlineCallbacks @defer.inlineCallbacks
def get_pagination_rows(self, user, config, key): def get_pagination_rows(self, user, config, key):
to_key = int(config.from_key) to_key = int(config.from_key)
defer.returnValue(([], to_key))
if config.to_key: if config.to_key:
from_key = int(config.to_key) from_key = int(config.to_key)

View file

@ -158,18 +158,40 @@ def runUntilCurrentTimer(func):
@functools.wraps(func) @functools.wraps(func)
def f(*args, **kwargs): def f(*args, **kwargs):
pending_calls = len(reactor.getDelayedCalls()) now = reactor.seconds()
num_pending = 0
# _newTimedCalls is one long list of *all* pending calls. Below loop
# is based off of impl of reactor.runUntilCurrent
for delayed_call in reactor._newTimedCalls:
if delayed_call.time > now:
break
if delayed_call.delayed_time > 0:
continue
num_pending += 1
num_pending += len(reactor.threadCallQueue)
start = time.time() * 1000 start = time.time() * 1000
ret = func(*args, **kwargs) ret = func(*args, **kwargs)
end = time.time() * 1000 end = time.time() * 1000
tick_time.inc_by(end - start) tick_time.inc_by(end - start)
pending_calls_metric.inc_by(pending_calls) pending_calls_metric.inc_by(num_pending)
return ret return ret
return f return f
if hasattr(reactor, "runUntilCurrent"): try:
# Ensure the reactor has all the attributes we expect
reactor.runUntilCurrent
reactor._newTimedCalls
reactor.threadCallQueue
# runUntilCurrent is called when we have pending calls. It is called once # runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling. # per iteratation after fd polling.
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent) reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
except AttributeError:
pass

View file

@ -54,7 +54,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database # Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts. # schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 21 SCHEMA_VERSION = 22
dir_path = os.path.abspath(os.path.dirname(__file__)) dir_path = os.path.abspath(os.path.dirname(__file__))

View file

@ -39,7 +39,7 @@ class PresenceStore(SQLBaseStore):
desc="has_presence_state", desc="has_presence_state",
) )
@cached() @cached(max_entries=2000)
def get_presence_state(self, user_localpart): def get_presence_state(self, user_localpart):
return self._simple_select_one( return self._simple_select_one(
table="presence", table="presence",

View file

@ -14,12 +14,11 @@
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cachedInlineCallbacks from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
from synapse.util.caches import cache_counter, caches_by_name
from twisted.internet import defer from twisted.internet import defer
from synapse.util import unwrapFirstError
from blist import sorteddict from blist import sorteddict
import logging import logging
import ujson as json import ujson as json
@ -54,19 +53,13 @@ class ReceiptsStore(SQLBaseStore):
self, room_ids, from_key self, room_ids, from_key
) )
results = yield defer.gatherResults( results = yield self._get_linearized_receipts_for_rooms(
[ room_ids, to_key, from_key=from_key
self.get_linearized_receipts_for_room( )
room_id, to_key, from_key=from_key
)
for room_id in room_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue([ev for res in results for ev in res]) defer.returnValue([ev for res in results.values() for ev in res])
@defer.inlineCallbacks @cachedInlineCallbacks(num_args=3, max_entries=5000)
def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None): def get_linearized_receipts_for_room(self, room_id, to_key, from_key=None):
"""Get receipts for a single room for sending to clients. """Get receipts for a single room for sending to clients.
@ -126,6 +119,66 @@ class ReceiptsStore(SQLBaseStore):
"content": content, "content": content,
}]) }])
@cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids",
num_args=3, inlineCallbacks=True)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
defer.returnValue({})
def f(txn):
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
) % (
",".join(["?"] * len(room_ids))
)
args = list(room_ids)
args.extend([from_key, to_key])
txn.execute(sql, args)
else:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id <= ?"
) % (
",".join(["?"] * len(room_ids))
)
args = list(room_ids)
args.append(to_key)
txn.execute(sql, args)
return self.cursor_to_dict(txn)
txn_results = yield self.runInteraction(
"_get_linearized_receipts_for_rooms", f
)
results = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
# receipts by room, event and type.
room_event = results.setdefault(row["room_id"], {
"type": "m.receipt",
"room_id": row["room_id"],
"content": {},
})
# The content is of the form:
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
receipt_type[row["user_id"]] = json.loads(row["data"])
results = {
room_id: [results[room_id]] if room_id in results else []
for room_id in room_ids
}
defer.returnValue(results)
def get_max_receipt_stream_id(self): def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token(self) return self._receipts_id_gen.get_max_token(self)
@ -305,6 +358,8 @@ class _RoomStreamChangeCache(object):
self._room_to_key = {} self._room_to_key = {}
self._cache = sorteddict() self._cache = sorteddict()
self._earliest_key = None self._earliest_key = None
self.name = "ReceiptsRoomChangeCache"
caches_by_name[self.name] = self._cache
@defer.inlineCallbacks @defer.inlineCallbacks
def get_rooms_changed(self, store, room_ids, key): def get_rooms_changed(self, store, room_ids, key):
@ -318,8 +373,11 @@ class _RoomStreamChangeCache(object):
result = set( result = set(
self._cache[k] for k in keys[i:] self._cache[k] for k in keys[i:]
).intersection(room_ids) ).intersection(room_ids)
cache_counter.inc_hits(self.name)
else: else:
result = room_ids result = room_ids
cache_counter.inc_misses(self.name)
defer.returnValue(result) defer.returnValue(result)

View file

@ -0,0 +1,18 @@
/* Copyright 2015 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX receipts_linearized_room_stream ON receipts_linearized(
room_id, stream_id
);

View file

@ -398,6 +398,7 @@ class StateStore(SQLBaseStore):
# for them again. # for them again.
state_dict = {key: None for key in types} state_dict = {key: None for key in types}
state_dict.update(results[group]) state_dict.update(results[group])
results[group] = state_dict
else: else:
state_dict = results[group] state_dict = results[group]
@ -412,9 +413,11 @@ class StateStore(SQLBaseStore):
full=(types is None), full=(types is None),
) )
# We replace here to remove all the entries with None values. # Remove all the entries with None values. The None values were just
# used for bookkeeping in the cache.
for group, state_dict in results.items():
results[group] = { results[group] = {
key: value for key, value in state_dict.items() if value key: event for key, event in state_dict.items() if event
} }
defer.returnValue(results) defer.returnValue(results)