mirror of
https://github.com/element-hq/synapse
synced 2024-10-01 21:32:40 +00:00
Merge branch 'madlittlemods/msc3575-sliding-sync-0.0.1' into madlittlemods/msc3575-sliding-sync-filtering
Conflicts: tests/handlers/test_sliding_sync.py
This commit is contained in:
commit
5078d36bd3
20 changed files with 808 additions and 108 deletions
1
changelog.d/17254.bugfix
Normal file
1
changelog.d/17254.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix searching for users with their exact localpart whose ID includes a hyphen.
|
1
changelog.d/17256.feature
Normal file
1
changelog.d/17256.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Improve ratelimiting in Synapse (#17256).
|
|
@ -1946,6 +1946,24 @@ Example configuration:
|
||||||
max_image_pixels: 35M
|
max_image_pixels: 35M
|
||||||
```
|
```
|
||||||
---
|
---
|
||||||
|
### `remote_media_download_burst_count`
|
||||||
|
|
||||||
|
Remote media downloads are ratelimited using a [leaky bucket algorithm](https://en.wikipedia.org/wiki/Leaky_bucket), where a given "bucket" is keyed to the IP address of the requester when requesting remote media downloads. This configuration option sets the size of the bucket against which the size in bytes of downloads are penalized - if the bucket is full, ie a given number of bytes have already been downloaded, further downloads will be denied until the bucket drains. Defaults to 500MiB. See also `remote_media_download_per_second` which determines the rate at which the "bucket" is emptied and thus has available space to authorize new requests.
|
||||||
|
|
||||||
|
Example configuration:
|
||||||
|
```yaml
|
||||||
|
remote_media_download_burst_count: 200M
|
||||||
|
```
|
||||||
|
---
|
||||||
|
### `remote_media_download_per_second`
|
||||||
|
|
||||||
|
Works in conjunction with `remote_media_download_burst_count` to ratelimit remote media downloads - this configuration option determines the rate at which the "bucket" (see above) leaks in bytes per second. As requests are made to download remote media, the size of those requests in bytes is added to the bucket, and once the bucket has reached it's capacity, no more requests will be allowed until a number of bytes has "drained" from the bucket. This setting determines the rate at which bytes drain from the bucket, with the practical effect that the larger the number, the faster the bucket leaks, allowing for more bytes downloaded over a shorter period of time. Defaults to 87KiB per second. See also `remote_media_download_burst_count`.
|
||||||
|
|
||||||
|
Example configuration:
|
||||||
|
```yaml
|
||||||
|
remote_media_download_per_second: 40K
|
||||||
|
```
|
||||||
|
---
|
||||||
### `prevent_media_downloads_from`
|
### `prevent_media_downloads_from`
|
||||||
|
|
||||||
A list of domains to never download media from. Media from these
|
A list of domains to never download media from. Media from these
|
||||||
|
|
|
@ -50,7 +50,7 @@ class Membership:
|
||||||
KNOCK: Final = "knock"
|
KNOCK: Final = "knock"
|
||||||
LEAVE: Final = "leave"
|
LEAVE: Final = "leave"
|
||||||
BAN: Final = "ban"
|
BAN: Final = "ban"
|
||||||
LIST: Final = (INVITE, JOIN, KNOCK, LEAVE, BAN)
|
LIST: Final = {INVITE, JOIN, KNOCK, LEAVE, BAN}
|
||||||
|
|
||||||
|
|
||||||
class PresenceState:
|
class PresenceState:
|
||||||
|
|
|
@ -218,3 +218,13 @@ class RatelimitConfig(Config):
|
||||||
"rc_media_create",
|
"rc_media_create",
|
||||||
defaults={"per_second": 10, "burst_count": 50},
|
defaults={"per_second": 10, "burst_count": 50},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.remote_media_downloads = RatelimitSettings(
|
||||||
|
key="rc_remote_media_downloads",
|
||||||
|
per_second=self.parse_size(
|
||||||
|
config.get("remote_media_download_per_second", "87K")
|
||||||
|
),
|
||||||
|
burst_count=self.parse_size(
|
||||||
|
config.get("remote_media_download_burst_count", "500M")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
|
@ -56,6 +56,7 @@ from synapse.api.errors import (
|
||||||
SynapseError,
|
SynapseError,
|
||||||
UnsupportedRoomVersionError,
|
UnsupportedRoomVersionError,
|
||||||
)
|
)
|
||||||
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.api.room_versions import (
|
from synapse.api.room_versions import (
|
||||||
KNOWN_ROOM_VERSIONS,
|
KNOWN_ROOM_VERSIONS,
|
||||||
EventFormatVersions,
|
EventFormatVersions,
|
||||||
|
@ -1877,6 +1878,8 @@ class FederationClient(FederationBase):
|
||||||
output_stream: BinaryIO,
|
output_stream: BinaryIO,
|
||||||
max_size: int,
|
max_size: int,
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: str,
|
||||||
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
||||||
try:
|
try:
|
||||||
return await self.transport_layer.download_media_v3(
|
return await self.transport_layer.download_media_v3(
|
||||||
|
@ -1885,6 +1888,8 @@ class FederationClient(FederationBase):
|
||||||
output_stream=output_stream,
|
output_stream=output_stream,
|
||||||
max_size=max_size,
|
max_size=max_size,
|
||||||
max_timeout_ms=max_timeout_ms,
|
max_timeout_ms=max_timeout_ms,
|
||||||
|
download_ratelimiter=download_ratelimiter,
|
||||||
|
ip_address=ip_address,
|
||||||
)
|
)
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
# If an error is received that is due to an unrecognised endpoint,
|
# If an error is received that is due to an unrecognised endpoint,
|
||||||
|
@ -1905,6 +1910,8 @@ class FederationClient(FederationBase):
|
||||||
output_stream=output_stream,
|
output_stream=output_stream,
|
||||||
max_size=max_size,
|
max_size=max_size,
|
||||||
max_timeout_ms=max_timeout_ms,
|
max_timeout_ms=max_timeout_ms,
|
||||||
|
download_ratelimiter=download_ratelimiter,
|
||||||
|
ip_address=ip_address,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,7 @@ import ijson
|
||||||
|
|
||||||
from synapse.api.constants import Direction, Membership
|
from synapse.api.constants import Direction, Membership
|
||||||
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
from synapse.api.errors import Codes, HttpResponseException, SynapseError
|
||||||
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.api.room_versions import RoomVersion
|
from synapse.api.room_versions import RoomVersion
|
||||||
from synapse.api.urls import (
|
from synapse.api.urls import (
|
||||||
FEDERATION_UNSTABLE_PREFIX,
|
FEDERATION_UNSTABLE_PREFIX,
|
||||||
|
@ -819,6 +820,8 @@ class TransportLayerClient:
|
||||||
output_stream: BinaryIO,
|
output_stream: BinaryIO,
|
||||||
max_size: int,
|
max_size: int,
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: str,
|
||||||
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
||||||
path = f"/_matrix/media/r0/download/{destination}/{media_id}"
|
path = f"/_matrix/media/r0/download/{destination}/{media_id}"
|
||||||
|
|
||||||
|
@ -834,6 +837,8 @@ class TransportLayerClient:
|
||||||
"allow_remote": "false",
|
"allow_remote": "false",
|
||||||
"timeout_ms": str(max_timeout_ms),
|
"timeout_ms": str(max_timeout_ms),
|
||||||
},
|
},
|
||||||
|
download_ratelimiter=download_ratelimiter,
|
||||||
|
ip_address=ip_address,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def download_media_v3(
|
async def download_media_v3(
|
||||||
|
@ -843,6 +848,8 @@ class TransportLayerClient:
|
||||||
output_stream: BinaryIO,
|
output_stream: BinaryIO,
|
||||||
max_size: int,
|
max_size: int,
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: str,
|
||||||
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
||||||
path = f"/_matrix/media/v3/download/{destination}/{media_id}"
|
path = f"/_matrix/media/v3/download/{destination}/{media_id}"
|
||||||
|
|
||||||
|
@ -862,6 +869,8 @@ class TransportLayerClient:
|
||||||
"allow_redirect": "true",
|
"allow_redirect": "true",
|
||||||
},
|
},
|
||||||
follow_redirects=True,
|
follow_redirects=True,
|
||||||
|
download_ratelimiter=download_ratelimiter,
|
||||||
|
ip_address=ip_address,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -42,17 +42,6 @@ if TYPE_CHECKING:
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Everything except `Membership.LEAVE` because we want everything that's *still*
|
|
||||||
# relevant to the user. There are few more things to include in the sync response
|
|
||||||
# (kicks, newly_left) but those are handled separately.
|
|
||||||
MEMBERSHIP_TO_DISPLAY_IN_SYNC = (
|
|
||||||
Membership.INVITE,
|
|
||||||
Membership.JOIN,
|
|
||||||
Membership.KNOCK,
|
|
||||||
Membership.BAN,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_membership_for_sync(*, membership: str, user_id: str, sender: str) -> bool:
|
def filter_membership_for_sync(*, membership: str, user_id: str, sender: str) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns True if the membership event should be included in the sync response,
|
Returns True if the membership event should be included in the sync response,
|
||||||
|
@ -65,7 +54,10 @@ def filter_membership_for_sync(*, membership: str, user_id: str, sender: str) ->
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return (
|
return (
|
||||||
membership in MEMBERSHIP_TO_DISPLAY_IN_SYNC
|
# Everything except `Membership.LEAVE` because we want everything that's *still*
|
||||||
|
# relevant to the user. There are few more things to include in the sync response
|
||||||
|
# (newly_left) but those are handled separately.
|
||||||
|
membership in (Membership.LIST - {Membership.LEAVE})
|
||||||
# Include kicks
|
# Include kicks
|
||||||
or (membership == Membership.LEAVE and sender != user_id)
|
or (membership == Membership.LEAVE and sender != user_id)
|
||||||
)
|
)
|
||||||
|
@ -233,7 +225,6 @@ class SlidingSyncResult:
|
||||||
|
|
||||||
class SlidingSyncHandler:
|
class SlidingSyncHandler:
|
||||||
def __init__(self, hs: "HomeServer"):
|
def __init__(self, hs: "HomeServer"):
|
||||||
self.hs_config = hs.config
|
|
||||||
self.clock = hs.get_clock()
|
self.clock = hs.get_clock()
|
||||||
self.store = hs.get_datastores().main
|
self.store = hs.get_datastores().main
|
||||||
self.auth_blocking = hs.get_auth_blocking()
|
self.auth_blocking = hs.get_auth_blocking()
|
||||||
|
@ -385,13 +376,23 @@ class SlidingSyncHandler:
|
||||||
Fetch room IDs that should be listed for this user in the sync response (the
|
Fetch room IDs that should be listed for this user in the sync response (the
|
||||||
full room list that will be filtered, sorted, and sliced).
|
full room list that will be filtered, sorted, and sliced).
|
||||||
|
|
||||||
We're looking for rooms that the user has not left (`invite`, `knock`, `join`,
|
We're looking for rooms where the user has the following state in the token
|
||||||
and `ban`), or kicks (`leave` where the `sender` is different from the
|
range (> `from_token` and <= `to_token`):
|
||||||
`state_key`), or newly_left rooms that are > `from_token` and <= `to_token`.
|
|
||||||
|
- `invite`, `join`, `knock`, `ban` membership events
|
||||||
|
- Kicks (`leave` membership events where `sender` is different from the
|
||||||
|
`user_id`/`state_key`)
|
||||||
|
- `newly_left` (rooms that were left during the given token range)
|
||||||
|
- In order for bans/kicks to not show up in sync, you need to `/forget` those
|
||||||
|
rooms. This doesn't modify the event itself though and only adds the
|
||||||
|
`forgotten` flag to the `room_memberships` table in Synapse. There isn't a way
|
||||||
|
to tell when a room was forgotten at the moment so we can't factor it into the
|
||||||
|
from/to range.
|
||||||
"""
|
"""
|
||||||
user_id = user.to_string()
|
user_id = user.to_string()
|
||||||
|
|
||||||
# First grab a current snapshot rooms for the user
|
# First grab a current snapshot rooms for the user
|
||||||
|
# (also handles forgotten rooms)
|
||||||
room_for_user_list = await self.store.get_rooms_for_local_user_where_membership_is(
|
room_for_user_list = await self.store.get_rooms_for_local_user_where_membership_is(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
# We want to fetch any kind of membership (joined and left rooms) in order
|
# We want to fetch any kind of membership (joined and left rooms) in order
|
||||||
|
@ -441,10 +442,7 @@ class SlidingSyncHandler:
|
||||||
# Then assemble the `RoomStreamToken`
|
# Then assemble the `RoomStreamToken`
|
||||||
membership_snapshot_token = RoomStreamToken(
|
membership_snapshot_token = RoomStreamToken(
|
||||||
# Minimum position in the `instance_map`
|
# Minimum position in the `instance_map`
|
||||||
stream=min(
|
stream=min(instance_to_max_stream_ordering_map.values()),
|
||||||
stream_ordering
|
|
||||||
for stream_ordering in instance_to_max_stream_ordering_map.values()
|
|
||||||
),
|
|
||||||
instance_map=immutabledict(instance_to_max_stream_ordering_map),
|
instance_map=immutabledict(instance_to_max_stream_ordering_map),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -454,26 +452,13 @@ class SlidingSyncHandler:
|
||||||
if membership_snapshot_token.is_before_or_eq(to_token.room_key):
|
if membership_snapshot_token.is_before_or_eq(to_token.room_key):
|
||||||
return sync_room_id_set
|
return sync_room_id_set
|
||||||
|
|
||||||
# We assume the `from_token` is before or at-least equal to the `to_token`
|
|
||||||
assert from_token is None or from_token.room_key.is_before_or_eq(
|
|
||||||
to_token.room_key
|
|
||||||
), f"{from_token.room_key if from_token else None} < {to_token.room_key}"
|
|
||||||
|
|
||||||
# We assume the `from_token`/`to_token` is before the `membership_snapshot_token`
|
|
||||||
assert from_token is None or from_token.room_key.is_before_or_eq(
|
|
||||||
membership_snapshot_token
|
|
||||||
), f"{from_token.room_key if from_token else None} < {membership_snapshot_token}"
|
|
||||||
assert to_token.room_key.is_before_or_eq(
|
|
||||||
membership_snapshot_token
|
|
||||||
), f"{to_token.room_key} < {membership_snapshot_token}"
|
|
||||||
|
|
||||||
# Since we fetched the users room list at some point in time after the from/to
|
# Since we fetched the users room list at some point in time after the from/to
|
||||||
# tokens, we need to revert/rewind some membership changes to match the point in
|
# tokens, we need to revert/rewind some membership changes to match the point in
|
||||||
# time of the `to_token`. In particular, we need to make these fixups:
|
# time of the `to_token`. In particular, we need to make these fixups:
|
||||||
#
|
#
|
||||||
# - 1) Add back newly_left rooms (> `from_token` and <= `to_token`)
|
# - 1a) Remove rooms that the user joined after the `to_token`
|
||||||
# - 2a) Remove rooms that the user joined after the `to_token`
|
# - 1b) Add back rooms that the user left after the `to_token`
|
||||||
# - 2b) Add back rooms that the user left after the `to_token`
|
# - 2) Add back newly_left rooms (> `from_token` and <= `to_token`)
|
||||||
#
|
#
|
||||||
# Below, we're doing two separate lookups for membership changes. We could
|
# Below, we're doing two separate lookups for membership changes. We could
|
||||||
# request everything for both fixups in one range, [`from_token.room_key`,
|
# request everything for both fixups in one range, [`from_token.room_key`,
|
||||||
|
@ -484,41 +469,7 @@ class SlidingSyncHandler:
|
||||||
|
|
||||||
# 1) -----------------------------------------------------
|
# 1) -----------------------------------------------------
|
||||||
|
|
||||||
# 1) Fetch membership changes that fall in the range from `from_token` up to `to_token`
|
# 1) Fetch membership changes that fall in the range from `to_token` up to
|
||||||
membership_change_events_in_from_to_range = []
|
|
||||||
if from_token:
|
|
||||||
membership_change_events_in_from_to_range = (
|
|
||||||
await self.store.get_membership_changes_for_user(
|
|
||||||
user_id,
|
|
||||||
from_key=from_token.room_key,
|
|
||||||
to_key=to_token.room_key,
|
|
||||||
excluded_rooms=self.rooms_to_exclude_globally,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 1) Assemble a list of the last membership events in some given ranges. Someone
|
|
||||||
# could have left and joined multiple times during the given range but we only
|
|
||||||
# care about end-result so we grab the last one.
|
|
||||||
last_membership_change_by_room_id_in_from_to_range: Dict[str, EventBase] = {}
|
|
||||||
for event in membership_change_events_in_from_to_range:
|
|
||||||
assert event.internal_metadata.stream_ordering
|
|
||||||
last_membership_change_by_room_id_in_from_to_range[event.room_id] = event
|
|
||||||
|
|
||||||
# 1) Fixup
|
|
||||||
for (
|
|
||||||
last_membership_change_in_from_to_range
|
|
||||||
) in last_membership_change_by_room_id_in_from_to_range.values():
|
|
||||||
room_id = last_membership_change_in_from_to_range.room_id
|
|
||||||
|
|
||||||
# 1) Add back newly_left rooms (> `from_token` and <= `to_token`). We
|
|
||||||
# include newly_left rooms because the last event that the user should see
|
|
||||||
# is their own leave event
|
|
||||||
if last_membership_change_in_from_to_range.membership == Membership.LEAVE:
|
|
||||||
sync_room_id_set.add(room_id)
|
|
||||||
|
|
||||||
# 2) -----------------------------------------------------
|
|
||||||
|
|
||||||
# 2) Fetch membership changes that fall in the range from `to_token` up to
|
|
||||||
# `membership_snapshot_token`
|
# `membership_snapshot_token`
|
||||||
membership_change_events_after_to_token = (
|
membership_change_events_after_to_token = (
|
||||||
await self.store.get_membership_changes_for_user(
|
await self.store.get_membership_changes_for_user(
|
||||||
|
@ -529,7 +480,7 @@ class SlidingSyncHandler:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2) Assemble a list of the last membership events in some given ranges. Someone
|
# 1) Assemble a list of the last membership events in some given ranges. Someone
|
||||||
# could have left and joined multiple times during the given range but we only
|
# could have left and joined multiple times during the given range but we only
|
||||||
# care about end-result so we grab the last one.
|
# care about end-result so we grab the last one.
|
||||||
last_membership_change_by_room_id_after_to_token: Dict[str, EventBase] = {}
|
last_membership_change_by_room_id_after_to_token: Dict[str, EventBase] = {}
|
||||||
|
@ -537,15 +488,13 @@ class SlidingSyncHandler:
|
||||||
# backward to the previous membership that would apply to the from/to range.
|
# backward to the previous membership that would apply to the from/to range.
|
||||||
first_membership_change_by_room_id_after_to_token: Dict[str, EventBase] = {}
|
first_membership_change_by_room_id_after_to_token: Dict[str, EventBase] = {}
|
||||||
for event in membership_change_events_after_to_token:
|
for event in membership_change_events_after_to_token:
|
||||||
assert event.internal_metadata.stream_ordering
|
|
||||||
|
|
||||||
last_membership_change_by_room_id_after_to_token[event.room_id] = event
|
last_membership_change_by_room_id_after_to_token[event.room_id] = event
|
||||||
# Only set if we haven't already set it
|
# Only set if we haven't already set it
|
||||||
first_membership_change_by_room_id_after_to_token.setdefault(
|
first_membership_change_by_room_id_after_to_token.setdefault(
|
||||||
event.room_id, event
|
event.room_id, event
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2) Fixup
|
# 1) Fixup
|
||||||
for (
|
for (
|
||||||
last_membership_change_after_to_token
|
last_membership_change_after_to_token
|
||||||
) in last_membership_change_by_room_id_after_to_token.values():
|
) in last_membership_change_by_room_id_after_to_token.values():
|
||||||
|
@ -602,7 +551,7 @@ class SlidingSyncHandler:
|
||||||
sender=last_membership_change_after_to_token.sender,
|
sender=last_membership_change_after_to_token.sender,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2a) Add back rooms that the user left after the `to_token`
|
# 1a) Add back rooms that the user left after the `to_token`
|
||||||
#
|
#
|
||||||
# For example, if the last membership event after the `to_token` is a leave
|
# For example, if the last membership event after the `to_token` is a leave
|
||||||
# event, then the room was excluded from `sync_room_id_set` when we first
|
# event, then the room was excluded from `sync_room_id_set` when we first
|
||||||
|
@ -613,7 +562,7 @@ class SlidingSyncHandler:
|
||||||
and should_prev_membership_be_included
|
and should_prev_membership_be_included
|
||||||
):
|
):
|
||||||
sync_room_id_set.add(room_id)
|
sync_room_id_set.add(room_id)
|
||||||
# 2b) Remove rooms that the user joined (hasn't left) after the `to_token`
|
# 1b) Remove rooms that the user joined (hasn't left) after the `to_token`
|
||||||
#
|
#
|
||||||
# For example, if the last membership event after the `to_token` is a "join"
|
# For example, if the last membership event after the `to_token` is a "join"
|
||||||
# event, then the room was included `sync_room_id_set` when we first crafted
|
# event, then the room was included `sync_room_id_set` when we first crafted
|
||||||
|
@ -625,6 +574,41 @@ class SlidingSyncHandler:
|
||||||
):
|
):
|
||||||
sync_room_id_set.discard(room_id)
|
sync_room_id_set.discard(room_id)
|
||||||
|
|
||||||
|
# 2) -----------------------------------------------------
|
||||||
|
# We fix-up newly_left rooms after the first fixup because it may have removed
|
||||||
|
# some left rooms that we can figure out our newly_left in the following code
|
||||||
|
|
||||||
|
# 2) Fetch membership changes that fall in the range from `from_token` up to `to_token`
|
||||||
|
membership_change_events_in_from_to_range = []
|
||||||
|
if from_token:
|
||||||
|
membership_change_events_in_from_to_range = (
|
||||||
|
await self.store.get_membership_changes_for_user(
|
||||||
|
user_id,
|
||||||
|
from_key=from_token.room_key,
|
||||||
|
to_key=to_token.room_key,
|
||||||
|
excluded_rooms=self.rooms_to_exclude_globally,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2) Assemble a list of the last membership events in some given ranges. Someone
|
||||||
|
# could have left and joined multiple times during the given range but we only
|
||||||
|
# care about end-result so we grab the last one.
|
||||||
|
last_membership_change_by_room_id_in_from_to_range: Dict[str, EventBase] = {}
|
||||||
|
for event in membership_change_events_in_from_to_range:
|
||||||
|
last_membership_change_by_room_id_in_from_to_range[event.room_id] = event
|
||||||
|
|
||||||
|
# 2) Fixup
|
||||||
|
for (
|
||||||
|
last_membership_change_in_from_to_range
|
||||||
|
) in last_membership_change_by_room_id_in_from_to_range.values():
|
||||||
|
room_id = last_membership_change_in_from_to_range.room_id
|
||||||
|
|
||||||
|
# 2) Add back newly_left rooms (> `from_token` and <= `to_token`). We
|
||||||
|
# include newly_left rooms because the last event that the user should see
|
||||||
|
# is their own leave event
|
||||||
|
if last_membership_change_in_from_to_range.membership == Membership.LEAVE:
|
||||||
|
sync_room_id_set.add(room_id)
|
||||||
|
|
||||||
return sync_room_id_set
|
return sync_room_id_set
|
||||||
|
|
||||||
async def filter_rooms(
|
async def filter_rooms(
|
||||||
|
|
|
@ -57,7 +57,7 @@ from twisted.internet.interfaces import IReactorTime
|
||||||
from twisted.internet.task import Cooperator
|
from twisted.internet.task import Cooperator
|
||||||
from twisted.web.client import ResponseFailed
|
from twisted.web.client import ResponseFailed
|
||||||
from twisted.web.http_headers import Headers
|
from twisted.web.http_headers import Headers
|
||||||
from twisted.web.iweb import IAgent, IBodyProducer, IResponse
|
from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
|
||||||
|
|
||||||
import synapse.metrics
|
import synapse.metrics
|
||||||
import synapse.util.retryutils
|
import synapse.util.retryutils
|
||||||
|
@ -68,6 +68,7 @@ from synapse.api.errors import (
|
||||||
RequestSendFailed,
|
RequestSendFailed,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.crypto.context_factory import FederationPolicyForHTTPS
|
from synapse.crypto.context_factory import FederationPolicyForHTTPS
|
||||||
from synapse.http import QuieterFileBodyProducer
|
from synapse.http import QuieterFileBodyProducer
|
||||||
from synapse.http.client import (
|
from synapse.http.client import (
|
||||||
|
@ -1411,9 +1412,11 @@ class MatrixFederationHttpClient:
|
||||||
destination: str,
|
destination: str,
|
||||||
path: str,
|
path: str,
|
||||||
output_stream: BinaryIO,
|
output_stream: BinaryIO,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: str,
|
||||||
|
max_size: int,
|
||||||
args: Optional[QueryParams] = None,
|
args: Optional[QueryParams] = None,
|
||||||
retry_on_dns_fail: bool = True,
|
retry_on_dns_fail: bool = True,
|
||||||
max_size: Optional[int] = None,
|
|
||||||
ignore_backoff: bool = False,
|
ignore_backoff: bool = False,
|
||||||
follow_redirects: bool = False,
|
follow_redirects: bool = False,
|
||||||
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
) -> Tuple[int, Dict[bytes, List[bytes]]]:
|
||||||
|
@ -1422,6 +1425,10 @@ class MatrixFederationHttpClient:
|
||||||
destination: The remote server to send the HTTP request to.
|
destination: The remote server to send the HTTP request to.
|
||||||
path: The HTTP path to GET.
|
path: The HTTP path to GET.
|
||||||
output_stream: File to write the response body to.
|
output_stream: File to write the response body to.
|
||||||
|
download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
|
||||||
|
requester IP
|
||||||
|
ip_address: IP address of the requester
|
||||||
|
max_size: maximum allowable size in bytes of the file
|
||||||
args: Optional dictionary used to create the query string.
|
args: Optional dictionary used to create the query string.
|
||||||
ignore_backoff: true to ignore the historical backoff data
|
ignore_backoff: true to ignore the historical backoff data
|
||||||
and try the request anyway.
|
and try the request anyway.
|
||||||
|
@ -1441,11 +1448,27 @@ class MatrixFederationHttpClient:
|
||||||
federation whitelist
|
federation whitelist
|
||||||
RequestSendFailed: If there were problems connecting to the
|
RequestSendFailed: If there were problems connecting to the
|
||||||
remote, due to e.g. DNS failures, connection timeouts etc.
|
remote, due to e.g. DNS failures, connection timeouts etc.
|
||||||
|
SynapseError: If the requested file exceeds ratelimits
|
||||||
"""
|
"""
|
||||||
request = MatrixFederationRequest(
|
request = MatrixFederationRequest(
|
||||||
method="GET", destination=destination, path=path, query=args
|
method="GET", destination=destination, path=path, query=args
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# check for a minimum balance of 1MiB in ratelimiter before initiating request
|
||||||
|
send_req, _ = await download_ratelimiter.can_do_action(
|
||||||
|
requester=None, key=ip_address, n_actions=1048576, update=False
|
||||||
|
)
|
||||||
|
|
||||||
|
if not send_req:
|
||||||
|
msg = "Requested file size exceeds ratelimits"
|
||||||
|
logger.warning(
|
||||||
|
"{%s} [%s] %s",
|
||||||
|
request.txn_id,
|
||||||
|
request.destination,
|
||||||
|
msg,
|
||||||
|
)
|
||||||
|
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
|
||||||
|
|
||||||
response = await self._send_request(
|
response = await self._send_request(
|
||||||
request,
|
request,
|
||||||
retry_on_dns_fail=retry_on_dns_fail,
|
retry_on_dns_fail=retry_on_dns_fail,
|
||||||
|
@ -1455,12 +1478,36 @@ class MatrixFederationHttpClient:
|
||||||
|
|
||||||
headers = dict(response.headers.getAllRawHeaders())
|
headers = dict(response.headers.getAllRawHeaders())
|
||||||
|
|
||||||
|
expected_size = response.length
|
||||||
|
# if we don't get an expected length then use the max length
|
||||||
|
if expected_size == UNKNOWN_LENGTH:
|
||||||
|
expected_size = max_size
|
||||||
|
logger.debug(
|
||||||
|
f"File size unknown, assuming file is max allowable size: {max_size}"
|
||||||
|
)
|
||||||
|
|
||||||
|
read_body, _ = await download_ratelimiter.can_do_action(
|
||||||
|
requester=None,
|
||||||
|
key=ip_address,
|
||||||
|
n_actions=expected_size,
|
||||||
|
)
|
||||||
|
if not read_body:
|
||||||
|
msg = "Requested file size exceeds ratelimits"
|
||||||
|
logger.warning(
|
||||||
|
"{%s} [%s] %s",
|
||||||
|
request.txn_id,
|
||||||
|
request.destination,
|
||||||
|
msg,
|
||||||
|
)
|
||||||
|
raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
d = read_body_with_max_size(response, output_stream, max_size)
|
# add a byte of headroom to max size as function errs at >=
|
||||||
|
d = read_body_with_max_size(response, output_stream, expected_size + 1)
|
||||||
d.addTimeout(self.default_timeout_seconds, self.reactor)
|
d.addTimeout(self.default_timeout_seconds, self.reactor)
|
||||||
length = await make_deferred_yieldable(d)
|
length = await make_deferred_yieldable(d)
|
||||||
except BodyExceededMaxSize:
|
except BodyExceededMaxSize:
|
||||||
msg = "Requested file is too large > %r bytes" % (max_size,)
|
msg = "Requested file is too large > %r bytes" % (expected_size,)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"{%s} [%s] %s",
|
"{%s} [%s] %s",
|
||||||
request.txn_id,
|
request.txn_id,
|
||||||
|
|
|
@ -42,6 +42,7 @@ from synapse.api.errors import (
|
||||||
SynapseError,
|
SynapseError,
|
||||||
cs_error,
|
cs_error,
|
||||||
)
|
)
|
||||||
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.config.repository import ThumbnailRequirement
|
from synapse.config.repository import ThumbnailRequirement
|
||||||
from synapse.http.server import respond_with_json
|
from synapse.http.server import respond_with_json
|
||||||
from synapse.http.site import SynapseRequest
|
from synapse.http.site import SynapseRequest
|
||||||
|
@ -111,6 +112,12 @@ class MediaRepository:
|
||||||
)
|
)
|
||||||
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
|
self.prevent_media_downloads_from = hs.config.media.prevent_media_downloads_from
|
||||||
|
|
||||||
|
self.download_ratelimiter = Ratelimiter(
|
||||||
|
store=hs.get_storage_controllers().main,
|
||||||
|
clock=hs.get_clock(),
|
||||||
|
cfg=hs.config.ratelimiting.remote_media_downloads,
|
||||||
|
)
|
||||||
|
|
||||||
# List of StorageProviders where we should search for media and
|
# List of StorageProviders where we should search for media and
|
||||||
# potentially upload to.
|
# potentially upload to.
|
||||||
storage_providers = []
|
storage_providers = []
|
||||||
|
@ -464,6 +471,7 @@ class MediaRepository:
|
||||||
media_id: str,
|
media_id: str,
|
||||||
name: Optional[str],
|
name: Optional[str],
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
ip_address: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Respond to requests for remote media.
|
"""Respond to requests for remote media.
|
||||||
|
|
||||||
|
@ -475,6 +483,7 @@ class MediaRepository:
|
||||||
the filename in the Content-Disposition header of the response.
|
the filename in the Content-Disposition header of the response.
|
||||||
max_timeout_ms: the maximum number of milliseconds to wait for the
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
media to be uploaded.
|
media to be uploaded.
|
||||||
|
ip_address: the IP address of the requester
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Resolves once a response has successfully been written to request
|
Resolves once a response has successfully been written to request
|
||||||
|
@ -500,7 +509,11 @@ class MediaRepository:
|
||||||
key = (server_name, media_id)
|
key = (server_name, media_id)
|
||||||
async with self.remote_media_linearizer.queue(key):
|
async with self.remote_media_linearizer.queue(key):
|
||||||
responder, media_info = await self._get_remote_media_impl(
|
responder, media_info = await self._get_remote_media_impl(
|
||||||
server_name, media_id, max_timeout_ms
|
server_name,
|
||||||
|
media_id,
|
||||||
|
max_timeout_ms,
|
||||||
|
self.download_ratelimiter,
|
||||||
|
ip_address,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We deliberately stream the file outside the lock
|
# We deliberately stream the file outside the lock
|
||||||
|
@ -517,7 +530,7 @@ class MediaRepository:
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
|
|
||||||
async def get_remote_media_info(
|
async def get_remote_media_info(
|
||||||
self, server_name: str, media_id: str, max_timeout_ms: int
|
self, server_name: str, media_id: str, max_timeout_ms: int, ip_address: str
|
||||||
) -> RemoteMedia:
|
) -> RemoteMedia:
|
||||||
"""Gets the media info associated with the remote file, downloading
|
"""Gets the media info associated with the remote file, downloading
|
||||||
if necessary.
|
if necessary.
|
||||||
|
@ -527,6 +540,7 @@ class MediaRepository:
|
||||||
media_id: The media ID of the content (as defined by the remote server).
|
media_id: The media ID of the content (as defined by the remote server).
|
||||||
max_timeout_ms: the maximum number of milliseconds to wait for the
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
media to be uploaded.
|
media to be uploaded.
|
||||||
|
ip_address: IP address of the requester
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The media info of the file
|
The media info of the file
|
||||||
|
@ -542,7 +556,11 @@ class MediaRepository:
|
||||||
key = (server_name, media_id)
|
key = (server_name, media_id)
|
||||||
async with self.remote_media_linearizer.queue(key):
|
async with self.remote_media_linearizer.queue(key):
|
||||||
responder, media_info = await self._get_remote_media_impl(
|
responder, media_info = await self._get_remote_media_impl(
|
||||||
server_name, media_id, max_timeout_ms
|
server_name,
|
||||||
|
media_id,
|
||||||
|
max_timeout_ms,
|
||||||
|
self.download_ratelimiter,
|
||||||
|
ip_address,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Ensure we actually use the responder so that it releases resources
|
# Ensure we actually use the responder so that it releases resources
|
||||||
|
@ -553,7 +571,12 @@ class MediaRepository:
|
||||||
return media_info
|
return media_info
|
||||||
|
|
||||||
async def _get_remote_media_impl(
|
async def _get_remote_media_impl(
|
||||||
self, server_name: str, media_id: str, max_timeout_ms: int
|
self,
|
||||||
|
server_name: str,
|
||||||
|
media_id: str,
|
||||||
|
max_timeout_ms: int,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: str,
|
||||||
) -> Tuple[Optional[Responder], RemoteMedia]:
|
) -> Tuple[Optional[Responder], RemoteMedia]:
|
||||||
"""Looks for media in local cache, if not there then attempt to
|
"""Looks for media in local cache, if not there then attempt to
|
||||||
download from remote server.
|
download from remote server.
|
||||||
|
@ -564,6 +587,9 @@ class MediaRepository:
|
||||||
remote server).
|
remote server).
|
||||||
max_timeout_ms: the maximum number of milliseconds to wait for the
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
media to be uploaded.
|
media to be uploaded.
|
||||||
|
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
|
||||||
|
requester IP.
|
||||||
|
ip_address: the IP address of the requester
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of responder and the media info of the file.
|
A tuple of responder and the media info of the file.
|
||||||
|
@ -596,7 +622,7 @@ class MediaRepository:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
media_info = await self._download_remote_file(
|
media_info = await self._download_remote_file(
|
||||||
server_name, media_id, max_timeout_ms
|
server_name, media_id, max_timeout_ms, download_ratelimiter, ip_address
|
||||||
)
|
)
|
||||||
except SynapseError:
|
except SynapseError:
|
||||||
raise
|
raise
|
||||||
|
@ -630,6 +656,8 @@ class MediaRepository:
|
||||||
server_name: str,
|
server_name: str,
|
||||||
media_id: str,
|
media_id: str,
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: str,
|
||||||
) -> RemoteMedia:
|
) -> RemoteMedia:
|
||||||
"""Attempt to download the remote file from the given server name,
|
"""Attempt to download the remote file from the given server name,
|
||||||
using the given file_id as the local id.
|
using the given file_id as the local id.
|
||||||
|
@ -641,6 +669,9 @@ class MediaRepository:
|
||||||
locally generated.
|
locally generated.
|
||||||
max_timeout_ms: the maximum number of milliseconds to wait for the
|
max_timeout_ms: the maximum number of milliseconds to wait for the
|
||||||
media to be uploaded.
|
media to be uploaded.
|
||||||
|
download_ratelimiter: a ratelimiter limiting remote media downloads, keyed to
|
||||||
|
requester IP
|
||||||
|
ip_address: the IP address of the requester
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The media info of the file.
|
The media info of the file.
|
||||||
|
@ -658,6 +689,8 @@ class MediaRepository:
|
||||||
output_stream=f,
|
output_stream=f,
|
||||||
max_size=self.max_upload_size,
|
max_size=self.max_upload_size,
|
||||||
max_timeout_ms=max_timeout_ms,
|
max_timeout_ms=max_timeout_ms,
|
||||||
|
download_ratelimiter=download_ratelimiter,
|
||||||
|
ip_address=ip_address,
|
||||||
)
|
)
|
||||||
except RequestSendFailed as e:
|
except RequestSendFailed as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
@ -359,9 +359,10 @@ class ThumbnailProvider:
|
||||||
desired_method: str,
|
desired_method: str,
|
||||||
desired_type: str,
|
desired_type: str,
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
ip_address: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
media_info = await self.media_repo.get_remote_media_info(
|
media_info = await self.media_repo.get_remote_media_info(
|
||||||
server_name, media_id, max_timeout_ms
|
server_name, media_id, max_timeout_ms, ip_address
|
||||||
)
|
)
|
||||||
if not media_info:
|
if not media_info:
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
|
@ -422,12 +423,13 @@ class ThumbnailProvider:
|
||||||
method: str,
|
method: str,
|
||||||
m_type: str,
|
m_type: str,
|
||||||
max_timeout_ms: int,
|
max_timeout_ms: int,
|
||||||
|
ip_address: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO: Don't download the whole remote file
|
# TODO: Don't download the whole remote file
|
||||||
# We should proxy the thumbnail from the remote server instead of
|
# We should proxy the thumbnail from the remote server instead of
|
||||||
# downloading the remote file and generating our own thumbnails.
|
# downloading the remote file and generating our own thumbnails.
|
||||||
media_info = await self.media_repo.get_remote_media_info(
|
media_info = await self.media_repo.get_remote_media_info(
|
||||||
server_name, media_id, max_timeout_ms
|
server_name, media_id, max_timeout_ms, ip_address
|
||||||
)
|
)
|
||||||
if not media_info:
|
if not media_info:
|
||||||
return
|
return
|
||||||
|
|
|
@ -174,6 +174,7 @@ class UnstableThumbnailResource(RestServlet):
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
ip_address = request.getClientAddress().host
|
||||||
remote_resp_function = (
|
remote_resp_function = (
|
||||||
self.thumbnailer.select_or_generate_remote_thumbnail
|
self.thumbnailer.select_or_generate_remote_thumbnail
|
||||||
if self.dynamic_thumbnails
|
if self.dynamic_thumbnails
|
||||||
|
@ -188,6 +189,7 @@ class UnstableThumbnailResource(RestServlet):
|
||||||
method,
|
method,
|
||||||
m_type,
|
m_type,
|
||||||
max_timeout_ms,
|
max_timeout_ms,
|
||||||
|
ip_address,
|
||||||
)
|
)
|
||||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||||
|
|
||||||
|
|
|
@ -292,6 +292,9 @@ class RoomStateEventRestServlet(RestServlet):
|
||||||
try:
|
try:
|
||||||
if event_type == EventTypes.Member:
|
if event_type == EventTypes.Member:
|
||||||
membership = content.get("membership", None)
|
membership = content.get("membership", None)
|
||||||
|
if not isinstance(membership, str):
|
||||||
|
raise SynapseError(400, "Invalid membership (must be a string)")
|
||||||
|
|
||||||
event_id, _ = await self.room_member_handler.update_membership(
|
event_id, _ = await self.room_member_handler.update_membership(
|
||||||
requester,
|
requester,
|
||||||
target=UserID.from_string(state_key),
|
target=UserID.from_string(state_key),
|
||||||
|
|
|
@ -97,6 +97,12 @@ class DownloadResource(RestServlet):
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
ip_address = request.getClientAddress().host
|
||||||
await self.media_repo.get_remote_media(
|
await self.media_repo.get_remote_media(
|
||||||
request, server_name, media_id, file_name, max_timeout_ms
|
request,
|
||||||
|
server_name,
|
||||||
|
media_id,
|
||||||
|
file_name,
|
||||||
|
max_timeout_ms,
|
||||||
|
ip_address,
|
||||||
)
|
)
|
||||||
|
|
|
@ -104,6 +104,7 @@ class ThumbnailResource(RestServlet):
|
||||||
respond_404(request)
|
respond_404(request)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
ip_address = request.getClientAddress().host
|
||||||
remote_resp_function = (
|
remote_resp_function = (
|
||||||
self.thumbnail_provider.select_or_generate_remote_thumbnail
|
self.thumbnail_provider.select_or_generate_remote_thumbnail
|
||||||
if self.dynamic_thumbnails
|
if self.dynamic_thumbnails
|
||||||
|
@ -118,5 +119,6 @@ class ThumbnailResource(RestServlet):
|
||||||
method,
|
method,
|
||||||
m_type,
|
m_type,
|
||||||
max_timeout_ms,
|
max_timeout_ms,
|
||||||
|
ip_address,
|
||||||
)
|
)
|
||||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||||
|
|
|
@ -1281,7 +1281,7 @@ def _parse_words_with_regex(search_term: str) -> List[str]:
|
||||||
Break down search term into words, when we don't have ICU available.
|
Break down search term into words, when we don't have ICU available.
|
||||||
See: `_parse_words`
|
See: `_parse_words`
|
||||||
"""
|
"""
|
||||||
return re.findall(r"([\w\-]+)", search_term, re.UNICODE)
|
return re.findall(r"([\w-]+)", search_term, re.UNICODE)
|
||||||
|
|
||||||
|
|
||||||
def _parse_words_with_icu(search_term: str) -> List[str]:
|
def _parse_words_with_icu(search_term: str) -> List[str]:
|
||||||
|
@ -1303,15 +1303,69 @@ def _parse_words_with_icu(search_term: str) -> List[str]:
|
||||||
if j < 0:
|
if j < 0:
|
||||||
break
|
break
|
||||||
|
|
||||||
result = search_term[i:j]
|
# We want to make sure that we split on `@` and `:` specifically, as
|
||||||
|
# they occur in user IDs.
|
||||||
|
for result in re.split(r"[@:]+", search_term[i:j]):
|
||||||
|
results.append(result.strip())
|
||||||
|
|
||||||
|
i = j
|
||||||
|
|
||||||
|
# libicu will break up words that have punctuation in them, but to handle
|
||||||
|
# cases where user IDs have '-', '.' and '_' in them we want to *not* break
|
||||||
|
# those into words and instead allow the DB to tokenise them how it wants.
|
||||||
|
#
|
||||||
|
# In particular, user-71 in postgres gets tokenised to "user, -71", and this
|
||||||
|
# will not match a query for "user, 71".
|
||||||
|
new_results: List[str] = []
|
||||||
|
i = 0
|
||||||
|
while i < len(results):
|
||||||
|
curr = results[i]
|
||||||
|
|
||||||
|
prev = None
|
||||||
|
next = None
|
||||||
|
if i > 0:
|
||||||
|
prev = results[i - 1]
|
||||||
|
if i + 1 < len(results):
|
||||||
|
next = results[i + 1]
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
# libicu considers spaces and punctuation between words as words, but we don't
|
# libicu considers spaces and punctuation between words as words, but we don't
|
||||||
# want to include those in results as they would result in syntax errors in SQL
|
# want to include those in results as they would result in syntax errors in SQL
|
||||||
# queries (e.g. "foo bar" would result in the search query including "foo & &
|
# queries (e.g. "foo bar" would result in the search query including "foo & &
|
||||||
# bar").
|
# bar").
|
||||||
if len(re.findall(r"([\w\-]+)", result, re.UNICODE)):
|
if not curr:
|
||||||
results.append(result)
|
continue
|
||||||
|
|
||||||
i = j
|
if curr in ["-", ".", "_"]:
|
||||||
|
prefix = ""
|
||||||
|
suffix = ""
|
||||||
|
|
||||||
return results
|
# Check if the next item is a word, and if so use it as the suffix.
|
||||||
|
# We check for if its a word as we don't want to concatenate
|
||||||
|
# multiple punctuation marks.
|
||||||
|
if next is not None and re.match(r"\w", next):
|
||||||
|
suffix = next
|
||||||
|
i += 1 # We're using next, so we skip it in the outer loop.
|
||||||
|
else:
|
||||||
|
# We want to avoid creating terms like "user-", as we should
|
||||||
|
# strip trailing punctuation.
|
||||||
|
continue
|
||||||
|
|
||||||
|
if prev and re.match(r"\w", prev) and new_results:
|
||||||
|
prefix = new_results[-1]
|
||||||
|
new_results.pop()
|
||||||
|
|
||||||
|
# We might not have a prefix here, but that's fine as we want to
|
||||||
|
# ensure that we don't strip preceding punctuation e.g. '-71'
|
||||||
|
# shouldn't be converted to '71'.
|
||||||
|
|
||||||
|
new_results.append(f"{prefix}{curr}{suffix}")
|
||||||
|
continue
|
||||||
|
elif not re.match(r"\w", curr):
|
||||||
|
# Ignore other punctuation
|
||||||
|
continue
|
||||||
|
|
||||||
|
new_results.append(curr)
|
||||||
|
|
||||||
|
return new_results
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
import logging
|
import logging
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
|
||||||
|
@ -26,9 +27,11 @@ from synapse.api.room_versions import RoomVersions
|
||||||
from synapse.rest import admin
|
from synapse.rest import admin
|
||||||
from synapse.rest.client import knock, login, room
|
from synapse.rest.client import knock, login, room
|
||||||
from synapse.server import HomeServer
|
from synapse.server import HomeServer
|
||||||
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||||
from synapse.types import JsonDict, UserID
|
from synapse.types import JsonDict, UserID
|
||||||
from synapse.util import Clock
|
from synapse.util import Clock
|
||||||
|
|
||||||
|
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||||
from tests.unittest import HomeserverTestCase
|
from tests.unittest import HomeserverTestCase
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -309,7 +312,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
|
||||||
def test_only_newly_left_rooms_show_up(self) -> None:
|
def test_only_newly_left_rooms_show_up(self) -> None:
|
||||||
"""
|
"""
|
||||||
Test that newly_left rooms still show up in the sync response but rooms that
|
Test that newly_left rooms still show up in the sync response but rooms that
|
||||||
were left before the `from_token` don't show up. See condition "1)" comments in
|
were left before the `from_token` don't show up. See condition "2)" comments in
|
||||||
the `get_sync_room_ids_for_user` method.
|
the `get_sync_room_ids_for_user` method.
|
||||||
"""
|
"""
|
||||||
user1_id = self.register_user("user1", "pass")
|
user1_id = self.register_user("user1", "pass")
|
||||||
|
@ -340,7 +343,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
|
||||||
|
|
||||||
def test_no_joins_after_to_token(self) -> None:
|
def test_no_joins_after_to_token(self) -> None:
|
||||||
"""
|
"""
|
||||||
Rooms we join after the `to_token` should *not* show up. See condition "2b)"
|
Rooms we join after the `to_token` should *not* show up. See condition "1b)"
|
||||||
comments in the `get_sync_room_ids_for_user()` method.
|
comments in the `get_sync_room_ids_for_user()` method.
|
||||||
"""
|
"""
|
||||||
user1_id = self.register_user("user1", "pass")
|
user1_id = self.register_user("user1", "pass")
|
||||||
|
@ -369,7 +372,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
|
||||||
def test_join_during_range_and_left_room_after_to_token(self) -> None:
|
def test_join_during_range_and_left_room_after_to_token(self) -> None:
|
||||||
"""
|
"""
|
||||||
Room still shows up if we left the room but were joined during the
|
Room still shows up if we left the room but were joined during the
|
||||||
from_token/to_token. See condition "2a)" comments in the
|
from_token/to_token. See condition "1a)" comments in the
|
||||||
`get_sync_room_ids_for_user()` method.
|
`get_sync_room_ids_for_user()` method.
|
||||||
"""
|
"""
|
||||||
user1_id = self.register_user("user1", "pass")
|
user1_id = self.register_user("user1", "pass")
|
||||||
|
@ -399,7 +402,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
|
||||||
def test_join_before_range_and_left_room_after_to_token(self) -> None:
|
def test_join_before_range_and_left_room_after_to_token(self) -> None:
|
||||||
"""
|
"""
|
||||||
Room still shows up if we left the room but were joined before the `from_token`
|
Room still shows up if we left the room but were joined before the `from_token`
|
||||||
so it should show up. See condition "2a)" comments in the
|
so it should show up. See condition "1a)" comments in the
|
||||||
`get_sync_room_ids_for_user()` method.
|
`get_sync_room_ids_for_user()` method.
|
||||||
"""
|
"""
|
||||||
user1_id = self.register_user("user1", "pass")
|
user1_id = self.register_user("user1", "pass")
|
||||||
|
@ -426,7 +429,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
|
||||||
def test_kicked_before_range_and_left_after_to_token(self) -> None:
|
def test_kicked_before_range_and_left_after_to_token(self) -> None:
|
||||||
"""
|
"""
|
||||||
Room still shows up if we left the room but were kicked before the `from_token`
|
Room still shows up if we left the room but were kicked before the `from_token`
|
||||||
so it should show up. See condition "2a)" comments in the
|
so it should show up. See condition "1a)" comments in the
|
||||||
`get_sync_room_ids_for_user()` method.
|
`get_sync_room_ids_for_user()` method.
|
||||||
"""
|
"""
|
||||||
user1_id = self.register_user("user1", "pass")
|
user1_id = self.register_user("user1", "pass")
|
||||||
|
@ -474,8 +477,8 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
|
||||||
def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None:
|
def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None:
|
||||||
"""
|
"""
|
||||||
Newly left room should show up. But we're also testing that joining and leaving
|
Newly left room should show up. But we're also testing that joining and leaving
|
||||||
after the `to_token` doesn't mess with the results. See condition "2a)" comments
|
after the `to_token` doesn't mess with the results. See condition "2)" and "1a)"
|
||||||
in the `get_sync_room_ids_for_user()` method.
|
comments in the `get_sync_room_ids_for_user()` method.
|
||||||
"""
|
"""
|
||||||
user1_id = self.register_user("user1", "pass")
|
user1_id = self.register_user("user1", "pass")
|
||||||
user1_tok = self.login(user1_id, "pass")
|
user1_tok = self.login(user1_id, "pass")
|
||||||
|
@ -508,10 +511,46 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
|
||||||
# Room should still show up because it's newly_left during the from/to range
|
# Room should still show up because it's newly_left during the from/to range
|
||||||
self.assertEqual(room_id_results, {room_id1})
|
self.assertEqual(room_id_results, {room_id1})
|
||||||
|
|
||||||
|
def test_newly_left_during_range_and_join_after_to_token(self) -> None:
|
||||||
|
"""
|
||||||
|
Newly left room should show up. But we're also testing that joining after the
|
||||||
|
`to_token` doesn't mess with the results. See condition "2)" and "1b)" comments
|
||||||
|
in the `get_sync_room_ids_for_user()` method.
|
||||||
|
"""
|
||||||
|
user1_id = self.register_user("user1", "pass")
|
||||||
|
user1_tok = self.login(user1_id, "pass")
|
||||||
|
user2_id = self.register_user("user2", "pass")
|
||||||
|
user2_tok = self.login(user2_id, "pass")
|
||||||
|
|
||||||
|
before_room1_token = self.event_sources.get_current_token()
|
||||||
|
|
||||||
|
# We create the room with user2 so the room isn't left with no members when we
|
||||||
|
# leave and can still re-join.
|
||||||
|
room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True)
|
||||||
|
# Join and leave the room during the from/to range
|
||||||
|
self.helper.join(room_id1, user1_id, tok=user1_tok)
|
||||||
|
self.helper.leave(room_id1, user1_id, tok=user1_tok)
|
||||||
|
|
||||||
|
after_room1_token = self.event_sources.get_current_token()
|
||||||
|
|
||||||
|
# Join the room after we already have our tokens
|
||||||
|
self.helper.join(room_id1, user1_id, tok=user1_tok)
|
||||||
|
|
||||||
|
room_id_results = self.get_success(
|
||||||
|
self.sliding_sync_handler.get_sync_room_ids_for_user(
|
||||||
|
UserID.from_string(user1_id),
|
||||||
|
from_token=before_room1_token,
|
||||||
|
to_token=after_room1_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Room should still show up because it's newly_left during the from/to range
|
||||||
|
self.assertEqual(room_id_results, {room_id1})
|
||||||
|
|
||||||
def test_leave_before_range_and_join_leave_after_to_token(self) -> None:
|
def test_leave_before_range_and_join_leave_after_to_token(self) -> None:
|
||||||
"""
|
"""
|
||||||
Old left room shouldn't show up. But we're also testing that joining and leaving
|
Old left room shouldn't show up. But we're also testing that joining and leaving
|
||||||
after the `to_token` doesn't mess with the results. See condition "2a)" comments
|
after the `to_token` doesn't mess with the results. See condition "1a)" comments
|
||||||
in the `get_sync_room_ids_for_user()` method.
|
in the `get_sync_room_ids_for_user()` method.
|
||||||
"""
|
"""
|
||||||
user1_id = self.register_user("user1", "pass")
|
user1_id = self.register_user("user1", "pass")
|
||||||
|
@ -546,7 +585,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
|
||||||
def test_leave_before_range_and_join_after_to_token(self) -> None:
|
def test_leave_before_range_and_join_after_to_token(self) -> None:
|
||||||
"""
|
"""
|
||||||
Old left room shouldn't show up. But we're also testing that joining after the
|
Old left room shouldn't show up. But we're also testing that joining after the
|
||||||
`to_token` doesn't mess with the results. See condition "2b)" comments in the
|
`to_token` doesn't mess with the results. See condition "1b)" comments in the
|
||||||
`get_sync_room_ids_for_user()` method.
|
`get_sync_room_ids_for_user()` method.
|
||||||
"""
|
"""
|
||||||
user1_id = self.register_user("user1", "pass")
|
user1_id = self.register_user("user1", "pass")
|
||||||
|
@ -664,7 +703,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Make it look like we joined after the token range but we were invited before the
|
Make it look like we joined after the token range but we were invited before the
|
||||||
from/to range so the room should still show up. See condition "2a)" comments in
|
from/to range so the room should still show up. See condition "1a)" comments in
|
||||||
the `get_sync_room_ids_for_user()` method.
|
the `get_sync_room_ids_for_user()` method.
|
||||||
"""
|
"""
|
||||||
user1_id = self.register_user("user1", "pass")
|
user1_id = self.register_user("user1", "pass")
|
||||||
|
@ -759,6 +798,224 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||||
|
"""
|
||||||
|
Tests Sliding Sync handler `get_sync_room_ids_for_user()` to make sure it works with
|
||||||
|
sharded event stream_writers enabled
|
||||||
|
"""
|
||||||
|
|
||||||
|
servlets = [
|
||||||
|
admin.register_servlets_for_client_rest_resource,
|
||||||
|
room.register_servlets,
|
||||||
|
login.register_servlets,
|
||||||
|
]
|
||||||
|
|
||||||
|
def default_config(self) -> dict:
|
||||||
|
config = super().default_config()
|
||||||
|
# Enable sliding sync
|
||||||
|
config["experimental_features"] = {"msc3575_enabled": True}
|
||||||
|
|
||||||
|
# Enable shared event stream_writers
|
||||||
|
config["stream_writers"] = {"events": ["worker1", "worker2", "worker3"]}
|
||||||
|
config["instance_map"] = {
|
||||||
|
"main": {"host": "testserv", "port": 8765},
|
||||||
|
"worker1": {"host": "testserv", "port": 1001},
|
||||||
|
"worker2": {"host": "testserv", "port": 1002},
|
||||||
|
"worker3": {"host": "testserv", "port": 1003},
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
self.sliding_sync_handler = self.hs.get_sliding_sync_handler()
|
||||||
|
self.store = self.hs.get_datastores().main
|
||||||
|
self.event_sources = hs.get_event_sources()
|
||||||
|
|
||||||
|
def _create_room(self, room_id: str, user_id: str, tok: str) -> None:
|
||||||
|
"""
|
||||||
|
Create a room with a specific room_id. We use this so that that we have a
|
||||||
|
consistent room_id across test runs that hashes to the same value and will be
|
||||||
|
sharded to a known worker in the tests.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# We control the room ID generation by patching out the
|
||||||
|
# `_generate_room_id` method
|
||||||
|
with patch(
|
||||||
|
"synapse.handlers.room.RoomCreationHandler._generate_room_id"
|
||||||
|
) as mock:
|
||||||
|
mock.side_effect = lambda: room_id
|
||||||
|
self.helper.create_room_as(user_id, tok=tok)
|
||||||
|
|
||||||
|
def test_sharded_event_persisters(self) -> None:
|
||||||
|
"""
|
||||||
|
This test should catch bugs that would come from flawed stream position
|
||||||
|
(`stream_ordering`) comparisons or making `RoomStreamToken`'s naively. To
|
||||||
|
compare event positions properly, you need to consider both the `instance_name`
|
||||||
|
and `stream_ordering` together.
|
||||||
|
|
||||||
|
The test creates three event persister workers and a room that is sharded to
|
||||||
|
each worker. On worker2, we make the event stream position stuck so that it lags
|
||||||
|
behind the other workers and we start getting `RoomStreamToken` that have an
|
||||||
|
`instance_map` component (i.e. q`m{min_pos}~{writer1}.{pos1}~{writer2}.{pos2}`).
|
||||||
|
|
||||||
|
We then send some events to advance the stream positions of worker1 and worker3
|
||||||
|
but worker2 is lagging behind because it's stuck. We are specifically testing
|
||||||
|
that `get_sync_room_ids_for_user(from_token=xxx, to_token=xxx)` should work
|
||||||
|
correctly in these adverse conditions.
|
||||||
|
"""
|
||||||
|
user1_id = self.register_user("user1", "pass")
|
||||||
|
user1_tok = self.login(user1_id, "pass")
|
||||||
|
user2_id = self.register_user("user2", "pass")
|
||||||
|
user2_tok = self.login(user2_id, "pass")
|
||||||
|
|
||||||
|
self.make_worker_hs(
|
||||||
|
"synapse.app.generic_worker",
|
||||||
|
{"worker_name": "worker1"},
|
||||||
|
)
|
||||||
|
|
||||||
|
worker_hs2 = self.make_worker_hs(
|
||||||
|
"synapse.app.generic_worker",
|
||||||
|
{"worker_name": "worker2"},
|
||||||
|
)
|
||||||
|
|
||||||
|
self.make_worker_hs(
|
||||||
|
"synapse.app.generic_worker",
|
||||||
|
{"worker_name": "worker3"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Specially crafted room IDs that get persisted on different workers.
|
||||||
|
#
|
||||||
|
# Sharded to worker1
|
||||||
|
room_id1 = "!fooo:test"
|
||||||
|
# Sharded to worker2
|
||||||
|
room_id2 = "!bar:test"
|
||||||
|
# Sharded to worker3
|
||||||
|
room_id3 = "!quux:test"
|
||||||
|
|
||||||
|
# Create rooms on the different workers.
|
||||||
|
self._create_room(room_id1, user2_id, user2_tok)
|
||||||
|
self._create_room(room_id2, user2_id, user2_tok)
|
||||||
|
self._create_room(room_id3, user2_id, user2_tok)
|
||||||
|
join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok)
|
||||||
|
join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok)
|
||||||
|
# Leave room2
|
||||||
|
self.helper.leave(room_id2, user1_id, tok=user1_tok)
|
||||||
|
join_response3 = self.helper.join(room_id3, user1_id, tok=user1_tok)
|
||||||
|
# Leave room3
|
||||||
|
self.helper.leave(room_id3, user1_id, tok=user1_tok)
|
||||||
|
|
||||||
|
# Ensure that the events were sharded to different workers.
|
||||||
|
pos1 = self.get_success(
|
||||||
|
self.store.get_position_for_event(join_response1["event_id"])
|
||||||
|
)
|
||||||
|
self.assertEqual(pos1.instance_name, "worker1")
|
||||||
|
pos2 = self.get_success(
|
||||||
|
self.store.get_position_for_event(join_response2["event_id"])
|
||||||
|
)
|
||||||
|
self.assertEqual(pos2.instance_name, "worker2")
|
||||||
|
pos3 = self.get_success(
|
||||||
|
self.store.get_position_for_event(join_response3["event_id"])
|
||||||
|
)
|
||||||
|
self.assertEqual(pos3.instance_name, "worker3")
|
||||||
|
|
||||||
|
before_stuck_activity_token = self.event_sources.get_current_token()
|
||||||
|
|
||||||
|
# We now gut wrench into the events stream `MultiWriterIdGenerator` on worker2 to
|
||||||
|
# mimic it getting stuck persisting an event. This ensures that when we send an
|
||||||
|
# event on worker1/worker3 we end up in a state where worker2 events stream
|
||||||
|
# position lags that on worker1/worker3, resulting in a RoomStreamToken with a
|
||||||
|
# non-empty `instance_map` component.
|
||||||
|
#
|
||||||
|
# Worker2's event stream position will not advance until we call `__aexit__`
|
||||||
|
# again.
|
||||||
|
worker_store2 = worker_hs2.get_datastores().main
|
||||||
|
assert isinstance(worker_store2._stream_id_gen, MultiWriterIdGenerator)
|
||||||
|
actx = worker_store2._stream_id_gen.get_next()
|
||||||
|
self.get_success(actx.__aenter__())
|
||||||
|
|
||||||
|
# For room_id1/worker1: leave and join the room to advance the stream position
|
||||||
|
# and generate membership changes.
|
||||||
|
self.helper.leave(room_id1, user1_id, tok=user1_tok)
|
||||||
|
self.helper.join(room_id1, user1_id, tok=user1_tok)
|
||||||
|
# For room_id2/worker2: which is currently stuck, join the room.
|
||||||
|
join_on_worker2_response = self.helper.join(room_id2, user1_id, tok=user1_tok)
|
||||||
|
# For room_id3/worker3: leave and join the room to advance the stream position
|
||||||
|
# and generate membership changes.
|
||||||
|
self.helper.leave(room_id3, user1_id, tok=user1_tok)
|
||||||
|
join_on_worker3_response = self.helper.join(room_id3, user1_id, tok=user1_tok)
|
||||||
|
|
||||||
|
# Get a token while things are stuck after our activity
|
||||||
|
stuck_activity_token = self.event_sources.get_current_token()
|
||||||
|
logger.info("stuck_activity_token %s", stuck_activity_token)
|
||||||
|
# Let's make sure we're working with a token that has an `instance_map`
|
||||||
|
self.assertNotEqual(len(stuck_activity_token.room_key.instance_map), 0)
|
||||||
|
|
||||||
|
# Just double check that the join event on worker2 (that is stuck) happened
|
||||||
|
# after the position recorded for worker2 in the token but before the max
|
||||||
|
# position in the token. This is crucial for the behavior we're trying to test.
|
||||||
|
join_on_worker2_pos = self.get_success(
|
||||||
|
self.store.get_position_for_event(join_on_worker2_response["event_id"])
|
||||||
|
)
|
||||||
|
logger.info("join_on_worker2_pos %s", join_on_worker2_pos)
|
||||||
|
# Ensure the join technially came after our token
|
||||||
|
self.assertGreater(
|
||||||
|
join_on_worker2_pos.stream,
|
||||||
|
stuck_activity_token.room_key.get_stream_pos_for_instance("worker2"),
|
||||||
|
)
|
||||||
|
# But less than the max stream position of some other worker
|
||||||
|
self.assertLess(
|
||||||
|
join_on_worker2_pos.stream,
|
||||||
|
# max
|
||||||
|
stuck_activity_token.room_key.get_max_stream_pos(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Just double check that the join event on worker3 happened after the min stream
|
||||||
|
# value in the token but still within the position recorded for worker3. This is
|
||||||
|
# crucial for the behavior we're trying to test.
|
||||||
|
join_on_worker3_pos = self.get_success(
|
||||||
|
self.store.get_position_for_event(join_on_worker3_response["event_id"])
|
||||||
|
)
|
||||||
|
logger.info("join_on_worker3_pos %s", join_on_worker3_pos)
|
||||||
|
# Ensure the join came after the min but still encapsulated by the token
|
||||||
|
self.assertGreaterEqual(
|
||||||
|
join_on_worker3_pos.stream,
|
||||||
|
# min
|
||||||
|
stuck_activity_token.room_key.stream,
|
||||||
|
)
|
||||||
|
self.assertLessEqual(
|
||||||
|
join_on_worker3_pos.stream,
|
||||||
|
stuck_activity_token.room_key.get_stream_pos_for_instance("worker3"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# We finish the fake persisting an event we started above and advance worker2's
|
||||||
|
# event stream position (unstuck worker2).
|
||||||
|
self.get_success(actx.__aexit__(None, None, None))
|
||||||
|
|
||||||
|
# The function under test
|
||||||
|
room_id_results = self.get_success(
|
||||||
|
self.sliding_sync_handler.get_sync_room_ids_for_user(
|
||||||
|
UserID.from_string(user1_id),
|
||||||
|
from_token=before_stuck_activity_token,
|
||||||
|
to_token=stuck_activity_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
room_id_results,
|
||||||
|
{
|
||||||
|
room_id1,
|
||||||
|
# room_id2 shouldn't show up because we left before the from/to range
|
||||||
|
# and the join event during the range happened while worker2 was stuck.
|
||||||
|
# This means that from the perspective of the master, where the
|
||||||
|
# `stuck_activity_token` is generated, the stream position for worker2
|
||||||
|
# wasn't advanced to the join yet. Looking at the `instance_map`, the
|
||||||
|
# join technically comes after `stuck_activity_token``.
|
||||||
|
#
|
||||||
|
# room_id2,
|
||||||
|
room_id3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FilterRoomsTestCase(HomeserverTestCase):
|
class FilterRoomsTestCase(HomeserverTestCase):
|
||||||
"""
|
"""
|
||||||
Tests Sliding Sync handler `filter_rooms()` to make sure it includes/excludes rooms
|
Tests Sliding Sync handler `filter_rooms()` to make sure it includes/excludes rooms
|
||||||
|
|
|
@ -1061,6 +1061,45 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||||
{alice: ProfileInfo(display_name=None, avatar_url=MXC_DUMMY)},
|
{alice: ProfileInfo(display_name=None, avatar_url=MXC_DUMMY)},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_search_punctuation(self) -> None:
|
||||||
|
"""Test that you can search for a user that includes punctuation"""
|
||||||
|
|
||||||
|
searching_user = self.register_user("searcher", "password")
|
||||||
|
searching_user_tok = self.login("searcher", "password")
|
||||||
|
|
||||||
|
room_id = self.helper.create_room_as(
|
||||||
|
searching_user,
|
||||||
|
room_version=RoomVersions.V1.identifier,
|
||||||
|
tok=searching_user_tok,
|
||||||
|
)
|
||||||
|
|
||||||
|
# We want to test searching for users of the form e.g. "user-1", with
|
||||||
|
# various punctuation. We also test both where the prefix is numeric and
|
||||||
|
# alphanumeric, as e.g. postgres tokenises "user-1" as "user" and "-1".
|
||||||
|
i = 1
|
||||||
|
for char in ["-", ".", "_"]:
|
||||||
|
for use_numeric in [False, True]:
|
||||||
|
if use_numeric:
|
||||||
|
prefix1 = f"{i}"
|
||||||
|
prefix2 = f"{i+1}"
|
||||||
|
else:
|
||||||
|
prefix1 = f"a{i}"
|
||||||
|
prefix2 = f"a{i+1}"
|
||||||
|
|
||||||
|
local_user_1 = self.register_user(f"user{char}{prefix1}", "password")
|
||||||
|
local_user_2 = self.register_user(f"user{char}{prefix2}", "password")
|
||||||
|
|
||||||
|
self._add_user_to_room(room_id, RoomVersions.V1, local_user_1)
|
||||||
|
self._add_user_to_room(room_id, RoomVersions.V1, local_user_2)
|
||||||
|
|
||||||
|
results = self.get_success(
|
||||||
|
self.handler.search_users(searching_user, local_user_1, 20)
|
||||||
|
)["results"]
|
||||||
|
received_user_id_ordering = [result["user_id"] for result in results]
|
||||||
|
self.assertSequenceEqual(received_user_id_ordering[:1], [local_user_1])
|
||||||
|
|
||||||
|
i += 2
|
||||||
|
|
||||||
|
|
||||||
class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
|
class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
|
||||||
servlets = [
|
servlets = [
|
||||||
|
|
|
@ -25,7 +25,7 @@ import tempfile
|
||||||
from binascii import unhexlify
|
from binascii import unhexlify
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
|
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Tuple, Union
|
||||||
from unittest.mock import Mock
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
|
@ -37,9 +37,12 @@ from twisted.internet import defer
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred
|
||||||
from twisted.python.failure import Failure
|
from twisted.python.failure import Failure
|
||||||
from twisted.test.proto_helpers import MemoryReactor
|
from twisted.test.proto_helpers import MemoryReactor
|
||||||
|
from twisted.web.http_headers import Headers
|
||||||
|
from twisted.web.iweb import UNKNOWN_LENGTH, IResponse
|
||||||
from twisted.web.resource import Resource
|
from twisted.web.resource import Resource
|
||||||
|
|
||||||
from synapse.api.errors import Codes, HttpResponseException
|
from synapse.api.errors import Codes, HttpResponseException
|
||||||
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.http.types import QueryParams
|
from synapse.http.types import QueryParams
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
|
@ -59,6 +62,7 @@ from synapse.util import Clock
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
from tests.server import FakeChannel
|
from tests.server import FakeChannel
|
||||||
from tests.test_utils import SMALL_PNG
|
from tests.test_utils import SMALL_PNG
|
||||||
|
from tests.unittest import override_config
|
||||||
from tests.utils import default_config
|
from tests.utils import default_config
|
||||||
|
|
||||||
|
|
||||||
|
@ -251,9 +255,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||||
destination: str,
|
destination: str,
|
||||||
path: str,
|
path: str,
|
||||||
output_stream: BinaryIO,
|
output_stream: BinaryIO,
|
||||||
|
download_ratelimiter: Ratelimiter,
|
||||||
|
ip_address: Any,
|
||||||
|
max_size: int,
|
||||||
args: Optional[QueryParams] = None,
|
args: Optional[QueryParams] = None,
|
||||||
retry_on_dns_fail: bool = True,
|
retry_on_dns_fail: bool = True,
|
||||||
max_size: Optional[int] = None,
|
|
||||||
ignore_backoff: bool = False,
|
ignore_backoff: bool = False,
|
||||||
follow_redirects: bool = False,
|
follow_redirects: bool = False,
|
||||||
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
|
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
|
||||||
|
@ -878,3 +884,218 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
|
||||||
tok=self.tok,
|
tok=self.tok,
|
||||||
expect_code=400,
|
expect_code=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase):
|
||||||
|
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||||
|
config = self.default_config()
|
||||||
|
|
||||||
|
self.storage_path = self.mktemp()
|
||||||
|
self.media_store_path = self.mktemp()
|
||||||
|
os.mkdir(self.storage_path)
|
||||||
|
os.mkdir(self.media_store_path)
|
||||||
|
config["media_store_path"] = self.media_store_path
|
||||||
|
|
||||||
|
provider_config = {
|
||||||
|
"module": "synapse.media.storage_provider.FileStorageProviderBackend",
|
||||||
|
"store_local": True,
|
||||||
|
"store_synchronous": False,
|
||||||
|
"store_remote": True,
|
||||||
|
"config": {"directory": self.storage_path},
|
||||||
|
}
|
||||||
|
|
||||||
|
config["media_storage_providers"] = [provider_config]
|
||||||
|
|
||||||
|
return self.setup_test_homeserver(config=config)
|
||||||
|
|
||||||
|
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||||
|
self.repo = hs.get_media_repository()
|
||||||
|
self.client = hs.get_federation_http_client()
|
||||||
|
self.store = hs.get_datastores().main
|
||||||
|
|
||||||
|
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||||
|
# We need to manually set the resource tree to include media, the
|
||||||
|
# default only does `/_matrix/client` APIs.
|
||||||
|
return {"/_matrix/media": self.hs.get_media_repository_resource()}
|
||||||
|
|
||||||
|
# mock actually reading file body
|
||||||
|
def read_body_with_max_size_30MiB(*args: Any, **kwargs: Any) -> Deferred:
|
||||||
|
d: Deferred = defer.Deferred()
|
||||||
|
d.callback(31457280)
|
||||||
|
return d
|
||||||
|
|
||||||
|
def read_body_with_max_size_50MiB(*args: Any, **kwargs: Any) -> Deferred:
|
||||||
|
d: Deferred = defer.Deferred()
|
||||||
|
d.callback(52428800)
|
||||||
|
return d
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"synapse.http.matrixfederationclient.read_body_with_max_size",
|
||||||
|
read_body_with_max_size_30MiB,
|
||||||
|
)
|
||||||
|
def test_download_ratelimit_default(self) -> None:
|
||||||
|
"""
|
||||||
|
Test remote media download ratelimiting against default configuration - 500MB bucket
|
||||||
|
and 87kb/second drain rate
|
||||||
|
"""
|
||||||
|
|
||||||
|
# mock out actually sending the request, returns a 30MiB response
|
||||||
|
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
|
||||||
|
resp = MagicMock(spec=IResponse)
|
||||||
|
resp.code = 200
|
||||||
|
resp.length = 31457280
|
||||||
|
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
|
||||||
|
resp.phrase = b"OK"
|
||||||
|
return resp
|
||||||
|
|
||||||
|
self.client._send_request = _send_request # type: ignore
|
||||||
|
|
||||||
|
# first request should go through
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel.code == 200
|
||||||
|
|
||||||
|
# next 15 should go through
|
||||||
|
for i in range(15):
|
||||||
|
channel2 = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel2.code == 200
|
||||||
|
|
||||||
|
# 17th will hit ratelimit
|
||||||
|
channel3 = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel3.code == 429
|
||||||
|
|
||||||
|
# however, a request from a different IP will go through
|
||||||
|
channel4 = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
|
||||||
|
shorthand=False,
|
||||||
|
client_ip="187.233.230.159",
|
||||||
|
)
|
||||||
|
assert channel4.code == 200
|
||||||
|
|
||||||
|
# at 87Kib/s it should take about 2 minutes for enough to drain from bucket that another
|
||||||
|
# 30MiB download is authorized - The last download was blocked at 503,316,480.
|
||||||
|
# The next download will be authorized when bucket hits 492,830,720
|
||||||
|
# (524,288,000 total capacity - 31,457,280 download size) so 503,316,480 - 492,830,720 ~= 10,485,760
|
||||||
|
# needs to drain before another download will be authorized, that will take ~=
|
||||||
|
# 2 minutes (10,485,760/89,088/60)
|
||||||
|
self.reactor.pump([2.0 * 60.0])
|
||||||
|
|
||||||
|
# enough has drained and next request goes through
|
||||||
|
channel5 = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyb",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel5.code == 200
|
||||||
|
|
||||||
|
@override_config(
|
||||||
|
{
|
||||||
|
"remote_media_download_per_second": "50M",
|
||||||
|
"remote_media_download_burst_count": "50M",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
@patch(
|
||||||
|
"synapse.http.matrixfederationclient.read_body_with_max_size",
|
||||||
|
read_body_with_max_size_50MiB,
|
||||||
|
)
|
||||||
|
def test_download_rate_limit_config(self) -> None:
|
||||||
|
"""
|
||||||
|
Test that download rate limit config options are correctly picked up and applied
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
|
||||||
|
resp = MagicMock(spec=IResponse)
|
||||||
|
resp.code = 200
|
||||||
|
resp.length = 52428800
|
||||||
|
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
|
||||||
|
resp.phrase = b"OK"
|
||||||
|
return resp
|
||||||
|
|
||||||
|
self.client._send_request = _send_request # type: ignore
|
||||||
|
|
||||||
|
# first request should go through
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyz",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel.code == 200
|
||||||
|
|
||||||
|
# immediate second request should fail
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy1",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel.code == 429
|
||||||
|
|
||||||
|
# advance half a second
|
||||||
|
self.reactor.pump([0.5])
|
||||||
|
|
||||||
|
# request still fails
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy2",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel.code == 429
|
||||||
|
|
||||||
|
# advance another half second
|
||||||
|
self.reactor.pump([0.5])
|
||||||
|
|
||||||
|
# enough has drained from bucket and request is successful
|
||||||
|
channel = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy3",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel.code == 200
|
||||||
|
|
||||||
|
@patch(
|
||||||
|
"synapse.http.matrixfederationclient.read_body_with_max_size",
|
||||||
|
read_body_with_max_size_30MiB,
|
||||||
|
)
|
||||||
|
def test_download_ratelimit_max_size_sub(self) -> None:
|
||||||
|
"""
|
||||||
|
Test that if no content-length is provided, the default max size is applied instead
|
||||||
|
"""
|
||||||
|
|
||||||
|
# mock out actually sending the request
|
||||||
|
async def _send_request(*args: Any, **kwargs: Any) -> IResponse:
|
||||||
|
resp = MagicMock(spec=IResponse)
|
||||||
|
resp.code = 200
|
||||||
|
resp.length = UNKNOWN_LENGTH
|
||||||
|
resp.headers = Headers({"Content-Type": ["application/octet-stream"]})
|
||||||
|
resp.phrase = b"OK"
|
||||||
|
return resp
|
||||||
|
|
||||||
|
self.client._send_request = _send_request # type: ignore
|
||||||
|
|
||||||
|
# ten requests should go through using the max size (500MB/50MB)
|
||||||
|
for i in range(10):
|
||||||
|
channel2 = self.make_request(
|
||||||
|
"GET",
|
||||||
|
f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel2.code == 200
|
||||||
|
|
||||||
|
# eleventh will hit ratelimit
|
||||||
|
channel3 = self.make_request(
|
||||||
|
"GET",
|
||||||
|
"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx",
|
||||||
|
shorthand=False,
|
||||||
|
)
|
||||||
|
assert channel3.code == 429
|
||||||
|
|
|
@ -711,6 +711,10 @@ class UserDirectoryICUTestCase(HomeserverTestCase):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.assertEqual(_parse_words_with_icu("user-1"), ["user-1"])
|
||||||
|
self.assertEqual(_parse_words_with_icu("user-ab"), ["user-ab"])
|
||||||
|
self.assertEqual(_parse_words_with_icu("user.--1"), ["user", "-1"])
|
||||||
|
|
||||||
def test_regex_word_boundary_punctuation(self) -> None:
|
def test_regex_word_boundary_punctuation(self) -> None:
|
||||||
"""
|
"""
|
||||||
Tests the behaviour of punctuation with the non-ICU tokeniser
|
Tests the behaviour of punctuation with the non-ICU tokeniser
|
||||||
|
|
Loading…
Reference in a new issue