Merge branch 'madlittlemods/msc3575-sliding-sync-0.0.1' into madlittlemods/msc3575-sliding-sync-filtering

Conflicts:
	tests/handlers/test_sliding_sync.py
This commit is contained in:
Eric Eastwood 2024-06-05 14:44:30 -05:00
commit 5078d36bd3
20 changed files with 808 additions and 108 deletions

1
changelog.d/17254.bugfix Normal file
View file

@ -0,0 +1 @@
Fix searching for users with their exact localpart whose ID includes a hyphen.

View file

@ -0,0 +1 @@
Improve ratelimiting in Synapse (#17256).

View file

@ -1946,6 +1946,24 @@ Example configuration:
max_image_pixels: 35M max_image_pixels: 35M
``` ```
--- ---
### `remote_media_download_burst_count`
Remote media downloads are ratelimited using a [leaky bucket algorithm](https://en.wikipedia.org/wiki/Leaky_bucket), where a given "bucket" is keyed to the IP address of the requester when requesting remote media downloads. This configuration option sets the size of the bucket against which the size in bytes of downloads are penalized - if the bucket is full, ie a given number of bytes have already been downloaded, further downloads will be denied until the bucket drains. Defaults to 500MiB. See also `remote_media_download_per_second` which determines the rate at which the "bucket" is emptied and thus has available space to authorize new requests.
Example configuration:
```yaml
remote_media_download_burst_count: 200M
```
---
### `remote_media_download_per_second`
Works in conjunction with `remote_media_download_burst_count` to ratelimit remote media downloads - this configuration option determines the rate at which the "bucket" (see above) leaks in bytes per second. As requests are made to download remote media, the size of those requests in bytes is added to the bucket, and once the bucket has reached it's capacity, no more requests will be allowed until a number of bytes has "drained" from the bucket. This setting determines the rate at which bytes drain from the bucket, with the practical effect that the larger the number, the faster the bucket leaks, allowing for more bytes downloaded over a shorter period of time. Defaults to 87KiB per second. See also `remote_media_download_burst_count`.
Example configuration:
```yaml
remote_media_download_per_second: 40K
```
---
### `prevent_media_downloads_from` ### `prevent_media_downloads_from`
A list of domains to never download media from. Media from these A list of domains to never download media from. Media from these

View file

@ -50,7 +50,7 @@ class Membership:
KNOCK: Final = "knock" KNOCK: Final = "knock"
LEAVE: Final = "leave" LEAVE: Final = "leave"
BAN: Final = "ban" BAN: Final = "ban"
LIST: Final = (INVITE, JOIN, KNOCK, LEAVE, BAN) LIST: Final = {INVITE, JOIN, KNOCK, LEAVE, BAN}
class PresenceState: class PresenceState:

View file

@ -218,3 +218,13 @@ class RatelimitConfig(Config):
"rc_media_create", "rc_media_create",
defaults={"per_second": 10, "burst_count": 50}, defaults={"per_second": 10, "burst_count": 50},
) )
self.remote_media_downloads = RatelimitSettings(
key="rc_remote_media_downloads",
per_second=self.parse_size(
config.get("remote_media_download_per_second", "87K")
),
burst_count=self.parse_size(
config.get("remote_media_download_burst_count", "500M")
),
)

View file

@ -56,6 +56,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
UnsupportedRoomVersionError, UnsupportedRoomVersionError,
) )
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import ( from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS, KNOWN_ROOM_VERSIONS,
EventFormatVersions, EventFormatVersions,
@ -1877,6 +1878,8 @@ class FederationClient(FederationBase):
output_stream: BinaryIO, output_stream: BinaryIO,
max_size: int, max_size: int,
max_timeout_ms: int, max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]: ) -> Tuple[int, Dict[bytes, List[bytes]]]:
try: try:
return await self.transport_layer.download_media_v3( return await self.transport_layer.download_media_v3(
@ -1885,6 +1888,8 @@ class FederationClient(FederationBase):
output_stream=output_stream, output_stream=output_stream,
max_size=max_size, max_size=max_size,
max_timeout_ms=max_timeout_ms, max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
) )
except HttpResponseException as e: except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint, # If an error is received that is due to an unrecognised endpoint,
@ -1905,6 +1910,8 @@ class FederationClient(FederationBase):
output_stream=output_stream, output_stream=output_stream,
max_size=max_size, max_size=max_size,
max_timeout_ms=max_timeout_ms, max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
) )

View file

@ -43,6 +43,7 @@ import ijson
from synapse.api.constants import Direction, Membership from synapse.api.constants import Direction, Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.api.room_versions import RoomVersion from synapse.api.room_versions import RoomVersion
from synapse.api.urls import ( from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX, FEDERATION_UNSTABLE_PREFIX,
@ -819,6 +820,8 @@ class TransportLayerClient:
output_stream: BinaryIO, output_stream: BinaryIO,
max_size: int, max_size: int,
max_timeout_ms: int, max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]: ) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/r0/download/{destination}/{media_id}" path = f"/_matrix/media/r0/download/{destination}/{media_id}"
@ -834,6 +837,8 @@ class TransportLayerClient:
"allow_remote": "false", "allow_remote": "false",
"timeout_ms": str(max_timeout_ms), "timeout_ms": str(max_timeout_ms),
}, },
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
) )
async def download_media_v3( async def download_media_v3(
@ -843,6 +848,8 @@ class TransportLayerClient:
output_stream: BinaryIO, output_stream: BinaryIO,
max_size: int, max_size: int,
max_timeout_ms: int, max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]]]: ) -> Tuple[int, Dict[bytes, List[bytes]]]:
path = f"/_matrix/media/v3/download/{destination}/{media_id}" path = f"/_matrix/media/v3/download/{destination}/{media_id}"
@ -862,6 +869,8 @@ class TransportLayerClient:
"allow_redirect": "true", "allow_redirect": "true",
}, },
follow_redirects=True, follow_redirects=True,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
) )

View file

@ -42,17 +42,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Everything except `Membership.LEAVE` because we want everything that's *still*
# relevant to the user. There are few more things to include in the sync response
# (kicks, newly_left) but those are handled separately.
MEMBERSHIP_TO_DISPLAY_IN_SYNC = (
Membership.INVITE,
Membership.JOIN,
Membership.KNOCK,
Membership.BAN,
)
def filter_membership_for_sync(*, membership: str, user_id: str, sender: str) -> bool: def filter_membership_for_sync(*, membership: str, user_id: str, sender: str) -> bool:
""" """
Returns True if the membership event should be included in the sync response, Returns True if the membership event should be included in the sync response,
@ -65,7 +54,10 @@ def filter_membership_for_sync(*, membership: str, user_id: str, sender: str) ->
""" """
return ( return (
membership in MEMBERSHIP_TO_DISPLAY_IN_SYNC # Everything except `Membership.LEAVE` because we want everything that's *still*
# relevant to the user. There are few more things to include in the sync response
# (newly_left) but those are handled separately.
membership in (Membership.LIST - {Membership.LEAVE})
# Include kicks # Include kicks
or (membership == Membership.LEAVE and sender != user_id) or (membership == Membership.LEAVE and sender != user_id)
) )
@ -233,7 +225,6 @@ class SlidingSyncResult:
class SlidingSyncHandler: class SlidingSyncHandler:
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
self.hs_config = hs.config
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.auth_blocking = hs.get_auth_blocking() self.auth_blocking = hs.get_auth_blocking()
@ -385,13 +376,23 @@ class SlidingSyncHandler:
Fetch room IDs that should be listed for this user in the sync response (the Fetch room IDs that should be listed for this user in the sync response (the
full room list that will be filtered, sorted, and sliced). full room list that will be filtered, sorted, and sliced).
We're looking for rooms that the user has not left (`invite`, `knock`, `join`, We're looking for rooms where the user has the following state in the token
and `ban`), or kicks (`leave` where the `sender` is different from the range (> `from_token` and <= `to_token`):
`state_key`), or newly_left rooms that are > `from_token` and <= `to_token`.
- `invite`, `join`, `knock`, `ban` membership events
- Kicks (`leave` membership events where `sender` is different from the
`user_id`/`state_key`)
- `newly_left` (rooms that were left during the given token range)
- In order for bans/kicks to not show up in sync, you need to `/forget` those
rooms. This doesn't modify the event itself though and only adds the
`forgotten` flag to the `room_memberships` table in Synapse. There isn't a way
to tell when a room was forgotten at the moment so we can't factor it into the
from/to range.
""" """
user_id = user.to_string() user_id = user.to_string()
# First grab a current snapshot rooms for the user # First grab a current snapshot rooms for the user
# (also handles forgotten rooms)
room_for_user_list = await self.store.get_rooms_for_local_user_where_membership_is( room_for_user_list = await self.store.get_rooms_for_local_user_where_membership_is(
user_id=user_id, user_id=user_id,
# We want to fetch any kind of membership (joined and left rooms) in order # We want to fetch any kind of membership (joined and left rooms) in order
@ -441,10 +442,7 @@ class SlidingSyncHandler:
# Then assemble the `RoomStreamToken` # Then assemble the `RoomStreamToken`
membership_snapshot_token = RoomStreamToken( membership_snapshot_token = RoomStreamToken(
# Minimum position in the `instance_map` # Minimum position in the `instance_map`
stream=min( stream=min(instance_to_max_stream_ordering_map.values()),
stream_ordering
for stream_ordering in instance_to_max_stream_ordering_map.values()
),
instance_map=immutabledict(instance_to_max_stream_ordering_map), instance_map=immutabledict(instance_to_max_stream_ordering_map),
) )
@ -454,26 +452,13 @@ class SlidingSyncHandler:
if membership_snapshot_token.is_before_or_eq(to_token.room_key): if membership_snapshot_token.is_before_or_eq(to_token.room_key):
return sync_room_id_set return sync_room_id_set
# We assume the `from_token` is before or at-least equal to the `to_token`
assert from_token is None or from_token.room_key.is_before_or_eq(
to_token.room_key
), f"{from_token.room_key if from_token else None} < {to_token.room_key}"
# We assume the `from_token`/`to_token` is before the `membership_snapshot_token`
assert from_token is None or from_token.room_key.is_before_or_eq(
membership_snapshot_token
), f"{from_token.room_key if from_token else None} < {membership_snapshot_token}"
assert to_token.room_key.is_before_or_eq(
membership_snapshot_token
), f"{to_token.room_key} < {membership_snapshot_token}"
# Since we fetched the users room list at some point in time after the from/to # Since we fetched the users room list at some point in time after the from/to
# tokens, we need to revert/rewind some membership changes to match the point in # tokens, we need to revert/rewind some membership changes to match the point in
# time of the `to_token`. In particular, we need to make these fixups: # time of the `to_token`. In particular, we need to make these fixups:
# #
# - 1) Add back newly_left rooms (> `from_token` and <= `to_token`) # - 1a) Remove rooms that the user joined after the `to_token`
# - 2a) Remove rooms that the user joined after the `to_token` # - 1b) Add back rooms that the user left after the `to_token`
# - 2b) Add back rooms that the user left after the `to_token` # - 2) Add back newly_left rooms (> `from_token` and <= `to_token`)
# #
# Below, we're doing two separate lookups for membership changes. We could # Below, we're doing two separate lookups for membership changes. We could
# request everything for both fixups in one range, [`from_token.room_key`, # request everything for both fixups in one range, [`from_token.room_key`,
@ -484,41 +469,7 @@ class SlidingSyncHandler:
# 1) ----------------------------------------------------- # 1) -----------------------------------------------------
# 1) Fetch membership changes that fall in the range from `from_token` up to `to_token` # 1) Fetch membership changes that fall in the range from `to_token` up to
membership_change_events_in_from_to_range = []
if from_token:
membership_change_events_in_from_to_range = (
await self.store.get_membership_changes_for_user(
user_id,
from_key=from_token.room_key,
to_key=to_token.room_key,
excluded_rooms=self.rooms_to_exclude_globally,
)
)
# 1) Assemble a list of the last membership events in some given ranges. Someone
# could have left and joined multiple times during the given range but we only
# care about end-result so we grab the last one.
last_membership_change_by_room_id_in_from_to_range: Dict[str, EventBase] = {}
for event in membership_change_events_in_from_to_range:
assert event.internal_metadata.stream_ordering
last_membership_change_by_room_id_in_from_to_range[event.room_id] = event
# 1) Fixup
for (
last_membership_change_in_from_to_range
) in last_membership_change_by_room_id_in_from_to_range.values():
room_id = last_membership_change_in_from_to_range.room_id
# 1) Add back newly_left rooms (> `from_token` and <= `to_token`). We
# include newly_left rooms because the last event that the user should see
# is their own leave event
if last_membership_change_in_from_to_range.membership == Membership.LEAVE:
sync_room_id_set.add(room_id)
# 2) -----------------------------------------------------
# 2) Fetch membership changes that fall in the range from `to_token` up to
# `membership_snapshot_token` # `membership_snapshot_token`
membership_change_events_after_to_token = ( membership_change_events_after_to_token = (
await self.store.get_membership_changes_for_user( await self.store.get_membership_changes_for_user(
@ -529,7 +480,7 @@ class SlidingSyncHandler:
) )
) )
# 2) Assemble a list of the last membership events in some given ranges. Someone # 1) Assemble a list of the last membership events in some given ranges. Someone
# could have left and joined multiple times during the given range but we only # could have left and joined multiple times during the given range but we only
# care about end-result so we grab the last one. # care about end-result so we grab the last one.
last_membership_change_by_room_id_after_to_token: Dict[str, EventBase] = {} last_membership_change_by_room_id_after_to_token: Dict[str, EventBase] = {}
@ -537,15 +488,13 @@ class SlidingSyncHandler:
# backward to the previous membership that would apply to the from/to range. # backward to the previous membership that would apply to the from/to range.
first_membership_change_by_room_id_after_to_token: Dict[str, EventBase] = {} first_membership_change_by_room_id_after_to_token: Dict[str, EventBase] = {}
for event in membership_change_events_after_to_token: for event in membership_change_events_after_to_token:
assert event.internal_metadata.stream_ordering
last_membership_change_by_room_id_after_to_token[event.room_id] = event last_membership_change_by_room_id_after_to_token[event.room_id] = event
# Only set if we haven't already set it # Only set if we haven't already set it
first_membership_change_by_room_id_after_to_token.setdefault( first_membership_change_by_room_id_after_to_token.setdefault(
event.room_id, event event.room_id, event
) )
# 2) Fixup # 1) Fixup
for ( for (
last_membership_change_after_to_token last_membership_change_after_to_token
) in last_membership_change_by_room_id_after_to_token.values(): ) in last_membership_change_by_room_id_after_to_token.values():
@ -602,7 +551,7 @@ class SlidingSyncHandler:
sender=last_membership_change_after_to_token.sender, sender=last_membership_change_after_to_token.sender,
) )
# 2a) Add back rooms that the user left after the `to_token` # 1a) Add back rooms that the user left after the `to_token`
# #
# For example, if the last membership event after the `to_token` is a leave # For example, if the last membership event after the `to_token` is a leave
# event, then the room was excluded from `sync_room_id_set` when we first # event, then the room was excluded from `sync_room_id_set` when we first
@ -613,7 +562,7 @@ class SlidingSyncHandler:
and should_prev_membership_be_included and should_prev_membership_be_included
): ):
sync_room_id_set.add(room_id) sync_room_id_set.add(room_id)
# 2b) Remove rooms that the user joined (hasn't left) after the `to_token` # 1b) Remove rooms that the user joined (hasn't left) after the `to_token`
# #
# For example, if the last membership event after the `to_token` is a "join" # For example, if the last membership event after the `to_token` is a "join"
# event, then the room was included `sync_room_id_set` when we first crafted # event, then the room was included `sync_room_id_set` when we first crafted
@ -625,6 +574,41 @@ class SlidingSyncHandler:
): ):
sync_room_id_set.discard(room_id) sync_room_id_set.discard(room_id)
# 2) -----------------------------------------------------
# We fix-up newly_left rooms after the first fixup because it may have removed
# some left rooms that we can figure out our newly_left in the following code
# 2) Fetch membership changes that fall in the range from `from_token` up to `to_token`
membership_change_events_in_from_to_range = []
if from_token:
membership_change_events_in_from_to_range = (
await self.store.get_membership_changes_for_user(
user_id,
from_key=from_token.room_key,
to_key=to_token.room_key,
excluded_rooms=self.rooms_to_exclude_globally,
)
)
# 2) Assemble a list of the last membership events in some given ranges. Someone
# could have left and joined multiple times during the given range but we only
# care about end-result so we grab the last one.
last_membership_change_by_room_id_in_from_to_range: Dict[str, EventBase] = {}
for event in membership_change_events_in_from_to_range:
last_membership_change_by_room_id_in_from_to_range[event.room_id] = event
# 2) Fixup
for (
last_membership_change_in_from_to_range
) in last_membership_change_by_room_id_in_from_to_range.values():
room_id = last_membership_change_in_from_to_range.room_id
# 2) Add back newly_left rooms (> `from_token` and <= `to_token`). We
# include newly_left rooms because the last event that the user should see
# is their own leave event
if last_membership_change_in_from_to_range.membership == Membership.LEAVE:
sync_room_id_set.add(room_id)
return sync_room_id_set return sync_room_id_set
async def filter_rooms( async def filter_rooms(

View file

@ -57,7 +57,7 @@ from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import Cooperator from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers from twisted.web.http_headers import Headers
from twisted.web.iweb import IAgent, IBodyProducer, IResponse from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
import synapse.metrics import synapse.metrics
import synapse.util.retryutils import synapse.util.retryutils
@ -68,6 +68,7 @@ from synapse.api.errors import (
RequestSendFailed, RequestSendFailed,
SynapseError, SynapseError,
) )
from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http import QuieterFileBodyProducer from synapse.http import QuieterFileBodyProducer
from synapse.http.client import ( from synapse.http.client import (
@ -1411,9 +1412,11 @@ class MatrixFederationHttpClient:
destination: str, destination: str,
path: str, path: str,
output_stream: BinaryIO, output_stream: BinaryIO,
download_ratelimiter: Ratelimiter,
ip_address: str,
max_size: int,
args: Optional[QueryParams] = None, args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True, retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
ignore_backoff: bool = False, ignore_backoff: bool = False,
follow_redirects: bool = False, follow_redirects: bool = False,
) -> Tuple[int, Dict[bytes, List[bytes]]]: ) -> Tuple[int, Dict[bytes, List[bytes]]]:
@ -1422,6 +1425,10 @@ class MatrixFederationHttpClient:
destination: The remote server to send the HTTP request to. destination: The remote server to send the HTTP request to.
path: The HTTP path to GET. path: The HTTP path to GET.
output_stream: File to write the response body to. output_stream: File to write the response body to.
download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
requester IP
ip_address: IP address of the requester
max_size: maximum allowable size in bytes of the file
args: Optional dictionary used to create the query string. args: Optional dictionary used to create the query string.
ignore_backoff: true to ignore the historical backoff data ignore_backoff: true to ignore the historical backoff data
and try the request anyway. and try the request anyway.
@ -1441,11 +1448,27 @@ class MatrixFederationHttpClient:
federation whitelist federation whitelist
RequestSendFailed: If there were problems connecting to the RequestSendFailed: If there were problems connecting to the
remote, due to e.g. DNS failures, connection timeouts etc. remote, due to e.g. DNS failures, connection timeouts etc.
SynapseError: If the requested file exceeds ratelimits
""" """
request = MatrixFederationRequest( request = MatrixFederationRequest(
method="GET", destination=destination, path=path, query=args method="GET", destination=destination, path=path, query=args
) )
# check for a minimum balance of 1MiB in ratelimiter before initiating request
send_req, _ = await download_ratelimiter.can_do_action(
requester=None, key=ip_address, n_actions=1048576, update=False
)
if not send_req:
msg = "Requested file size exceeds ratelimits"
logger.warning(
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
response = await self._send_request( response = await self._send_request(
request, request,
retry_on_dns_fail=retry_on_dns_fail, retry_on_dns_fail=retry_on_dns_fail,
@ -1455,12 +1478,36 @@ class MatrixFederationHttpClient:
headers = dict(response.headers.getAllRawHeaders()) headers = dict(response.headers.getAllRawHeaders())
expected_size = response.length
# if we don't get an expected length then use the max length
if expected_size == UNKNOWN_LENGTH:
expected_size = max_size
logger.debug(
f"File size unknown, assuming file is max allowable size: {max_size}"
)
read_body, _ = await download_ratelimiter.can_do_action(
requester=None,
key=ip_address,
n_actions=expected_size,
)
if not read_body:
msg = "Requested file size exceeds ratelimits"
logger.warning(
"{%s} [%s] %s",
request.txn_id,
request.destination,
msg,
)
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
try: try:
d = read_body_with_max_size(response, output_stream, max_size) # add a byte of headroom to max size as function errs at >=
d = read_body_with_max_size(response, output_stream, expected_size + 1)
d.addTimeout(self.default_timeout_seconds, self.reactor) d.addTimeout(self.default_timeout_seconds, self.reactor)
length = await make_deferred_yieldable(d) length = await make_deferred_yieldable(d)
except BodyExceededMaxSize: except BodyExceededMaxSize:
msg = "Requested file is too large > %r bytes" % (max_size,) msg = "Requested file is too large > %r bytes" % (expected_size,)
logger.warning( logger.warning(
"{%s} [%s] %s", "{%s} [%s] %s",
request.txn_id, request.txn_id,

View file

@ -42,6 +42,7 @@ from synapse.api.errors import (
SynapseError, SynapseError,
cs_error, cs_error,
) )
from synapse.api.ratelimiting import Ratelimiter
from synapse.config.repository import ThumbnailRequirement from synapse.config.repository import ThumbnailRequirement
from synapse.http.server import respond_with_json from synapse.http.server import respond_with_json
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -111,6 +112,12 @@ class MediaRepository:
) )
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
self.download_ratelimiter = Ratelimiter(
store=hs.get_storage_controllers().main,
clock=hs.get_clock(),
cfg=hs.config.ratelimiting.remote_media_downloads,
)
# List of StorageProviders where we should search for media and # List of StorageProviders where we should search for media and
# potentially upload to. # potentially upload to.
storage_providers = [] storage_providers = []
@ -464,6 +471,7 @@ class MediaRepository:
media_id: str, media_id: str,
name: Optional[str], name: Optional[str],
max_timeout_ms: int, max_timeout_ms: int,
ip_address: str,
) -> None: ) -> None:
"""Respond to requests for remote media. """Respond to requests for remote media.
@ -475,6 +483,7 @@ class MediaRepository:
the filename in the Content-Disposition header of the response. the filename in the Content-Disposition header of the response.
max_timeout_ms: the maximum number of milliseconds to wait for the max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded. media to be uploaded.
ip_address: the IP address of the requester
Returns: Returns:
Resolves once a response has successfully been written to request Resolves once a response has successfully been written to request
@ -500,7 +509,11 @@ class MediaRepository:
key = (server_name, media_id) key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key): async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl( responder, media_info = await self._get_remote_media_impl(
server_name, media_id, max_timeout_ms server_name,
media_id,
max_timeout_ms,
self.download_ratelimiter,
ip_address,
) )
# We deliberately stream the file outside the lock # We deliberately stream the file outside the lock
@ -517,7 +530,7 @@ class MediaRepository:
respond_404(request) respond_404(request)
async def get_remote_media_info( async def get_remote_media_info(
self, server_name: str, media_id: str, max_timeout_ms: int self, server_name: str, media_id: str, max_timeout_ms: int, ip_address: str
) -> RemoteMedia: ) -> RemoteMedia:
"""Gets the media info associated with the remote file, downloading """Gets the media info associated with the remote file, downloading
if necessary. if necessary.
@ -527,6 +540,7 @@ class MediaRepository:
media_id: The media ID of the content (as defined by the remote server). media_id: The media ID of the content (as defined by the remote server).
max_timeout_ms: the maximum number of milliseconds to wait for the max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded. media to be uploaded.
ip_address: IP address of the requester
Returns: Returns:
The media info of the file The media info of the file
@ -542,7 +556,11 @@ class MediaRepository:
key = (server_name, media_id) key = (server_name, media_id)
async with self.remote_media_linearizer.queue(key): async with self.remote_media_linearizer.queue(key):
responder, media_info = await self._get_remote_media_impl( responder, media_info = await self._get_remote_media_impl(
server_name, media_id, max_timeout_ms server_name,
media_id,
max_timeout_ms,
self.download_ratelimiter,
ip_address,
) )
# Ensure we actually use the responder so that it releases resources # Ensure we actually use the responder so that it releases resources
@ -553,7 +571,12 @@ class MediaRepository:
return media_info return media_info
async def _get_remote_media_impl( async def _get_remote_media_impl(
self, server_name: str, media_id: str, max_timeout_ms: int self,
server_name: str,
media_id: str,
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[Optional[Responder], RemoteMedia]: ) -> Tuple[Optional[Responder], RemoteMedia]:
"""Looks for media in local cache, if not there then attempt to """Looks for media in local cache, if not there then attempt to
download from remote server. download from remote server.
@ -564,6 +587,9 @@ class MediaRepository:
remote server). remote server).
max_timeout_ms: the maximum number of milliseconds to wait for the max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded. media to be uploaded.
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
requester IP.
ip_address: the IP address of the requester
Returns: Returns:
A tuple of responder and the media info of the file. A tuple of responder and the media info of the file.
@ -596,7 +622,7 @@ class MediaRepository:
try: try:
media_info = await self._download_remote_file( media_info = await self._download_remote_file(
server_name, media_id, max_timeout_ms server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address
) )
except SynapseError: except SynapseError:
raise raise
@ -630,6 +656,8 @@ class MediaRepository:
server_name: str, server_name: str,
media_id: str, media_id: str,
max_timeout_ms: int, max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> RemoteMedia: ) -> RemoteMedia:
"""Attempt to download the remote file from the given server name, """Attempt to download the remote file from the given server name,
using the given file_id as the local id. using the given file_id as the local id.
@ -641,6 +669,9 @@ class MediaRepository:
locally generated. locally generated.
max_timeout_ms: the maximum number of milliseconds to wait for the max_timeout_ms: the maximum number of milliseconds to wait for the
media to be uploaded. media to be uploaded.
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
requester IP
ip_address: the IP address of the requester
Returns: Returns:
The media info of the file. The media info of the file.
@ -658,6 +689,8 @@ class MediaRepository:
output_stream=f, output_stream=f,
max_size=self.max_upload_size, max_size=self.max_upload_size,
max_timeout_ms=max_timeout_ms, max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
) )
except RequestSendFailed as e: except RequestSendFailed as e:
logger.warning( logger.warning(

View file

@ -359,9 +359,10 @@ class ThumbnailProvider:
desired_method: str, desired_method: str,
desired_type: str, desired_type: str,
max_timeout_ms: int, max_timeout_ms: int,
ip_address: str,
) -> None: ) -> None:
media_info = await self.media_repo.get_remote_media_info( media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms server_name, media_id, max_timeout_ms, ip_address
) )
if not media_info: if not media_info:
respond_404(request) respond_404(request)
@ -422,12 +423,13 @@ class ThumbnailProvider:
method: str, method: str,
m_type: str, m_type: str,
max_timeout_ms: int, max_timeout_ms: int,
ip_address: str,
) -> None: ) -> None:
# TODO: Don't download the whole remote file # TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead of # We should proxy the thumbnail from the remote server instead of
# downloading the remote file and generating our own thumbnails. # downloading the remote file and generating our own thumbnails.
media_info = await self.media_repo.get_remote_media_info( media_info = await self.media_repo.get_remote_media_info(
server_name, media_id, max_timeout_ms server_name, media_id, max_timeout_ms, ip_address
) )
if not media_info: if not media_info:
return return

View file

@ -174,6 +174,7 @@ class UnstableThumbnailResource(RestServlet):
respond_404(request) respond_404(request)
return return
ip_address = request.getClientAddress().host
remote_resp_function = ( remote_resp_function = (
self.thumbnailer.select_or_generate_remote_thumbnail self.thumbnailer.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails if self.dynamic_thumbnails
@ -188,6 +189,7 @@ class UnstableThumbnailResource(RestServlet):
method, method,
m_type, m_type,
max_timeout_ms, max_timeout_ms,
ip_address,
) )
self.media_repo.mark_recently_accessed(server_name, media_id) self.media_repo.mark_recently_accessed(server_name, media_id)

View file

@ -292,6 +292,9 @@ class RoomStateEventRestServlet(RestServlet):
try: try:
if event_type == EventTypes.Member: if event_type == EventTypes.Member:
membership = content.get("membership", None) membership = content.get("membership", None)
if not isinstance(membership, str):
raise SynapseError(400, "Invalid membership (must be a string)")
event_id, _ = await self.room_member_handler.update_membership( event_id, _ = await self.room_member_handler.update_membership(
requester, requester,
target=UserID.from_string(state_key), target=UserID.from_string(state_key),

View file

@ -97,6 +97,12 @@ class DownloadResource(RestServlet):
respond_404(request) respond_404(request)
return return
ip_address = request.getClientAddress().host
await self.media_repo.get_remote_media( await self.media_repo.get_remote_media(
request, server_name, media_id, file_name, max_timeout_ms request,
server_name,
media_id,
file_name,
max_timeout_ms,
ip_address,
) )

View file

@ -104,6 +104,7 @@ class ThumbnailResource(RestServlet):
respond_404(request) respond_404(request)
return return
ip_address = request.getClientAddress().host
remote_resp_function = ( remote_resp_function = (
self.thumbnail_provider.select_or_generate_remote_thumbnail self.thumbnail_provider.select_or_generate_remote_thumbnail
if self.dynamic_thumbnails if self.dynamic_thumbnails
@ -118,5 +119,6 @@ class ThumbnailResource(RestServlet):
method, method,
m_type, m_type,
max_timeout_ms, max_timeout_ms,
ip_address,
) )
self.media_repo.mark_recently_accessed(server_name, media_id) self.media_repo.mark_recently_accessed(server_name, media_id)

View file

@ -1281,7 +1281,7 @@ def _parse_words_with_regex(search_term: str) -> List[str]:
Break down search term into words, when we don't have ICU available. Break down search term into words, when we don't have ICU available.
See: `_parse_words` See: `_parse_words`
""" """
return re.findall(r"([\w\-]+)", search_term, re.UNICODE) return re.findall(r"([\w-]+)", search_term, re.UNICODE)
def _parse_words_with_icu(search_term: str) -> List[str]: def _parse_words_with_icu(search_term: str) -> List[str]:
@ -1303,15 +1303,69 @@ def _parse_words_with_icu(search_term: str) -> List[str]:
if j < 0: if j < 0:
break break
result = search_term[i:j] # We want to make sure that we split on `@` and `:` specifically, as
# they occur in user IDs.
for result in re.split(r"[@:]+", search_term[i:j]):
results.append(result.strip())
i = j
# libicu will break up words that have punctuation in them, but to handle
# cases where user IDs have '-', '.' and '_' in them we want to *not* break
# those into words and instead allow the DB to tokenise them how it wants.
#
# In particular, user-71 in postgres gets tokenised to "user, -71", and this
# will not match a query for "user, 71".
new_results: List[str] = []
i = 0
while i < len(results):
curr = results[i]
prev = None
next = None
if i > 0:
prev = results[i - 1]
if i + 1 < len(results):
next = results[i + 1]
i += 1
# libicu considers spaces and punctuation between words as words, but we don't # libicu considers spaces and punctuation between words as words, but we don't
# want to include those in results as they would result in syntax errors in SQL # want to include those in results as they would result in syntax errors in SQL
# queries (e.g. "foo bar" would result in the search query including "foo & & # queries (e.g. "foo bar" would result in the search query including "foo & &
# bar"). # bar").
if len(re.findall(r"([\w\-]+)", result, re.UNICODE)): if not curr:
results.append(result) continue
i = j if curr in ["-", ".", "_"]:
prefix = ""
suffix = ""
return results # Check if the next item is a word, and if so use it as the suffix.
# We check for if its a word as we don't want to concatenate
# multiple punctuation marks.
if next is not None and re.match(r"\w", next):
suffix = next
i += 1 # We're using next, so we skip it in the outer loop.
else:
# We want to avoid creating terms like "user-", as we should
# strip trailing punctuation.
continue
if prev and re.match(r"\w", prev) and new_results:
prefix = new_results[-1]
new_results.pop()
# We might not have a prefix here, but that's fine as we want to
# ensure that we don't strip preceding punctuation e.g. '-71'
# shouldn't be converted to '71'.
new_results.append(f"{prefix}{curr}{suffix}")
continue
elif not re.match(r"\w", curr):
# Ignore other punctuation
continue
new_results.append(curr)
return new_results

View file

@ -18,6 +18,7 @@
# #
# #
import logging import logging
from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@ -26,9 +27,11 @@ from synapse.api.room_versions import RoomVersions
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import knock, login, room from synapse.rest.client import knock, login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import JsonDict, UserID from synapse.types import JsonDict, UserID
from synapse.util import Clock from synapse.util import Clock
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -309,7 +312,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
def test_only_newly_left_rooms_show_up(self) -> None: def test_only_newly_left_rooms_show_up(self) -> None:
""" """
Test that newly_left rooms still show up in the sync response but rooms that Test that newly_left rooms still show up in the sync response but rooms that
were left before the `from_token` don't show up. See condition "1)" comments in were left before the `from_token` don't show up. See condition "2)" comments in
the `get_sync_room_ids_for_user` method. the `get_sync_room_ids_for_user` method.
""" """
user1_id = self.register_user("user1", "pass") user1_id = self.register_user("user1", "pass")
@ -340,7 +343,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
def test_no_joins_after_to_token(self) -> None: def test_no_joins_after_to_token(self) -> None:
""" """
Rooms we join after the `to_token` should *not* show up. See condition "2b)" Rooms we join after the `to_token` should *not* show up. See condition "1b)"
comments in the `get_sync_room_ids_for_user()` method. comments in the `get_sync_room_ids_for_user()` method.
""" """
user1_id = self.register_user("user1", "pass") user1_id = self.register_user("user1", "pass")
@ -369,7 +372,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
def test_join_during_range_and_left_room_after_to_token(self) -> None: def test_join_during_range_and_left_room_after_to_token(self) -> None:
""" """
Room still shows up if we left the room but were joined during the Room still shows up if we left the room but were joined during the
from_token/to_token. See condition "2a)" comments in the from_token/to_token. See condition "1a)" comments in the
`get_sync_room_ids_for_user()` method. `get_sync_room_ids_for_user()` method.
""" """
user1_id = self.register_user("user1", "pass") user1_id = self.register_user("user1", "pass")
@ -399,7 +402,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
def test_join_before_range_and_left_room_after_to_token(self) -> None: def test_join_before_range_and_left_room_after_to_token(self) -> None:
""" """
Room still shows up if we left the room but were joined before the `from_token` Room still shows up if we left the room but were joined before the `from_token`
so it should show up. See condition "2a)" comments in the so it should show up. See condition "1a)" comments in the
`get_sync_room_ids_for_user()` method. `get_sync_room_ids_for_user()` method.
""" """
user1_id = self.register_user("user1", "pass") user1_id = self.register_user("user1", "pass")
@ -426,7 +429,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
def test_kicked_before_range_and_left_after_to_token(self) -> None: def test_kicked_before_range_and_left_after_to_token(self) -> None:
""" """
Room still shows up if we left the room but were kicked before the `from_token` Room still shows up if we left the room but were kicked before the `from_token`
so it should show up. See condition "2a)" comments in the so it should show up. See condition "1a)" comments in the
`get_sync_room_ids_for_user()` method. `get_sync_room_ids_for_user()` method.
""" """
user1_id = self.register_user("user1", "pass") user1_id = self.register_user("user1", "pass")
@ -474,8 +477,8 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None: def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None:
""" """
Newly left room should show up. But we're also testing that joining and leaving Newly left room should show up. But we're also testing that joining and leaving
after the `to_token` doesn't mess with the results. See condition "2a)" comments after the `to_token` doesn't mess with the results. See condition "2)" and "1a)"
in the `get_sync_room_ids_for_user()` method. comments in the `get_sync_room_ids_for_user()` method.
""" """
user1_id = self.register_user("user1", "pass") user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass") user1_tok = self.login(user1_id, "pass")
@ -508,10 +511,46 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
# Room should still show up because it's newly_left during the from/to range # Room should still show up because it's newly_left during the from/to range
self.assertEqual(room_id_results, {room_id1}) self.assertEqual(room_id_results, {room_id1})
def test_newly_left_during_range_and_join_after_to_token(self) -> None:
"""
Newly left room should show up. But we're also testing that joining after the
`to_token` doesn't mess with the results. See condition "2)" and "1b)" comments
in the `get_sync_room_ids_for_user()` method.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
before_room1_token = self.event_sources.get_current_token()
# We create the room with user2 so the room isn't left with no members when we
# leave and can still re-join.
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
# Join and leave the room during the from/to range
self.helper.join(room_id1, user1_id, tok=user1_tok)
self.helper.leave(room_id1, user1_id, tok=user1_tok)
after_room1_token = self.event_sources.get_current_token()
# Join the room after we already have our tokens
self.helper.join(room_id1, user1_id, tok=user1_tok)
room_id_results = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=before_room1_token,
to_token=after_room1_token,
)
)
# Room should still show up because it's newly_left during the from/to range
self.assertEqual(room_id_results, {room_id1})
def test_leave_before_range_and_join_leave_after_to_token(self) -> None: def test_leave_before_range_and_join_leave_after_to_token(self) -> None:
""" """
Old left room shouldn't show up. But we're also testing that joining and leaving Old left room shouldn't show up. But we're also testing that joining and leaving
after the `to_token` doesn't mess with the results. See condition "2a)" comments after the `to_token` doesn't mess with the results. See condition "1a)" comments
in the `get_sync_room_ids_for_user()` method. in the `get_sync_room_ids_for_user()` method.
""" """
user1_id = self.register_user("user1", "pass") user1_id = self.register_user("user1", "pass")
@ -546,7 +585,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
def test_leave_before_range_and_join_after_to_token(self) -> None: def test_leave_before_range_and_join_after_to_token(self) -> None:
""" """
Old left room shouldn't show up. But we're also testing that joining after the Old left room shouldn't show up. But we're also testing that joining after the
`to_token` doesn't mess with the results. See condition "2b)" comments in the `to_token` doesn't mess with the results. See condition "1b)" comments in the
`get_sync_room_ids_for_user()` method. `get_sync_room_ids_for_user()` method.
""" """
user1_id = self.register_user("user1", "pass") user1_id = self.register_user("user1", "pass")
@ -664,7 +703,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) -> None: ) -> None:
""" """
Make it look like we joined after the token range but we were invited before the Make it look like we joined after the token range but we were invited before the
from/to range so the room should still show up. See condition "2a)" comments in from/to range so the room should still show up. See condition "1a)" comments in
the `get_sync_room_ids_for_user()` method. the `get_sync_room_ids_for_user()` method.
""" """
user1_id = self.register_user("user1", "pass") user1_id = self.register_user("user1", "pass")
@ -759,6 +798,224 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
) )
class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase):
"""
Tests Sliding Sync handler `get_sync_room_ids_for_user()` to make sure it works with
sharded event stream_writers enabled
"""
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
login.register_servlets,
]
def default_config(self) -> dict:
config = super().default_config()
# Enable sliding sync
config["experimental_features"] = {"msc3575_enabled": True}
# Enable shared event stream_writers
config["stream_writers"] = {"events": ["worker1", "worker2", "worker3"]}
config["instance_map"] = {
"main": {"host": "testserv", "port": 8765},
"worker1": {"host": "testserv", "port": 1001},
"worker2": {"host": "testserv", "port": 1002},
"worker3": {"host": "testserv", "port": 1003},
}
return config
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
self.store = self.hs.get_datastores().main
self.event_sources = hs.get_event_sources()
def _create_room(self, room_id: str, user_id: str, tok: str) -> None:
"""
Create a room with a specific room_id. We use this so that that we have a
consistent room_id across test runs that hashes to the same value and will be
sharded to a known worker in the tests.
"""
# We control the room ID generation by patching out the
# `_generate_room_id` method
with patch(
"synapse.handlers.room.RoomCreationHandler._generate_room_id"
) as mock:
mock.side_effect = lambda: room_id
self.helper.create_room_as(user_id, tok=tok)
def test_sharded_event_persisters(self) -> None:
"""
This test should catch bugs that would come from flawed stream position
(`stream_ordering`) comparisons or making `RoomStreamToken`'s naively. To
compare event positions properly, you need to consider both the `instance_name`
and `stream_ordering` together.
The test creates three event persister workers and a room that is sharded to
each worker. On worker2, we make the event stream position stuck so that it lags
behind the other workers and we start getting `RoomStreamToken` that have an
`instance_map` component (i.e. q`m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}`).
We then send some events to advance the stream positions of worker1 and worker3
but worker2 is lagging behind because it's stuck. We are specifically testing
that `get_sync_room_ids_for_user(from_token=xxx, to_token=xxx)` should work
correctly in these adverse conditions.
"""
user1_id = self.register_user("user1", "pass")
user1_tok = self.login(user1_id, "pass")
user2_id = self.register_user("user2", "pass")
user2_tok = self.login(user2_id, "pass")
self.make_worker_hs(
"synapse.app.generic_worker",
{"worker_name": "worker1"},
)
worker_hs2 = self.make_worker_hs(
"synapse.app.generic_worker",
{"worker_name": "worker2"},
)
self.make_worker_hs(
"synapse.app.generic_worker",
{"worker_name": "worker3"},
)
# Specially crafted room IDs that get persisted on different workers.
#
# Sharded to worker1
room_id1 = "!fooo:test"
# Sharded to worker2
room_id2 = "!bar:test"
# Sharded to worker3
room_id3 = "!quux:test"
# Create rooms on the different workers.
self._create_room(room_id1, user2_id, user2_tok)
self._create_room(room_id2, user2_id, user2_tok)
self._create_room(room_id3, user2_id, user2_tok)
join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok)
# Leave room2
self.helper.leave(room_id2, user1_id, tok=user1_tok)
join_response3 = self.helper.join(room_id3, user1_id, tok=user1_tok)
# Leave room3
self.helper.leave(room_id3, user1_id, tok=user1_tok)
# Ensure that the events were sharded to different workers.
pos1 = self.get_success(
self.store.get_position_for_event(join_response1["event_id"])
)
self.assertEqual(pos1.instance_name, "worker1")
pos2 = self.get_success(
self.store.get_position_for_event(join_response2["event_id"])
)
self.assertEqual(pos2.instance_name, "worker2")
pos3 = self.get_success(
self.store.get_position_for_event(join_response3["event_id"])
)
self.assertEqual(pos3.instance_name, "worker3")
before_stuck_activity_token = self.event_sources.get_current_token()
# We now gut wrench into the events stream `MultiWriterIdGenerator` on worker2 to
# mimic it getting stuck persisting an event. This ensures that when we send an
# event on worker1/worker3 we end up in a state where worker2 events stream
# position lags that on worker1/worker3, resulting in a RoomStreamToken with a
# non-empty `instance_map` component.
#
# Worker2's event stream position will not advance until we call `__aexit__`
# again.
worker_store2 = worker_hs2.get_datastores().main
assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator)
actx = worker_store2._stream_id_gen.get_next()
self.get_success(actx.__aenter__())
# For room_id1/worker1: leave and join the room to advance the stream position
# and generate membership changes.
self.helper.leave(room_id1, user1_id, tok=user1_tok)
self.helper.join(room_id1, user1_id, tok=user1_tok)
# For room_id2/worker2: which is currently stuck, join the room.
join_on_worker2_response = self.helper.join(room_id2, user1_id, tok=user1_tok)
# For room_id3/worker3: leave and join the room to advance the stream position
# and generate membership changes.
self.helper.leave(room_id3, user1_id, tok=user1_tok)
join_on_worker3_response = self.helper.join(room_id3, user1_id, tok=user1_tok)
# Get a token while things are stuck after our activity
stuck_activity_token = self.event_sources.get_current_token()
logger.info("stuck_activity_token %s", stuck_activity_token)
# Let's make sure we're working with a token that has an `instance_map`
self.assertNotEqual(len(stuck_activity_token.room_key.instance_map), 0)
# Just double check that the join event on worker2 (that is stuck) happened
# after the position recorded for worker2 in the token but before the max
# position in the token. This is crucial for the behavior we're trying to test.
join_on_worker2_pos = self.get_success(
self.store.get_position_for_event(join_on_worker2_response["event_id"])
)
logger.info("join_on_worker2_pos %s", join_on_worker2_pos)
# Ensure the join technially came after our token
self.assertGreater(
join_on_worker2_pos.stream,
stuck_activity_token.room_key.get_stream_pos_for_instance("worker2"),
)
# But less than the max stream position of some other worker
self.assertLess(
join_on_worker2_pos.stream,
# max
stuck_activity_token.room_key.get_max_stream_pos(),
)
# Just double check that the join event on worker3 happened after the min stream
# value in the token but still within the position recorded for worker3. This is
# crucial for the behavior we're trying to test.
join_on_worker3_pos = self.get_success(
self.store.get_position_for_event(join_on_worker3_response["event_id"])
)
logger.info("join_on_worker3_pos %s", join_on_worker3_pos)
# Ensure the join came after the min but still encapsulated by the token
self.assertGreaterEqual(
join_on_worker3_pos.stream,
# min
stuck_activity_token.room_key.stream,
)
self.assertLessEqual(
join_on_worker3_pos.stream,
stuck_activity_token.room_key.get_stream_pos_for_instance("worker3"),
)
# We finish the fake persisting an event we started above and advance worker2's
# event stream position (unstuck worker2).
self.get_success(actx.__aexit__(None, None, None))
# The function under test
room_id_results = self.get_success(
self.sliding_sync_handler.get_sync_room_ids_for_user(
UserID.from_string(user1_id),
from_token=before_stuck_activity_token,
to_token=stuck_activity_token,
)
)
self.assertEqual(
room_id_results,
{
room_id1,
# room_id2 shouldn't show up because we left before the from/to range
# and the join event during the range happened while worker2 was stuck.
# This means that from the perspective of the master, where the
# `stuck_activity_token` is generated, the stream position for worker2
# wasn't advanced to the join yet. Looking at the `instance_map`, the
# join technically comes after `stuck_activity_token``.
#
# room_id2,
room_id3,
},
)
class FilterRoomsTestCase(HomeserverTestCase): class FilterRoomsTestCase(HomeserverTestCase):
""" """
Tests Sliding Sync handler `filter_rooms()` to make sure it includes/excludes rooms Tests Sliding Sync handler `filter_rooms()` to make sure it includes/excludes rooms

View file

@ -1061,6 +1061,45 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
{alice: ProfileInfo(display_name=None, avatar_url=MXC_DUMMY)}, {alice: ProfileInfo(display_name=None, avatar_url=MXC_DUMMY)},
) )
def test_search_punctuation(self) -> None:
"""Test that you can search for a user that includes punctuation"""
searching_user = self.register_user("searcher", "password")
searching_user_tok = self.login("searcher", "password")
room_id = self.helper.create_room_as(
searching_user,
room_version=RoomVersions.V1.identifier,
tok=searching_user_tok,
)
# We want to test searching for users of the form e.g. "user-1", with
# various punctuation. We also test both where the prefix is numeric and
# alphanumeric, as e.g. postgres tokenises "user-1" as "user" and "-1".
i = 1
for char in ["-", ".", "_"]:
for use_numeric in [False, True]:
if use_numeric:
prefix1 = f"{i}"
prefix2 = f"{i+1}"
else:
prefix1 = f"a{i}"
prefix2 = f"a{i+1}"
local_user_1 = self.register_user(f"user{char}{prefix1}", "password")
local_user_2 = self.register_user(f"user{char}{prefix2}", "password")
self._add_user_to_room(room_id, RoomVersions.V1, local_user_1)
self._add_user_to_room(room_id, RoomVersions.V1, local_user_2)
results = self.get_success(
self.handler.search_users(searching_user, local_user_1, 20)
)["results"]
received_user_id_ordering = [result["user_id"] for result in results]
self.assertSequenceEqual(received_user_id_ordering[:1], [local_user_1])
i += 2
class TestUserDirSearchDisabled(unittest.HomeserverTestCase): class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
servlets = [ servlets = [

View file

@ -25,7 +25,7 @@ import tempfile
from binascii import unhexlify from binascii import unhexlify
from io import BytesIO from io import BytesIO
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
from unittest.mock import Mock from unittest.mock import MagicMock, Mock, patch
from urllib import parse from urllib import parse
import attr import attr
@ -37,9 +37,12 @@ from twisted.internet import defer
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.python.failure import Failure from twisted.python.failure import Failure
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from twisted.web.http_headers import Headers
from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
from twisted.web.resource import Resource from twisted.web.resource import Resource
from synapse.api.errors import Codes, HttpResponseException from synapse.api.errors import Codes, HttpResponseException
from synapse.api.ratelimiting import Ratelimiter
from synapse.events import EventBase from synapse.events import EventBase
from synapse.http.types import QueryParams from synapse.http.types import QueryParams
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
@ -59,6 +62,7 @@ from synapse.util import Clock
from tests import unittest from tests import unittest
from tests.server import FakeChannel from tests.server import FakeChannel
from tests.test_utils import SMALL_PNG from tests.test_utils import SMALL_PNG
from tests.unittest import override_config
from tests.utils import default_config from tests.utils import default_config
@ -251,9 +255,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
destination: str, destination: str,
path: str, path: str,
output_stream: BinaryIO, output_stream: BinaryIO,
download_ratelimiter: Ratelimiter,
ip_address: Any,
max_size: int,
args: Optional[QueryParams] = None, args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True, retry_on_dns_fail: bool = True,
max_size: Optional[int] = None,
ignore_backoff: bool = False, ignore_backoff: bool = False,
follow_redirects: bool = False, follow_redirects: bool = False,
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]": ) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
@ -878,3 +884,218 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
tok=self.tok, tok=self.tok,
expect_code=400, expect_code=400,
) )
class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
os.mkdir(self.storage_path)
os.mkdir(self.media_store_path)
config["media_store_path"] = self.media_store_path
provider_config = {
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
"store_local": True,
"store_synchronous": False,
"store_remote": True,
"config": {"directory": self.storage_path},
}
config["media_storage_providers"] = [provider_config]
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.repo = hs.get_media_repository()
self.client = hs.get_federation_http_client()
self.store = hs.get_datastores().main
def create_resource_dict(self) -> Dict[str, Resource]:
# We need to manually set the resource tree to include media, the
# default only does `/_matrix/client` APIs.
return {"/_matrix/media": self.hs.get_media_repository_resource()}
# mock actually reading file body
def read_body_with_max_size_30MiB(*args: Any, **kwargs: Any) -> Deferred:
d: Deferred = defer.Deferred()
d.callback(31457280)
return d
def read_body_with_max_size_50MiB(*args: Any, **kwargs: Any) -> Deferred:
d: Deferred = defer.Deferred()
d.callback(52428800)
return d
@patch(
"synapse.http.matrixfederationclient.read_body_with_max_size",
read_body_with_max_size_30MiB,
)
def test_download_ratelimit_default(self) -> None:
"""
Test remote media download ratelimiting against default configuration - 500MB bucket
and 87kb/second drain rate
"""
# mock out actually sending the request, returns a 30MiB response
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
resp = MagicMock(spec=IResponse)
resp.code = 200
resp.length = 31457280
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
resp.phrase = b"OK"
return resp
self.client._send_request = _send_request # type: ignore
# first request should go through
channel = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
shorthand=False,
)
assert channel.code == 200
# next 15 should go through
for i in range(15):
channel2 = self.make_request(
"GET",
f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
shorthand=False,
)
assert channel2.code == 200
# 17th will hit ratelimit
channel3 = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
shorthand=False,
)
assert channel3.code == 429
# however, a request from a different IP will go through
channel4 = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
shorthand=False,
client_ip="187.233.230.159",
)
assert channel4.code == 200
# at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
# 30MiB download is authorized - The last download was blocked at 503,316,480.
# The next download will be authorized when bucket hits 492,830,720
# (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
# needs to drain before another download will be authorized, that will take ~=
# 2 minutes (10,485,760/89,088/60)
self.reactor.pump([2.0 * 60.0])
# enough has drained and next request goes through
channel5 = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyb",
shorthand=False,
)
assert channel5.code == 200
@override_config(
{
"remote_media_download_per_second": "50M",
"remote_media_download_burst_count": "50M",
}
)
@patch(
"synapse.http.matrixfederationclient.read_body_with_max_size",
read_body_with_max_size_50MiB,
)
def test_download_rate_limit_config(self) -> None:
"""
Test that download rate limit config options are correctly picked up and applied
"""
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
resp = MagicMock(spec=IResponse)
resp.code = 200
resp.length = 52428800
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
resp.phrase = b"OK"
return resp
self.client._send_request = _send_request # type: ignore
# first request should go through
channel = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
shorthand=False,
)
assert channel.code == 200
# immediate second request should fail
channel = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy1",
shorthand=False,
)
assert channel.code == 429
# advance half a second
self.reactor.pump([0.5])
# request still fails
channel = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy2",
shorthand=False,
)
assert channel.code == 429
# advance another half second
self.reactor.pump([0.5])
# enough has drained from bucket and request is successful
channel = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy3",
shorthand=False,
)
assert channel.code == 200
@patch(
"synapse.http.matrixfederationclient.read_body_with_max_size",
read_body_with_max_size_30MiB,
)
def test_download_ratelimit_max_size_sub(self) -> None:
"""
Test that if no content-length is provided, the default max size is applied instead
"""
# mock out actually sending the request
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
resp = MagicMock(spec=IResponse)
resp.code = 200
resp.length = UNKNOWN_LENGTH
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
resp.phrase = b"OK"
return resp
self.client._send_request = _send_request # type: ignore
# ten requests should go through using the max size (500MB/50MB)
for i in range(10):
channel2 = self.make_request(
"GET",
f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
shorthand=False,
)
assert channel2.code == 200
# eleventh will hit ratelimit
channel3 = self.make_request(
"GET",
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
shorthand=False,
)
assert channel3.code == 429

View file

@ -711,6 +711,10 @@ class UserDirectoryICUTestCase(HomeserverTestCase):
), ),
) )
self.assertEqual(_parse_words_with_icu("user-1"), ["user-1"])
self.assertEqual(_parse_words_with_icu("user-ab"), ["user-ab"])
self.assertEqual(_parse_words_with_icu("user.--1"), ["user", "-1"])
def test_regex_word_boundary_punctuation(self) -> None: def test_regex_word_boundary_punctuation(self) -> None:
""" """
Tests the behaviour of punctuation with the non-ICU tokeniser Tests the behaviour of punctuation with the non-ICU tokeniser