synapse.api.auth.Auth cleanup: make permission-related methods use Requester instead of the UserID (#13024)

Part of #13019

This changes all the permission-related methods to rely on the Requester instead of the UserID. This is a first step towards enabling scoped access tokens at some point, since I expect the Requester to have scope-related informations in it.

It also changes methods which figure out the user/device/appservice out of the access token to return a Requester instead of something else. This avoids having store-related objects in the methods signatures.
This commit is contained in:
Quentin Gliech 2022-08-22 15:17:59 +02:00 committed by GitHub
parent 94375f7a91
commit 3dd175b628
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 202 additions and 207 deletions

1
changelog.d/13024.misc Normal file
View file

@ -0,0 +1 @@
Refactor methods in `synapse.api.auth.Auth` to use `Requester` objects everywhere instead of user IDs.

View file

@ -37,8 +37,7 @@ from synapse.logging.opentracing import (
start_active_span, start_active_span,
trace, trace,
) )
from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import Requester, create_requester
from synapse.types import Requester, UserID, create_requester
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -70,14 +69,14 @@ class Auth:
async def check_user_in_room( async def check_user_in_room(
self, self,
room_id: str, room_id: str,
user_id: str, requester: Requester,
allow_departed_users: bool = False, allow_departed_users: bool = False,
) -> Tuple[str, Optional[str]]: ) -> Tuple[str, Optional[str]]:
"""Check if the user is in the room, or was at some point. """Check if the user is in the room, or was at some point.
Args: Args:
room_id: The room to check. room_id: The room to check.
user_id: The user to check. requester: The user making the request, according to the access token.
current_state: Optional map of the current state of the room. current_state: Optional map of the current state of the room.
If provided then that map is used to check whether they are a If provided then that map is used to check whether they are a
@ -94,6 +93,7 @@ class Auth:
membership event ID of the user. membership event ID of the user.
""" """
user_id = requester.user.to_string()
( (
membership, membership,
member_event_id, member_event_id,
@ -182,96 +182,69 @@ class Auth:
access_token = self.get_access_token_from_request(request) access_token = self.get_access_token_from_request(request)
( # First check if it could be a request from an appservice
user_id, requester = await self._get_appservice_user(request)
device_id, if not requester:
app_service, # If not, it should be from a regular user
) = await self._get_appservice_user_id_and_device_id(request) requester = await self.get_user_by_access_token(
if user_id and app_service: access_token, allow_expired=allow_expired
if ip_addr and self._track_appservice_user_ips:
await self.store.insert_client_ip(
user_id=user_id,
access_token=access_token,
ip=ip_addr,
user_agent=user_agent,
device_id="dummy-device"
if device_id is None
else device_id, # stubbed
)
requester = create_requester(
user_id, app_service=app_service, device_id=device_id
) )
request.requester = user_id # Deny the request if the user account has expired.
return requester # This check is only done for regular users, not appservice ones.
if not allow_expired:
if await self._account_validity_handler.is_user_expired(
requester.user.to_string()
):
# Raise the error if either an account validity module has determined
# the account has expired, or the legacy account validity
# implementation is enabled and determined the account has expired
raise AuthError(
403,
"User account has expired",
errcode=Codes.EXPIRED_ACCOUNT,
)
user_info = await self.get_user_by_access_token( if ip_addr and (
access_token, allow_expired=allow_expired not requester.app_service or self._track_appservice_user_ips
) ):
token_id = user_info.token_id # XXX(quenting): I'm 95% confident that we could skip setting the
is_guest = user_info.is_guest # device_id to "dummy-device" for appservices, and that the only impact
shadow_banned = user_info.shadow_banned # would be some rows which whould not deduplicate in the 'user_ips'
# table during the transition
# Deny the request if the user account has expired. recorded_device_id = (
if not allow_expired: "dummy-device"
if await self._account_validity_handler.is_user_expired( if requester.device_id is None and requester.app_service is not None
user_info.user_id else requester.device_id
): )
# Raise the error if either an account validity module has determined
# the account has expired, or the legacy account validity
# implementation is enabled and determined the account has expired
raise AuthError(
403,
"User account has expired",
errcode=Codes.EXPIRED_ACCOUNT,
)
device_id = user_info.device_id
if access_token and ip_addr:
await self.store.insert_client_ip( await self.store.insert_client_ip(
user_id=user_info.token_owner, user_id=requester.authenticated_entity,
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
user_agent=user_agent, user_agent=user_agent,
device_id=device_id, device_id=recorded_device_id,
) )
# Track also the puppeted user client IP if enabled and the user is puppeting # Track also the puppeted user client IP if enabled and the user is puppeting
if ( if (
user_info.user_id != user_info.token_owner requester.user.to_string() != requester.authenticated_entity
and self._track_puppeted_user_ips and self._track_puppeted_user_ips
): ):
await self.store.insert_client_ip( await self.store.insert_client_ip(
user_id=user_info.user_id, user_id=requester.user.to_string(),
access_token=access_token, access_token=access_token,
ip=ip_addr, ip=ip_addr,
user_agent=user_agent, user_agent=user_agent,
device_id=device_id, device_id=requester.device_id,
) )
if is_guest and not allow_guest: if requester.is_guest and not allow_guest:
raise AuthError( raise AuthError(
403, 403,
"Guest access not allowed", "Guest access not allowed",
errcode=Codes.GUEST_ACCESS_FORBIDDEN, errcode=Codes.GUEST_ACCESS_FORBIDDEN,
) )
# Mark the token as used. This is used to invalidate old refresh
# tokens after some time.
if not user_info.token_used and token_id is not None:
await self.store.mark_access_token_as_used(token_id)
requester = create_requester(
user_info.user_id,
token_id,
is_guest,
shadow_banned,
device_id,
app_service=app_service,
authenticated_entity=user_info.token_owner,
)
request.requester = requester request.requester = requester
return requester return requester
except KeyError: except KeyError:
@ -308,9 +281,7 @@ class Auth:
403, "Application service has not registered this user (%s)" % user_id 403, "Application service has not registered this user (%s)" % user_id
) )
async def _get_appservice_user_id_and_device_id( async def _get_appservice_user(self, request: Request) -> Optional[Requester]:
self, request: Request
) -> Tuple[Optional[str], Optional[str], Optional[ApplicationService]]:
""" """
Given a request, reads the request parameters to determine: Given a request, reads the request parameters to determine:
- whether it's an application service that's making this request - whether it's an application service that's making this request
@ -325,15 +296,13 @@ class Auth:
Must use `org.matrix.msc3202.device_id` in place of `device_id` for now. Must use `org.matrix.msc3202.device_id` in place of `device_id` for now.
Returns: Returns:
3-tuple of the application service `Requester` of that request
(user ID?, device ID?, application service?)
Postconditions: Postconditions:
- If an application service is returned, so is a user ID - The `app_service` field in the returned `Requester` is set
- A user ID is never returned without an application service - The `user_id` field in the returned `Requester` is either the application
- A device ID is never returned without a user ID or an application service service sender or the controlled user set by the `user_id` URI parameter
- The returned application service, if present, is permitted to control the - The returned application service is permitted to control the returned user ID.
returned user ID.
- The returned device ID, if present, has been checked to be a valid device ID - The returned device ID, if present, has been checked to be a valid device ID
for the returned user ID. for the returned user ID.
""" """
@ -343,12 +312,12 @@ class Auth:
self.get_access_token_from_request(request) self.get_access_token_from_request(request)
) )
if app_service is None: if app_service is None:
return None, None, None return None
if app_service.ip_range_whitelist: if app_service.ip_range_whitelist:
ip_address = IPAddress(request.getClientAddress().host) ip_address = IPAddress(request.getClientAddress().host)
if ip_address not in app_service.ip_range_whitelist: if ip_address not in app_service.ip_range_whitelist:
return None, None, None return None
# This will always be set by the time Twisted calls us. # This will always be set by the time Twisted calls us.
assert request.args is not None assert request.args is not None
@ -382,13 +351,15 @@ class Auth:
Codes.EXCLUSIVE, Codes.EXCLUSIVE,
) )
return effective_user_id, effective_device_id, app_service return create_requester(
effective_user_id, app_service=app_service, device_id=effective_device_id
)
async def get_user_by_access_token( async def get_user_by_access_token(
self, self,
token: str, token: str,
allow_expired: bool = False, allow_expired: bool = False,
) -> TokenLookupResult: ) -> Requester:
"""Validate access token and get user_id from it """Validate access token and get user_id from it
Args: Args:
@ -405,9 +376,9 @@ class Auth:
# First look in the database to see if the access token is present # First look in the database to see if the access token is present
# as an opaque token. # as an opaque token.
r = await self.store.get_user_by_access_token(token) user_info = await self.store.get_user_by_access_token(token)
if r: if user_info:
valid_until_ms = r.valid_until_ms valid_until_ms = user_info.valid_until_ms
if ( if (
not allow_expired not allow_expired
and valid_until_ms is not None and valid_until_ms is not None
@ -419,7 +390,20 @@ class Auth:
msg="Access token has expired", soft_logout=True msg="Access token has expired", soft_logout=True
) )
return r # Mark the token as used. This is used to invalidate old refresh
# tokens after some time.
await self.store.mark_access_token_as_used(user_info.token_id)
requester = create_requester(
user_id=user_info.user_id,
access_token_id=user_info.token_id,
is_guest=user_info.is_guest,
shadow_banned=user_info.shadow_banned,
device_id=user_info.device_id,
authenticated_entity=user_info.token_owner,
)
return requester
# If the token isn't found in the database, then it could still be a # If the token isn't found in the database, then it could still be a
# macaroon for a guest, so we check that here. # macaroon for a guest, so we check that here.
@ -445,11 +429,12 @@ class Auth:
"Guest access token used for regular user" "Guest access token used for regular user"
) )
return TokenLookupResult( return create_requester(
user_id=user_id, user_id=user_id,
is_guest=True, is_guest=True,
# all guests get the same device id # all guests get the same device id
device_id=GUEST_DEVICE_ID, device_id=GUEST_DEVICE_ID,
authenticated_entity=user_id,
) )
except ( except (
pymacaroons.exceptions.MacaroonException, pymacaroons.exceptions.MacaroonException,
@ -472,32 +457,33 @@ class Auth:
request.requester = create_requester(service.sender, app_service=service) request.requester = create_requester(service.sender, app_service=service)
return service return service
async def is_server_admin(self, user: UserID) -> bool: async def is_server_admin(self, requester: Requester) -> bool:
"""Check if the given user is a local server admin. """Check if the given user is a local server admin.
Args: Args:
user: user to check requester: The user making the request, according to the access token.
Returns: Returns:
True if the user is an admin True if the user is an admin
""" """
return await self.store.is_server_admin(user) return await self.store.is_server_admin(requester.user)
async def check_can_change_room_list(self, room_id: str, user: UserID) -> bool: async def check_can_change_room_list(
self, room_id: str, requester: Requester
) -> bool:
"""Determine whether the user is allowed to edit the room's entry in the """Determine whether the user is allowed to edit the room's entry in the
published room list. published room list.
Args: Args:
room_id room_id: The room to check.
user requester: The user making the request, according to the access token.
""" """
is_admin = await self.is_server_admin(user) is_admin = await self.is_server_admin(requester)
if is_admin: if is_admin:
return True return True
user_id = user.to_string() await self.check_user_in_room(room_id, requester)
await self.check_user_in_room(room_id, user_id)
# We currently require the user is a "moderator" in the room. We do this # We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the # by checking if they would (theoretically) be able to change the
@ -516,7 +502,9 @@ class Auth:
send_level = event_auth.get_send_level( send_level = event_auth.get_send_level(
EventTypes.CanonicalAlias, "", power_level_event EventTypes.CanonicalAlias, "", power_level_event
) )
user_level = event_auth.get_user_power_level(user_id, auth_events) user_level = event_auth.get_user_power_level(
requester.user.to_string(), auth_events
)
return user_level >= send_level return user_level >= send_level
@ -574,16 +562,16 @@ class Auth:
@trace @trace
async def check_user_in_room_or_world_readable( async def check_user_in_room_or_world_readable(
self, room_id: str, user_id: str, allow_departed_users: bool = False self, room_id: str, requester: Requester, allow_departed_users: bool = False
) -> Tuple[str, Optional[str]]: ) -> Tuple[str, Optional[str]]:
"""Checks that the user is or was in the room or the room is world """Checks that the user is or was in the room or the room is world
readable. If it isn't then an exception is raised. readable. If it isn't then an exception is raised.
Args: Args:
room_id: room to check room_id: The room to check.
user_id: user to check requester: The user making the request, according to the access token.
allow_departed_users: if True, accept users that were previously allow_departed_users: If True, accept users that were previously
members but have now departed members but have now departed.
Returns: Returns:
Resolves to the current membership of the user in the room and the Resolves to the current membership of the user in the room and the
@ -598,7 +586,7 @@ class Auth:
# * The user is a guest user, and has joined the room # * The user is a guest user, and has joined the room
# else it will throw. # else it will throw.
return await self.check_user_in_room( return await self.check_user_in_room(
room_id, user_id, allow_departed_users=allow_departed_users room_id, requester, allow_departed_users=allow_departed_users
) )
except AuthError: except AuthError:
visibility = await self._storage_controllers.state.get_current_state_event( visibility = await self._storage_controllers.state.get_current_state_event(
@ -613,6 +601,6 @@ class Auth:
raise UnstableSpecAuthError( raise UnstableSpecAuthError(
403, 403,
"User %s not in room %s, and room previews are disabled" "User %s not in room %s, and room previews are disabled"
% (user_id, room_id), % (requester.user, room_id),
errcode=Codes.NOT_JOINED, errcode=Codes.NOT_JOINED,
) )

View file

@ -280,7 +280,7 @@ class AuthHandler:
that it isn't stolen by re-authenticating them. that it isn't stolen by re-authenticating them.
Args: Args:
requester: The user, as given by the access token requester: The user making the request, according to the access token.
request: The request sent by the client. request: The request sent by the client.
@ -1435,20 +1435,25 @@ class AuthHandler:
access_token: access token to be deleted access_token: access token to be deleted
""" """
user_info = await self.auth.get_user_by_access_token(access_token) token = await self.store.get_user_by_access_token(access_token)
if not token:
# At this point, the token should already have been fetched once by
# the caller, so this should not happen, unless of a race condition
# between two delete requests
raise SynapseError(HTTPStatus.UNAUTHORIZED, "Unrecognised access token")
await self.store.delete_access_token(access_token) await self.store.delete_access_token(access_token)
# see if any modules want to know about this # see if any modules want to know about this
await self.password_auth_provider.on_logged_out( await self.password_auth_provider.on_logged_out(
user_id=user_info.user_id, user_id=token.user_id,
device_id=user_info.device_id, device_id=token.device_id,
access_token=access_token, access_token=access_token,
) )
# delete pushers associated with this access token # delete pushers associated with this access token
if user_info.token_id is not None: if token.token_id is not None:
await self.hs.get_pusherpool().remove_pushers_by_access_token( await self.hs.get_pusherpool().remove_pushers_by_access_token(
user_info.user_id, (user_info.token_id,) token.user_id, (token.token_id,)
) )
async def delete_access_tokens_for_user( async def delete_access_tokens_for_user(

View file

@ -30,7 +30,7 @@ from synapse.api.errors import (
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.module_api import NOT_SPAM from synapse.module_api import NOT_SPAM
from synapse.storage.databases.main.directory import RoomAliasMapping from synapse.storage.databases.main.directory import RoomAliasMapping
from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id from synapse.types import JsonDict, Requester, RoomAlias, get_domain_from_id
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -133,7 +133,7 @@ class DirectoryHandler:
else: else:
# Server admins are not subject to the same constraints as normal # Server admins are not subject to the same constraints as normal
# users when creating an alias (e.g. being in the room). # users when creating an alias (e.g. being in the room).
is_admin = await self.auth.is_server_admin(requester.user) is_admin = await self.auth.is_server_admin(requester)
if (self.require_membership and check_membership) and not is_admin: if (self.require_membership and check_membership) and not is_admin:
rooms_for_user = await self.store.get_rooms_for_user(user_id) rooms_for_user = await self.store.get_rooms_for_user(user_id)
@ -197,7 +197,7 @@ class DirectoryHandler:
user_id = requester.user.to_string() user_id = requester.user.to_string()
try: try:
can_delete = await self._user_can_delete_alias(room_alias, user_id) can_delete = await self._user_can_delete_alias(room_alias, requester)
except StoreError as e: except StoreError as e:
if e.code == 404: if e.code == 404:
raise NotFoundError("Unknown room alias") raise NotFoundError("Unknown room alias")
@ -400,7 +400,9 @@ class DirectoryHandler:
# either no interested services, or no service with an exclusive lock # either no interested services, or no service with an exclusive lock
return True return True
async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool: async def _user_can_delete_alias(
self, alias: RoomAlias, requester: Requester
) -> bool:
"""Determine whether a user can delete an alias. """Determine whether a user can delete an alias.
One of the following must be true: One of the following must be true:
@ -413,7 +415,7 @@ class DirectoryHandler:
""" """
creator = await self.store.get_room_alias_creator(alias.to_string()) creator = await self.store.get_room_alias_creator(alias.to_string())
if creator == user_id: if creator == requester.user.to_string():
return True return True
# Resolve the alias to the corresponding room. # Resolve the alias to the corresponding room.
@ -422,9 +424,7 @@ class DirectoryHandler:
if not room_id: if not room_id:
return False return False
return await self.auth.check_can_change_room_list( return await self.auth.check_can_change_room_list(room_id, requester)
room_id, UserID.from_string(user_id)
)
async def edit_published_room_list( async def edit_published_room_list(
self, requester: Requester, room_id: str, visibility: str self, requester: Requester, room_id: str, visibility: str
@ -463,7 +463,7 @@ class DirectoryHandler:
raise SynapseError(400, "Unknown room") raise SynapseError(400, "Unknown room")
can_change_room_list = await self.auth.check_can_change_room_list( can_change_room_list = await self.auth.check_can_change_room_list(
room_id, requester.user room_id, requester
) )
if not can_change_room_list: if not can_change_room_list:
raise AuthError( raise AuthError(
@ -528,10 +528,8 @@ class DirectoryHandler:
Get a list of the aliases that currently point to this room on this server Get a list of the aliases that currently point to this room on this server
""" """
# allow access to server admins and current members of the room # allow access to server admins and current members of the room
is_admin = await self.auth.is_server_admin(requester.user) is_admin = await self.auth.is_server_admin(requester)
if not is_admin: if not is_admin:
await self.auth.check_user_in_room_or_world_readable( await self.auth.check_user_in_room_or_world_readable(room_id, requester)
room_id, requester.user.to_string()
)
return await self.store.get_aliases_for_room(room_id) return await self.store.get_aliases_for_room(room_id)

View file

@ -309,18 +309,18 @@ class InitialSyncHandler:
if blocked: if blocked:
raise SynapseError(403, "This room has been blocked on this server") raise SynapseError(403, "This room has been blocked on this server")
user_id = requester.user.to_string()
( (
membership, membership,
member_event_id, member_event_id,
) = await self.auth.check_user_in_room_or_world_readable( ) = await self.auth.check_user_in_room_or_world_readable(
room_id, room_id,
user_id, requester,
allow_departed_users=True, allow_departed_users=True,
) )
is_peeking = member_event_id is None is_peeking = member_event_id is None
user_id = requester.user.to_string()
if membership == Membership.JOIN: if membership == Membership.JOIN:
result = await self._room_initial_sync_joined( result = await self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking user_id, room_id, pagin_config, membership, is_peeking

View file

@ -104,7 +104,7 @@ class MessageHandler:
async def get_room_data( async def get_room_data(
self, self,
user_id: str, requester: Requester,
room_id: str, room_id: str,
event_type: str, event_type: str,
state_key: str, state_key: str,
@ -112,7 +112,7 @@ class MessageHandler:
"""Get data from a room. """Get data from a room.
Args: Args:
user_id requester: The user who did the request.
room_id room_id
event_type event_type
state_key state_key
@ -125,7 +125,7 @@ class MessageHandler:
membership, membership,
membership_event_id, membership_event_id,
) = await self.auth.check_user_in_room_or_world_readable( ) = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True room_id, requester, allow_departed_users=True
) )
if membership == Membership.JOIN: if membership == Membership.JOIN:
@ -161,11 +161,10 @@ class MessageHandler:
async def get_state_events( async def get_state_events(
self, self,
user_id: str, requester: Requester,
room_id: str, room_id: str,
state_filter: Optional[StateFilter] = None, state_filter: Optional[StateFilter] = None,
at_token: Optional[StreamToken] = None, at_token: Optional[StreamToken] = None,
is_guest: bool = False,
) -> List[dict]: ) -> List[dict]:
"""Retrieve all state events for a given room. If the user is """Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has joined to the room then return the current state. If the user has
@ -174,14 +173,13 @@ class MessageHandler:
visible. visible.
Args: Args:
user_id: The user requesting state events. requester: The user requesting state events.
room_id: The room ID to get all state events from. room_id: The room ID to get all state events from.
state_filter: The state filter used to fetch state from the database. state_filter: The state filter used to fetch state from the database.
at_token: the stream token of the at which we are requesting at_token: the stream token of the at which we are requesting
the stats. If the user is not allowed to view the state as of that the stats. If the user is not allowed to view the state as of that
stream token, we raise a 403 SynapseError. If None, returns the current stream token, we raise a 403 SynapseError. If None, returns the current
state based on the current_state_events table. state based on the current_state_events table.
is_guest: whether this user is a guest
Returns: Returns:
A list of dicts representing state events. [{}, {}, {}] A list of dicts representing state events. [{}, {}, {}]
Raises: Raises:
@ -191,6 +189,7 @@ class MessageHandler:
members of this room. members of this room.
""" """
state_filter = state_filter or StateFilter.all() state_filter = state_filter or StateFilter.all()
user_id = requester.user.to_string()
if at_token: if at_token:
last_event_id = ( last_event_id = (
@ -223,7 +222,7 @@ class MessageHandler:
membership, membership,
membership_event_id, membership_event_id,
) = await self.auth.check_user_in_room_or_world_readable( ) = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True room_id, requester, allow_departed_users=True
) )
if membership == Membership.JOIN: if membership == Membership.JOIN:
@ -317,12 +316,11 @@ class MessageHandler:
Returns: Returns:
A dict of user_id to profile info A dict of user_id to profile info
""" """
user_id = requester.user.to_string()
if not requester.app_service: if not requester.app_service:
# We check AS auth after fetching the room membership, as it # We check AS auth after fetching the room membership, as it
# requires us to pull out all joined members anyway. # requires us to pull out all joined members anyway.
membership, _ = await self.auth.check_user_in_room_or_world_readable( membership, _ = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True room_id, requester, allow_departed_users=True
) )
if membership != Membership.JOIN: if membership != Membership.JOIN:
raise SynapseError( raise SynapseError(
@ -340,7 +338,10 @@ class MessageHandler:
# If this is an AS, double check that they are allowed to see the members. # If this is an AS, double check that they are allowed to see the members.
# This can either be because the AS user is in the room or because there # This can either be because the AS user is in the room or because there
# is a user in the room that the AS is "interested in" # is a user in the room that the AS is "interested in"
if requester.app_service and user_id not in users_with_profile: if (
requester.app_service
and requester.user.to_string() not in users_with_profile
):
for uid in users_with_profile: for uid in users_with_profile:
if requester.app_service.is_interested_in_user(uid): if requester.app_service.is_interested_in_user(uid):
break break

View file

@ -464,7 +464,7 @@ class PaginationHandler:
membership, membership,
member_event_id, member_event_id,
) = await self.auth.check_user_in_room_or_world_readable( ) = await self.auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True room_id, requester, allow_departed_users=True
) )
if pagin_config.direction == "b": if pagin_config.direction == "b":

View file

@ -29,7 +29,13 @@ from synapse.api.constants import (
JoinRules, JoinRules,
LoginType, LoginType,
) )
from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError from synapse.api.errors import (
AuthError,
Codes,
ConsentNotGivenError,
InvalidClientTokenError,
SynapseError,
)
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.config.server import is_threepid_reserved from synapse.config.server import is_threepid_reserved
from synapse.http.servlet import assert_params_in_dict from synapse.http.servlet import assert_params_in_dict
@ -180,10 +186,7 @@ class RegistrationHandler:
) )
if guest_access_token: if guest_access_token:
user_data = await self.auth.get_user_by_access_token(guest_access_token) user_data = await self.auth.get_user_by_access_token(guest_access_token)
if ( if not user_data.is_guest or user_data.user.localpart != localpart:
not user_data.is_guest
or UserID.from_string(user_data.user_id).localpart != localpart
):
raise AuthError( raise AuthError(
403, 403,
"Cannot register taken user ID without valid guest " "Cannot register taken user ID without valid guest "
@ -618,7 +621,7 @@ class RegistrationHandler:
user_id = user.to_string() user_id = user.to_string()
service = self.store.get_app_service_by_token(as_token) service = self.store.get_app_service_by_token(as_token)
if not service: if not service:
raise AuthError(403, "Invalid application service token.") raise InvalidClientTokenError()
if not service.is_interested_in_user(user_id): if not service.is_interested_in_user(user_id):
raise SynapseError( raise SynapseError(
400, 400,

View file

@ -103,7 +103,7 @@ class RelationsHandler:
# TODO Properly handle a user leaving a room. # TODO Properly handle a user leaving a room.
(_, member_event_id) = await self._auth.check_user_in_room_or_world_readable( (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True room_id, requester, allow_departed_users=True
) )
# This gets the original event and checks that a) the event exists and # This gets the original event and checks that a) the event exists and

View file

@ -721,7 +721,7 @@ class RoomCreationHandler:
# allow the server notices mxid to create rooms # allow the server notices mxid to create rooms
is_requester_admin = True is_requester_admin = True
else: else:
is_requester_admin = await self.auth.is_server_admin(requester.user) is_requester_admin = await self.auth.is_server_admin(requester)
# Let the third party rules modify the room creation config if needed, or abort # Let the third party rules modify the room creation config if needed, or abort
# the room creation entirely with an exception. # the room creation entirely with an exception.
@ -1279,7 +1279,7 @@ class RoomContextHandler:
""" """
user = requester.user user = requester.user
if use_admin_priviledge: if use_admin_priviledge:
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester)
before_limit = math.floor(limit / 2.0) before_limit = math.floor(limit / 2.0)
after_limit = limit - before_limit after_limit = limit - before_limit

View file

@ -179,7 +179,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
"""Try and join a room that this server is not in """Try and join a room that this server is not in
Args: Args:
requester requester: The user making the request, according to the access token.
remote_room_hosts: List of servers that can be used to join via. remote_room_hosts: List of servers that can be used to join via.
room_id: Room that we are trying to join room_id: Room that we are trying to join
user: User who is trying to join user: User who is trying to join
@ -744,7 +744,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
is_requester_admin = True is_requester_admin = True
else: else:
is_requester_admin = await self.auth.is_server_admin(requester.user) is_requester_admin = await self.auth.is_server_admin(requester)
if not is_requester_admin: if not is_requester_admin:
if self.config.server.block_non_admin_invites: if self.config.server.block_non_admin_invites:
@ -868,7 +868,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
bypass_spam_checker = True bypass_spam_checker = True
else: else:
bypass_spam_checker = await self.auth.is_server_admin(requester.user) bypass_spam_checker = await self.auth.is_server_admin(requester)
inviter = await self._get_inviter(target.to_string(), room_id) inviter = await self._get_inviter(target.to_string(), room_id)
if ( if (
@ -1410,7 +1410,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ShadowBanError if the requester has been shadow-banned. ShadowBanError if the requester has been shadow-banned.
""" """
if self.config.server.block_non_admin_invites: if self.config.server.block_non_admin_invites:
is_requester_admin = await self.auth.is_server_admin(requester.user) is_requester_admin = await self.auth.is_server_admin(requester)
if not is_requester_admin: if not is_requester_admin:
raise SynapseError( raise SynapseError(
403, "Invites have been disabled on this server", Codes.FORBIDDEN 403, "Invites have been disabled on this server", Codes.FORBIDDEN
@ -1693,7 +1693,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
check_complexity check_complexity
and self.hs.config.server.limit_remote_rooms.admins_can_join and self.hs.config.server.limit_remote_rooms.admins_can_join
): ):
check_complexity = not await self.auth.is_server_admin(user) check_complexity = not await self.store.is_server_admin(user)
if check_complexity: if check_complexity:
# Fetch the room complexity # Fetch the room complexity

View file

@ -253,12 +253,11 @@ class TypingWriterHandler(FollowerTypingHandler):
self, target_user: UserID, requester: Requester, room_id: str, timeout: int self, target_user: UserID, requester: Requester, room_id: str, timeout: int
) -> None: ) -> None:
target_user_id = target_user.to_string() target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string()
if not self.is_mine_id(target_user_id): if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this homeserver") raise SynapseError(400, "User is not hosted on this homeserver")
if target_user_id != auth_user_id: if target_user != requester.user:
raise AuthError(400, "Cannot set another user's typing state") raise AuthError(400, "Cannot set another user's typing state")
if requester.shadow_banned: if requester.shadow_banned:
@ -266,7 +265,7 @@ class TypingWriterHandler(FollowerTypingHandler):
await self.clock.sleep(random.randint(1, 10)) await self.clock.sleep(random.randint(1, 10))
raise ShadowBanError() raise ShadowBanError()
await self.auth.check_user_in_room(room_id, target_user_id) await self.auth.check_user_in_room(room_id, requester)
logger.debug("%s has started typing in %s", target_user_id, room_id) logger.debug("%s has started typing in %s", target_user_id, room_id)
@ -289,12 +288,11 @@ class TypingWriterHandler(FollowerTypingHandler):
self, target_user: UserID, requester: Requester, room_id: str self, target_user: UserID, requester: Requester, room_id: str
) -> None: ) -> None:
target_user_id = target_user.to_string() target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string()
if not self.is_mine_id(target_user_id): if not self.is_mine_id(target_user_id):
raise SynapseError(400, "User is not hosted on this homeserver") raise SynapseError(400, "User is not hosted on this homeserver")
if target_user_id != auth_user_id: if target_user != requester.user:
raise AuthError(400, "Cannot set another user's typing state") raise AuthError(400, "Cannot set another user's typing state")
if requester.shadow_banned: if requester.shadow_banned:
@ -302,7 +300,7 @@ class TypingWriterHandler(FollowerTypingHandler):
await self.clock.sleep(random.randint(1, 10)) await self.clock.sleep(random.randint(1, 10))
raise ShadowBanError() raise ShadowBanError()
await self.auth.check_user_in_room(room_id, target_user_id) await self.auth.check_user_in_room(room_id, requester)
logger.debug("%s has stopped typing in %s", target_user_id, room_id) logger.debug("%s has stopped typing in %s", target_user_id, room_id)

View file

@ -226,7 +226,7 @@ class SynapseRequest(Request):
# If this is a request where the target user doesn't match the user who # If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we return both. # authenticated (e.g. and admin is puppetting a user) then we return both.
if self._requester.user.to_string() != authenticated_entity: if requester != authenticated_entity:
return requester, authenticated_entity return requester, authenticated_entity
return requester, None return requester, None

View file

@ -19,7 +19,7 @@ from typing import Iterable, Pattern
from synapse.api.auth import Auth from synapse.api.auth import Auth
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
from synapse.types import UserID from synapse.types import Requester
def admin_patterns(path_regex: str, version: str = "v1") -> Iterable[Pattern]: def admin_patterns(path_regex: str, version: str = "v1") -> Iterable[Pattern]:
@ -48,19 +48,19 @@ async def assert_requester_is_admin(auth: Auth, request: SynapseRequest) -> None
AuthError if the requester is not a server admin AuthError if the requester is not a server admin
""" """
requester = await auth.get_user_by_req(request) requester = await auth.get_user_by_req(request)
await assert_user_is_admin(auth, requester.user) await assert_user_is_admin(auth, requester)
async def assert_user_is_admin(auth: Auth, user_id: UserID) -> None: async def assert_user_is_admin(auth: Auth, requester: Requester) -> None:
"""Verify that the given user is an admin user """Verify that the given user is an admin user
Args: Args:
auth: Auth singleton auth: Auth singleton
user_id: user to check requester: The user making the request, according to the access token.
Raises: Raises:
AuthError if the user is not a server admin AuthError if the user is not a server admin
""" """
is_admin = await auth.is_server_admin(user_id) is_admin = await auth.is_server_admin(requester)
if not is_admin: if not is_admin:
raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin") raise AuthError(HTTPStatus.FORBIDDEN, "You are not a server admin")

View file

@ -54,7 +54,7 @@ class QuarantineMediaInRoom(RestServlet):
self, request: SynapseRequest, room_id: str self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester)
logging.info("Quarantining room: %s", room_id) logging.info("Quarantining room: %s", room_id)
@ -81,7 +81,7 @@ class QuarantineMediaByUser(RestServlet):
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester)
logging.info("Quarantining media by user: %s", user_id) logging.info("Quarantining media by user: %s", user_id)
@ -110,7 +110,7 @@ class QuarantineMediaByID(RestServlet):
self, request: SynapseRequest, server_name: str, media_id: str self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester)
logging.info("Quarantining media by ID: %s/%s", server_name, media_id) logging.info("Quarantining media by ID: %s/%s", server_name, media_id)

View file

@ -75,7 +75,7 @@ class RoomRestV2Servlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request) requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user) await assert_user_is_admin(self._auth, requester)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -327,7 +327,7 @@ class RoomRestServlet(RestServlet):
pagination_handler: "PaginationHandler", pagination_handler: "PaginationHandler",
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await auth.get_user_by_req(request) requester = await auth.get_user_by_req(request)
await assert_user_is_admin(auth, requester.user) await assert_user_is_admin(auth, requester)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -461,7 +461,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
assert request.args is not None assert request.args is not None
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -551,7 +551,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
self, request: SynapseRequest, room_identifier: str self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester)
content = parse_json_object_from_request(request, allow_empty_body=True) content = parse_json_object_from_request(request, allow_empty_body=True)
room_id, _ = await self.resolve_room_id(room_identifier) room_id, _ = await self.resolve_room_id(room_identifier)
@ -742,7 +742,7 @@ class RoomEventContextServlet(RestServlet):
self, request: SynapseRequest, room_id: str, event_id: str self, request: SynapseRequest, room_id: str, event_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester)
limit = parse_integer(request, "limit", default=10) limit = parse_integer(request, "limit", default=10)
@ -834,7 +834,7 @@ class BlockRoomRestServlet(RestServlet):
self, request: SynapseRequest, room_id: str self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request) requester = await self._auth.get_user_by_req(request)
await assert_user_is_admin(self._auth, requester.user) await assert_user_is_admin(self._auth, requester)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)

View file

@ -183,7 +183,7 @@ class UserRestServletV2(RestServlet):
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
body = parse_json_object_from_request(request) body = parse_json_object_from_request(request)
@ -575,10 +575,9 @@ class WhoisRestServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
auth_user = requester.user
if target_user != auth_user: if target_user != requester.user:
await assert_user_is_admin(self.auth, auth_user) await assert_user_is_admin(self.auth, requester)
if not self.is_mine(target_user): if not self.is_mine(target_user):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user") raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only whois a local user")
@ -601,7 +600,7 @@ class DeactivateAccountRestServlet(RestServlet):
self, request: SynapseRequest, target_user_id: str self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester)
if not self.is_mine(UserID.from_string(target_user_id)): if not self.is_mine(UserID.from_string(target_user_id)):
raise SynapseError( raise SynapseError(
@ -693,7 +692,7 @@ class ResetPasswordRestServlet(RestServlet):
This needs user to have administrator access in Synapse. This needs user to have administrator access in Synapse.
""" """
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester)
UserID.from_string(target_user_id) UserID.from_string(target_user_id)
@ -807,7 +806,7 @@ class UserAdminServlet(RestServlet):
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester)
auth_user = requester.user auth_user = requester.user
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -921,7 +920,7 @@ class UserTokenRestServlet(RestServlet):
self, request: SynapseRequest, user_id: str self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester)
auth_user = requester.user auth_user = requester.user
if not self.is_mine_id(user_id): if not self.is_mine_id(user_id):

View file

@ -66,7 +66,7 @@ class ProfileDisplaynameRestServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=True) requester = await self.auth.get_user_by_req(request, allow_guest=True)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester.user) is_admin = await self.auth.is_server_admin(requester)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -123,7 +123,7 @@ class ProfileAvatarURLRestServlet(RestServlet):
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
user = UserID.from_string(user_id) user = UserID.from_string(user_id)
is_admin = await self.auth.is_server_admin(requester.user) is_admin = await self.auth.is_server_admin(requester)
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
try: try:

View file

@ -484,9 +484,6 @@ class RegisterRestServlet(RestServlet):
"Appservice token must be provided when using a type of m.login.application_service", "Appservice token must be provided when using a type of m.login.application_service",
) )
# Verify the AS
self.auth.get_appservice_by_req(request)
# Set the desired user according to the AS API (which uses the # Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll # 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one. # fallback to 'username' if they gave one.

View file

@ -229,7 +229,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
msg_handler = self.message_handler msg_handler = self.message_handler
data = await msg_handler.get_room_data( data = await msg_handler.get_room_data(
user_id=requester.user.to_string(), requester=requester,
room_id=room_id, room_id=room_id,
event_type=event_type, event_type=event_type,
state_key=state_key, state_key=state_key,
@ -574,7 +574,7 @@ class RoomMemberListRestServlet(RestServlet):
events = await handler.get_state_events( events = await handler.get_state_events(
room_id=room_id, room_id=room_id,
user_id=requester.user.to_string(), requester=requester,
at_token=at_token, at_token=at_token,
state_filter=StateFilter.from_types([(EventTypes.Member, None)]), state_filter=StateFilter.from_types([(EventTypes.Member, None)]),
) )
@ -696,8 +696,7 @@ class RoomStateRestServlet(RestServlet):
# Get all the current state for this room # Get all the current state for this room
events = await self.message_handler.get_state_events( events = await self.message_handler.get_state_events(
room_id=room_id, room_id=room_id,
user_id=requester.user.to_string(), requester=requester,
is_guest=requester.is_guest,
) )
return 200, events return 200, events
@ -755,7 +754,7 @@ class RoomEventServlet(RestServlet):
== "true" == "true"
) )
if include_unredacted_content and not await self.auth.is_server_admin( if include_unredacted_content and not await self.auth.is_server_admin(
requester.user requester
): ):
power_level_event = ( power_level_event = (
await self._storage_controllers.state.get_current_state_event( await self._storage_controllers.state.get_current_state_event(
@ -1260,9 +1259,7 @@ class TimestampLookupRestServlet(RestServlet):
self, request: SynapseRequest, room_id: str self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self._auth.get_user_by_req(request) requester = await self._auth.get_user_by_req(request)
await self._auth.check_user_in_room_or_world_readable( await self._auth.check_user_in_room_or_world_readable(room_id, requester)
room_id, requester.user.to_string()
)
timestamp = parse_integer(request, "ts", required=True) timestamp = parse_integer(request, "ts", required=True)
direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"]) direction = parse_string(request, "dir", default="f", allowed_values=["f", "b"])

View file

@ -244,7 +244,7 @@ class ServerNoticesManager:
assert self.server_notices_mxid is not None assert self.server_notices_mxid is not None
notice_user_data_in_room = await self._message_handler.get_room_data( notice_user_data_in_room = await self._message_handler.get_room_data(
self.server_notices_mxid, create_requester(self.server_notices_mxid),
room_id, room_id,
EventTypes.Member, EventTypes.Member,
self.server_notices_mxid, self.server_notices_mxid,

View file

@ -69,9 +69,9 @@ class TokenLookupResult:
""" """
user_id: str user_id: str
token_id: int
is_guest: bool = False is_guest: bool = False
shadow_banned: bool = False shadow_banned: bool = False
token_id: Optional[int] = None
device_id: Optional[str] = None device_id: Optional[str] = None
valid_until_ms: Optional[int] = None valid_until_ms: Optional[int] = None
token_owner: str = attr.ib() token_owner: str = attr.ib()

View file

@ -284,10 +284,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
TokenLookupResult( TokenLookupResult(
user_id="@baldrick:matrix.org", user_id="@baldrick:matrix.org",
device_id="device", device_id="device",
token_id=5,
token_owner="@admin:matrix.org", token_owner="@admin:matrix.org",
token_used=True,
) )
) )
self.store.insert_client_ip = simple_async_mock(None) self.store.insert_client_ip = simple_async_mock(None)
self.store.mark_access_token_as_used = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
@ -301,10 +304,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
TokenLookupResult( TokenLookupResult(
user_id="@baldrick:matrix.org", user_id="@baldrick:matrix.org",
device_id="device", device_id="device",
token_id=5,
token_owner="@admin:matrix.org", token_owner="@admin:matrix.org",
token_used=True,
) )
) )
self.store.insert_client_ip = simple_async_mock(None) self.store.insert_client_ip = simple_async_mock(None)
self.store.mark_access_token_as_used = simple_async_mock(None)
request = Mock(args={}) request = Mock(args={})
request.getClientAddress.return_value.host = "127.0.0.1" request.getClientAddress.return_value.host = "127.0.0.1"
request.args[b"access_token"] = [self.test_token] request.args[b"access_token"] = [self.test_token]
@ -347,7 +353,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
serialized = macaroon.serialize() serialized = macaroon.serialize()
user_info = self.get_success(self.auth.get_user_by_access_token(serialized)) user_info = self.get_success(self.auth.get_user_by_access_token(serialized))
self.assertEqual(user_id, user_info.user_id) self.assertEqual(user_id, user_info.user.to_string())
self.assertTrue(user_info.is_guest) self.assertTrue(user_info.is_guest)
self.store.get_user_by_id.assert_called_with(user_id) self.store.get_user_by_id.assert_called_with(user_id)

View file

@ -25,7 +25,7 @@ from synapse.api.constants import EduTypes
from synapse.api.errors import AuthError from synapse.api.errors import AuthError
from synapse.federation.transport.server import TransportLayerServer from synapse.federation.transport.server import TransportLayerServer
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict, UserID, create_requester from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -117,8 +117,10 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
self.room_members = [] self.room_members = []
async def check_user_in_room(room_id: str, user_id: str) -> None: async def check_user_in_room(room_id: str, requester: Requester) -> None:
if user_id not in [u.to_string() for u in self.room_members]: if requester.user.to_string() not in [
u.to_string() for u in self.room_members
]:
raise AuthError(401, "User is not in the room") raise AuthError(401, "User is not in the room")
return None return None

View file

@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import JsonDict from synapse.types import JsonDict, create_requester
from synapse.util import Clock from synapse.util import Clock
from synapse.visibility import filter_events_for_client from synapse.visibility import filter_events_for_client
@ -188,7 +188,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
message_handler = self.hs.get_message_handler() message_handler = self.hs.get_message_handler()
create_event = self.get_success( create_event = self.get_success(
message_handler.get_room_data( message_handler.get_room_data(
self.user_id, room_id, EventTypes.Create, state_key="" create_requester(self.user_id), room_id, EventTypes.Create, state_key=""
) )
) )

View file

@ -26,7 +26,7 @@ from synapse.rest.client import (
room_upgrade_rest_servlet, room_upgrade_rest_servlet,
) )
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID from synapse.types import UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
from tests import unittest from tests import unittest
@ -275,7 +275,7 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler() message_handler = self.hs.get_message_handler()
event = self.get_success( event = self.get_success(
message_handler.get_room_data( message_handler.get_room_data(
self.banned_user_id, create_requester(self.banned_user_id),
room_id, room_id,
"m.room.member", "m.room.member",
self.banned_user_id, self.banned_user_id,
@ -310,7 +310,7 @@ class ProfileTestCase(_ShadowBannedBase):
message_handler = self.hs.get_message_handler() message_handler = self.hs.get_message_handler()
event = self.get_success( event = self.get_success(
message_handler.get_room_data( message_handler.get_room_data(
self.banned_user_id, create_requester(self.banned_user_id),
room_id, room_id,
"m.room.member", "m.room.member",
self.banned_user_id, self.banned_user_id,