Make auth & transactions more testable (#3499)

This commit is contained in:
Amber Brown 2018-07-14 07:34:49 +10:00 committed by GitHub
parent 2aba1f549c
commit 33b60c01b5
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 96 additions and 97 deletions

0
changelog.d/3499.misc Normal file
View file

View file

@ -193,7 +193,7 @@ class Auth(object):
synapse.types.create_requester(user_id, app_service=app_service)
)
access_token = get_access_token_from_request(
access_token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
@ -239,7 +239,7 @@ class Auth(object):
@defer.inlineCallbacks
def _get_appservice_user_id(self, request):
app_service = self.store.get_app_service_by_token(
get_access_token_from_request(
self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
)
@ -513,7 +513,7 @@ class Auth(object):
def get_appservice_by_req(self, request):
try:
token = get_access_token_from_request(
token = self.get_access_token_from_request(
request, self.TOKEN_NOT_FOUND_HTTP_STATUS
)
service = self.store.get_app_service_by_token(token)
@ -673,8 +673,8 @@ class Auth(object):
" edit its room list entry"
)
def has_access_token(request):
@staticmethod
def has_access_token(request):
"""Checks if the request has an access_token.
Returns:
@ -684,8 +684,8 @@ def has_access_token(request):
auth_headers = request.requestHeaders.getRawHeaders(b"Authorization")
return bool(query_params) or bool(auth_headers)
def get_access_token_from_request(request, token_not_found_http_status=401):
@staticmethod
def get_access_token_from_request(request, token_not_found_http_status=401):
"""Extracts the access_token from the request.
Args:

View file

@ -17,14 +17,28 @@
to ensure idempotency when performing PUTs using the REST API."""
import logging
from synapse.api.auth import get_access_token_from_request
from synapse.util.async import ObservableDeferred
from synapse.util.logcontext import make_deferred_yieldable, run_in_background
logger = logging.getLogger(__name__)
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
def get_transaction_key(request):
class HttpTransactionCache(object):
def __init__(self, hs):
self.hs = hs
self.auth = self.hs.get_auth()
self.clock = self.hs.get_clock()
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
# Try to clean entries every 30 mins. This means entries will exist
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
def _get_transaction_key(self, request):
"""A helper function which returns a transaction key that can be used
with TransactionCache for idempotent requests.
@ -38,24 +52,9 @@ def get_transaction_key(request):
Returns:
str: A transaction key
"""
token = get_access_token_from_request(request)
token = self.auth.get_access_token_from_request(request)
return request.path + "/" + token
CLEANUP_PERIOD_MS = 1000 * 60 * 30 # 30 mins
class HttpTransactionCache(object):
def __init__(self, clock):
self.clock = clock
self.transactions = {
# $txn_key: (ObservableDeferred<(res_code, res_json_body)>, timestamp)
}
# Try to clean entries every 30 mins. This means entries will exist
# for at *LEAST* 30 mins, and at *MOST* 60 mins.
self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS)
def fetch_or_execute_request(self, request, fn, *args, **kwargs):
"""A helper function for fetch_or_execute which extracts
a transaction key from the given request.
@ -64,7 +63,7 @@ class HttpTransactionCache(object):
fetch_or_execute
"""
return self.fetch_or_execute(
get_transaction_key(request), fn, *args, **kwargs
self._get_transaction_key(request), fn, *args, **kwargs
)
def fetch_or_execute(self, txn_key, fn, *args, **kwargs):

View file

@ -62,4 +62,4 @@ class ClientV1RestServlet(RestServlet):
self.hs = hs
self.builder_factory = hs.get_event_builder_factory()
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock())
self.txns = HttpTransactionCache(hs)

View file

@ -17,7 +17,6 @@ import logging
from twisted.internet import defer
from synapse.api.auth import get_access_token_from_request
from synapse.api.errors import AuthError
from .base import ClientV1RestServlet, client_path_patterns
@ -51,7 +50,7 @@ class LogoutRestServlet(ClientV1RestServlet):
if requester.device_id is None:
# the acccess token wasn't associated with a device.
# Just delete the access token
access_token = get_access_token_from_request(request)
access_token = self._auth.get_access_token_from_request(request)
yield self._auth_handler.delete_access_token(access_token)
else:
yield self._device_handler.delete_device(

View file

@ -23,7 +23,6 @@ from six import string_types
from twisted.internet import defer
import synapse.util.stringutils as stringutils
from synapse.api.auth import get_access_token_from_request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import parse_json_object_from_request
@ -67,6 +66,7 @@ class RegisterRestServlet(ClientV1RestServlet):
# TODO: persistent storage
self.sessions = {}
self.enable_registration = hs.config.enable_registration
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self.handlers = hs.get_handlers()
@ -310,7 +310,7 @@ class RegisterRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def _do_app_service(self, request, register_json, session):
as_token = get_access_token_from_request(request)
as_token = self.auth.get_access_token_from_request(request)
if "user" not in register_json:
raise SynapseError(400, "Expected 'user' key.")
@ -400,7 +400,7 @@ class CreateUserRestServlet(ClientV1RestServlet):
def on_POST(self, request):
user_json = parse_json_object_from_request(request)
access_token = get_access_token_from_request(request)
access_token = self.auth.get_access_token_from_request(request)
app_service = self.store.get_app_service_by_token(
access_token
)

View file

@ -20,7 +20,6 @@ from six.moves import http_client
from twisted.internet import defer
from synapse.api.auth import has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError
from synapse.http.servlet import (
@ -130,7 +129,7 @@ class PasswordRestServlet(RestServlet):
#
# In the second case, we require a password to confirm their identity.
if has_access_token(request):
if self.auth.has_access_token(request):
requester = yield self.auth.get_user_by_req(request)
params = yield self.auth_handler.validate_user_via_ui_auth(
requester, body, self.hs.get_ip_from_request(request),

View file

@ -24,7 +24,6 @@ from twisted.internet import defer
import synapse
import synapse.types
from synapse.api.auth import get_access_token_from_request, has_access_token
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, SynapseError, UnrecognizedRequestError
from synapse.http.servlet import (
@ -224,7 +223,7 @@ class RegisterRestServlet(RestServlet):
desired_username = body['username']
appservice = None
if has_access_token(request):
if self.auth.has_access_token(request):
appservice = yield self.auth.get_appservice_by_req(request)
# fork off as soon as possible for ASes and shared secret auth which
@ -242,7 +241,7 @@ class RegisterRestServlet(RestServlet):
# because the IRC bridges rely on being able to register stupid
# IDs.
access_token = get_access_token_from_request(request)
access_token = self.auth.get_access_token_from_request(request)
if isinstance(desired_username, string_types):
result = yield self._do_appservice_registration(

View file

@ -40,7 +40,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
super(SendToDeviceRestServlet, self).__init__()
self.hs = hs
self.auth = hs.get_auth()
self.txns = HttpTransactionCache(hs.get_clock())
self.txns = HttpTransactionCache(hs)
self.device_message_handler = hs.get_device_message_handler()
def on_PUT(self, request, message_type, txn_id):

View file

@ -14,7 +14,10 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self):
self.clock = MockClock()
self.cache = HttpTransactionCache(self.clock)
self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (200, "GOOD JOB!")
self.mock_key = "foo"