Describe which rate limiter was hit in logs (#16135)

This commit is contained in:
David Robertson 2023-08-30 00:39:39 +01:00 committed by GitHub
parent e9235d92f2
commit 62a1a9be52
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 235 additions and 121 deletions

1
changelog.d/16135.misc Normal file
View file

@ -0,0 +1 @@
Describe which rate limiter was hit in logs.

View file

@ -211,6 +211,11 @@ class SynapseError(CodeMessageException):
def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, **self._additional_fields) return cs_error(self.msg, self.errcode, **self._additional_fields)
@property
def debug_context(self) -> Optional[str]:
"""Override this to add debugging context that shouldn't be sent to clients."""
return None
class InvalidAPICallError(SynapseError): class InvalidAPICallError(SynapseError):
"""You called an existing API endpoint, but fed that endpoint """You called an existing API endpoint, but fed that endpoint
@ -508,8 +513,8 @@ class LimitExceededError(SynapseError):
def __init__( def __init__(
self, self,
limiter_name: str,
code: int = 429, code: int = 429,
msg: str = "Too Many Requests",
retry_after_ms: Optional[int] = None, retry_after_ms: Optional[int] = None,
errcode: str = Codes.LIMIT_EXCEEDED, errcode: str = Codes.LIMIT_EXCEEDED,
): ):
@ -518,12 +523,17 @@ class LimitExceededError(SynapseError):
if self.include_retry_after_header and retry_after_ms is not None if self.include_retry_after_header and retry_after_ms is not None
else None else None
) )
super().__init__(code, msg, errcode, headers=headers) super().__init__(code, "Too Many Requests", errcode, headers=headers)
self.retry_after_ms = retry_after_ms self.retry_after_ms = retry_after_ms
self.limiter_name = limiter_name
def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict": def error_dict(self, config: Optional["HomeServerConfig"]) -> "JsonDict":
return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms) return cs_error(self.msg, self.errcode, retry_after_ms=self.retry_after_ms)
@property
def debug_context(self) -> Optional[str]:
return self.limiter_name
class RoomKeysVersionError(SynapseError): class RoomKeysVersionError(SynapseError):
"""A client has tried to upload to a non-current version of the room_keys store""" """A client has tried to upload to a non-current version of the room_keys store"""

View file

@ -61,12 +61,16 @@ class Ratelimiter:
""" """
def __init__( def __init__(
self, store: DataStore, clock: Clock, rate_hz: float, burst_count: int self,
store: DataStore,
clock: Clock,
cfg: RatelimitSettings,
): ):
self.clock = clock self.clock = clock
self.rate_hz = rate_hz self.rate_hz = cfg.per_second
self.burst_count = burst_count self.burst_count = cfg.burst_count
self.store = store self.store = store
self._limiter_name = cfg.key
# An ordered dictionary representing the token buckets tracked by this rate # An ordered dictionary representing the token buckets tracked by this rate
# limiter. Each entry maps a key of arbitrary type to a tuple representing: # limiter. Each entry maps a key of arbitrary type to a tuple representing:
@ -305,7 +309,8 @@ class Ratelimiter:
if not allowed: if not allowed:
raise LimitExceededError( raise LimitExceededError(
retry_after_ms=int(1000 * (time_allowed - time_now_s)) limiter_name=self._limiter_name,
retry_after_ms=int(1000 * (time_allowed - time_now_s)),
) )
@ -322,7 +327,9 @@ class RequestRatelimiter:
# The rate_hz and burst_count are overridden on a per-user basis # The rate_hz and burst_count are overridden on a per-user basis
self.request_ratelimiter = Ratelimiter( self.request_ratelimiter = Ratelimiter(
store=self.store, clock=self.clock, rate_hz=0, burst_count=0 store=self.store,
clock=self.clock,
cfg=RatelimitSettings(key=rc_message.key, per_second=0, burst_count=0),
) )
self._rc_message = rc_message self._rc_message = rc_message
@ -332,8 +339,7 @@ class RequestRatelimiter:
self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter( self.admin_redaction_ratelimiter: Optional[Ratelimiter] = Ratelimiter(
store=self.store, store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=rc_admin_redaction.per_second, cfg=rc_admin_redaction,
burst_count=rc_admin_redaction.burst_count,
) )
else: else:
self.admin_redaction_ratelimiter = None self.admin_redaction_ratelimiter = None

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional from typing import Any, Dict, Optional, cast
import attr import attr
@ -21,16 +21,47 @@ from synapse.types import JsonDict
from ._base import Config from ._base import Config
@attr.s(slots=True, frozen=True, auto_attribs=True)
class RatelimitSettings: class RatelimitSettings:
def __init__( key: str
self, per_second: float
config: Dict[str, float], burst_count: int
@classmethod
def parse(
cls,
config: Dict[str, Any],
key: str,
defaults: Optional[Dict[str, float]] = None, defaults: Optional[Dict[str, float]] = None,
): ) -> "RatelimitSettings":
"""Parse config[key] as a new-style rate limiter config.
The key may refer to a nested dictionary using a full stop (.) to separate
each nested key. For example, use the key "a.b.c" to parse the following:
a:
b:
c:
per_second: 10
burst_count: 200
If this lookup fails, we'll fallback to the defaults.
"""
defaults = defaults or {"per_second": 0.17, "burst_count": 3.0} defaults = defaults or {"per_second": 0.17, "burst_count": 3.0}
self.per_second = config.get("per_second", defaults["per_second"]) rl_config = config
self.burst_count = int(config.get("burst_count", defaults["burst_count"])) for part in key.split("."):
rl_config = rl_config.get(part, {})
# By this point we should have hit the rate limiter parameters.
# We don't actually check this though!
rl_config = cast(Dict[str, float], rl_config)
return cls(
key=key,
per_second=rl_config.get("per_second", defaults["per_second"]),
burst_count=int(rl_config.get("burst_count", defaults["burst_count"])),
)
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
@ -49,15 +80,14 @@ class RatelimitConfig(Config):
# Load the new-style messages config if it exists. Otherwise fall back # Load the new-style messages config if it exists. Otherwise fall back
# to the old method. # to the old method.
if "rc_message" in config: if "rc_message" in config:
self.rc_message = RatelimitSettings( self.rc_message = RatelimitSettings.parse(
config["rc_message"], defaults={"per_second": 0.2, "burst_count": 10.0} config, "rc_message", defaults={"per_second": 0.2, "burst_count": 10.0}
) )
else: else:
self.rc_message = RatelimitSettings( self.rc_message = RatelimitSettings(
{ key="rc_messages",
"per_second": config.get("rc_messages_per_second", 0.2), per_second=config.get("rc_messages_per_second", 0.2),
"burst_count": config.get("rc_message_burst_count", 10.0), burst_count=config.get("rc_message_burst_count", 10.0),
}
) )
# Load the new-style federation config, if it exists. Otherwise, fall # Load the new-style federation config, if it exists. Otherwise, fall
@ -79,51 +109,59 @@ class RatelimitConfig(Config):
} }
) )
self.rc_registration = RatelimitSettings(config.get("rc_registration", {})) self.rc_registration = RatelimitSettings.parse(config, "rc_registration", {})
self.rc_registration_token_validity = RatelimitSettings( self.rc_registration_token_validity = RatelimitSettings.parse(
config.get("rc_registration_token_validity", {}), config,
"rc_registration_token_validity",
defaults={"per_second": 0.1, "burst_count": 5}, defaults={"per_second": 0.1, "burst_count": 5},
) )
# It is reasonable to login with a bunch of devices at once (i.e. when # It is reasonable to login with a bunch of devices at once (i.e. when
# setting up an account), but it is *not* valid to continually be # setting up an account), but it is *not* valid to continually be
# logging into new devices. # logging into new devices.
rc_login_config = config.get("rc_login", {}) self.rc_login_address = RatelimitSettings.parse(
self.rc_login_address = RatelimitSettings( config,
rc_login_config.get("address", {}), "rc_login.address",
defaults={"per_second": 0.003, "burst_count": 5}, defaults={"per_second": 0.003, "burst_count": 5},
) )
self.rc_login_account = RatelimitSettings( self.rc_login_account = RatelimitSettings.parse(
rc_login_config.get("account", {}), config,
"rc_login.account",
defaults={"per_second": 0.003, "burst_count": 5}, defaults={"per_second": 0.003, "burst_count": 5},
) )
self.rc_login_failed_attempts = RatelimitSettings( self.rc_login_failed_attempts = RatelimitSettings.parse(
rc_login_config.get("failed_attempts", {}) config,
"rc_login.failed_attempts",
{},
) )
self.federation_rr_transactions_per_room_per_second = config.get( self.federation_rr_transactions_per_room_per_second = config.get(
"federation_rr_transactions_per_room_per_second", 50 "federation_rr_transactions_per_room_per_second", 50
) )
rc_admin_redaction = config.get("rc_admin_redaction")
self.rc_admin_redaction = None self.rc_admin_redaction = None
if rc_admin_redaction: if "rc_admin_redaction" in config:
self.rc_admin_redaction = RatelimitSettings(rc_admin_redaction) self.rc_admin_redaction = RatelimitSettings.parse(
config, "rc_admin_redaction", {}
)
self.rc_joins_local = RatelimitSettings( self.rc_joins_local = RatelimitSettings.parse(
config.get("rc_joins", {}).get("local", {}), config,
"rc_joins.local",
defaults={"per_second": 0.1, "burst_count": 10}, defaults={"per_second": 0.1, "burst_count": 10},
) )
self.rc_joins_remote = RatelimitSettings( self.rc_joins_remote = RatelimitSettings.parse(
config.get("rc_joins", {}).get("remote", {}), config,
"rc_joins.remote",
defaults={"per_second": 0.01, "burst_count": 10}, defaults={"per_second": 0.01, "burst_count": 10},
) )
# Track the rate of joins to a given room. If there are too many, temporarily # Track the rate of joins to a given room. If there are too many, temporarily
# prevent local joins and remote joins via this server. # prevent local joins and remote joins via this server.
self.rc_joins_per_room = RatelimitSettings( self.rc_joins_per_room = RatelimitSettings.parse(
config.get("rc_joins_per_room", {}), config,
"rc_joins_per_room",
defaults={"per_second": 1, "burst_count": 10}, defaults={"per_second": 1, "burst_count": 10},
) )
@ -132,31 +170,37 @@ class RatelimitConfig(Config):
# * For requests received over federation this is keyed by the origin. # * For requests received over federation this is keyed by the origin.
# #
# Note that this isn't exposed in the configuration as it is obscure. # Note that this isn't exposed in the configuration as it is obscure.
self.rc_key_requests = RatelimitSettings( self.rc_key_requests = RatelimitSettings.parse(
config.get("rc_key_requests", {}), config,
"rc_key_requests",
defaults={"per_second": 20, "burst_count": 100}, defaults={"per_second": 20, "burst_count": 100},
) )
self.rc_3pid_validation = RatelimitSettings( self.rc_3pid_validation = RatelimitSettings.parse(
config.get("rc_3pid_validation") or {}, config,
"rc_3pid_validation",
defaults={"per_second": 0.003, "burst_count": 5}, defaults={"per_second": 0.003, "burst_count": 5},
) )
self.rc_invites_per_room = RatelimitSettings( self.rc_invites_per_room = RatelimitSettings.parse(
config.get("rc_invites", {}).get("per_room", {}), config,
"rc_invites.per_room",
defaults={"per_second": 0.3, "burst_count": 10}, defaults={"per_second": 0.3, "burst_count": 10},
) )
self.rc_invites_per_user = RatelimitSettings( self.rc_invites_per_user = RatelimitSettings.parse(
config.get("rc_invites", {}).get("per_user", {}), config,
"rc_invites.per_user",
defaults={"per_second": 0.003, "burst_count": 5}, defaults={"per_second": 0.003, "burst_count": 5},
) )
self.rc_invites_per_issuer = RatelimitSettings( self.rc_invites_per_issuer = RatelimitSettings.parse(
config.get("rc_invites", {}).get("per_issuer", {}), config,
"rc_invites.per_issuer",
defaults={"per_second": 0.3, "burst_count": 10}, defaults={"per_second": 0.3, "burst_count": 10},
) )
self.rc_third_party_invite = RatelimitSettings( self.rc_third_party_invite = RatelimitSettings.parse(
config.get("rc_third_party_invite", {}), config,
"rc_third_party_invite",
defaults={"per_second": 0.0025, "burst_count": 5}, defaults={"per_second": 0.0025, "burst_count": 5},
) )

View file

@ -218,19 +218,17 @@ class AuthHandler:
self._failed_uia_attempts_ratelimiter = Ratelimiter( self._failed_uia_attempts_ratelimiter = Ratelimiter(
store=self.store, store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, cfg=self.hs.config.ratelimiting.rc_login_failed_attempts,
burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count,
) )
# The number of seconds to keep a UI auth session active. # The number of seconds to keep a UI auth session active.
self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout self._ui_auth_session_timeout = hs.config.auth.ui_auth_session_timeout
# Ratelimitier for failed /login attempts # Ratelimiter for failed /login attempts
self._failed_login_attempts_ratelimiter = Ratelimiter( self._failed_login_attempts_ratelimiter = Ratelimiter(
store=self.store, store=self.store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=self.hs.config.ratelimiting.rc_login_failed_attempts.per_second, cfg=self.hs.config.ratelimiting.rc_login_failed_attempts,
burst_count=self.hs.config.ratelimiting.rc_login_failed_attempts.burst_count,
) )
self._clock = self.hs.get_clock() self._clock = self.hs.get_clock()

View file

@ -90,8 +90,7 @@ class DeviceMessageHandler:
self._ratelimiter = Ratelimiter( self._ratelimiter = Ratelimiter(
store=self.store, store=self.store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_key_requests.per_second, cfg=hs.config.ratelimiting.rc_key_requests,
burst_count=hs.config.ratelimiting.rc_key_requests.burst_count,
) )
async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None: async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:

View file

@ -66,14 +66,12 @@ class IdentityHandler:
self._3pid_validation_ratelimiter_ip = Ratelimiter( self._3pid_validation_ratelimiter_ip = Ratelimiter(
store=self.store, store=self.store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, cfg=hs.config.ratelimiting.rc_3pid_validation,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
) )
self._3pid_validation_ratelimiter_address = Ratelimiter( self._3pid_validation_ratelimiter_address = Ratelimiter(
store=self.store, store=self.store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_3pid_validation.per_second, cfg=hs.config.ratelimiting.rc_3pid_validation,
burst_count=hs.config.ratelimiting.rc_3pid_validation.burst_count,
) )
async def ratelimit_request_token_requests( async def ratelimit_request_token_requests(

View file

@ -112,8 +112,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self._join_rate_limiter_local = Ratelimiter( self._join_rate_limiter_local = Ratelimiter(
store=self.store, store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second, cfg=hs.config.ratelimiting.rc_joins_local,
burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
) )
# Tracks joins from local users to rooms this server isn't a member of. # Tracks joins from local users to rooms this server isn't a member of.
# I.e. joins this server makes by requesting /make_join /send_join from # I.e. joins this server makes by requesting /make_join /send_join from
@ -121,8 +120,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self._join_rate_limiter_remote = Ratelimiter( self._join_rate_limiter_remote = Ratelimiter(
store=self.store, store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second, cfg=hs.config.ratelimiting.rc_joins_remote,
burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
) )
# TODO: find a better place to keep this Ratelimiter. # TODO: find a better place to keep this Ratelimiter.
# It needs to be # It needs to be
@ -135,8 +133,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self._join_rate_per_room_limiter = Ratelimiter( self._join_rate_per_room_limiter = Ratelimiter(
store=self.store, store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second, cfg=hs.config.ratelimiting.rc_joins_per_room,
burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count,
) )
# Ratelimiter for invites, keyed by room (across all issuers, all # Ratelimiter for invites, keyed by room (across all issuers, all
@ -144,8 +141,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self._invites_per_room_limiter = Ratelimiter( self._invites_per_room_limiter = Ratelimiter(
store=self.store, store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_room.per_second, cfg=hs.config.ratelimiting.rc_invites_per_room,
burst_count=hs.config.ratelimiting.rc_invites_per_room.burst_count,
) )
# Ratelimiter for invites, keyed by recipient (across all rooms, all # Ratelimiter for invites, keyed by recipient (across all rooms, all
@ -153,8 +149,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self._invites_per_recipient_limiter = Ratelimiter( self._invites_per_recipient_limiter = Ratelimiter(
store=self.store, store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_user.per_second, cfg=hs.config.ratelimiting.rc_invites_per_user,
burst_count=hs.config.ratelimiting.rc_invites_per_user.burst_count,
) )
# Ratelimiter for invites, keyed by issuer (across all rooms, all # Ratelimiter for invites, keyed by issuer (across all rooms, all
@ -162,15 +157,13 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self._invites_per_issuer_limiter = Ratelimiter( self._invites_per_issuer_limiter = Ratelimiter(
store=self.store, store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_invites_per_issuer.per_second, cfg=hs.config.ratelimiting.rc_invites_per_issuer,
burst_count=hs.config.ratelimiting.rc_invites_per_issuer.burst_count,
) )
self._third_party_invite_limiter = Ratelimiter( self._third_party_invite_limiter = Ratelimiter(
store=self.store, store=self.store,
clock=self.clock, clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_third_party_invite.per_second, cfg=hs.config.ratelimiting.rc_third_party_invite,
burst_count=hs.config.ratelimiting.rc_third_party_invite.burst_count,
) )
self.request_ratelimiter = hs.get_request_ratelimiter() self.request_ratelimiter = hs.get_request_ratelimiter()

View file

@ -35,6 +35,7 @@ from synapse.api.errors import (
UnsupportedRoomVersionError, UnsupportedRoomVersionError,
) )
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.config.ratelimiting import RatelimitSettings
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import JsonDict, Requester, StrCollection from synapse.types import JsonDict, Requester, StrCollection
from synapse.util.caches.response_cache import ResponseCache from synapse.util.caches.response_cache import ResponseCache
@ -94,7 +95,9 @@ class RoomSummaryHandler:
self._server_name = hs.hostname self._server_name = hs.hostname
self._federation_client = hs.get_federation_client() self._federation_client = hs.get_federation_client()
self._ratelimiter = Ratelimiter( self._ratelimiter = Ratelimiter(
store=self._store, clock=hs.get_clock(), rate_hz=5, burst_count=10 store=self._store,
clock=hs.get_clock(),
cfg=RatelimitSettings("<room summary>", per_second=5, burst_count=10),
) )
# If a user tries to fetch the same page multiple times in quick succession, # If a user tries to fetch the same page multiple times in quick succession,

View file

@ -115,7 +115,13 @@ def return_json_error(
if exc.headers is not None: if exc.headers is not None:
for header, value in exc.headers.items(): for header, value in exc.headers.items():
request.setHeader(header, value) request.setHeader(header, value)
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg) error_ctx = exc.debug_context
if error_ctx:
logger.info(
"%s SynapseError: %s - %s (%s)", request, error_code, exc.msg, error_ctx
)
else:
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
elif f.check(CancelledError): elif f.check(CancelledError):
error_code = HTTP_STATUS_REQUEST_CANCELLED error_code = HTTP_STATUS_REQUEST_CANCELLED
error_dict = {"error": "Request cancelled", "errcode": Codes.UNKNOWN} error_dict = {"error": "Request cancelled", "errcode": Codes.UNKNOWN}

View file

@ -120,14 +120,12 @@ class LoginRestServlet(RestServlet):
self._address_ratelimiter = Ratelimiter( self._address_ratelimiter = Ratelimiter(
store=self._main_store, store=self._main_store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=self.hs.config.ratelimiting.rc_login_address.per_second, cfg=self.hs.config.ratelimiting.rc_login_address,
burst_count=self.hs.config.ratelimiting.rc_login_address.burst_count,
) )
self._account_ratelimiter = Ratelimiter( self._account_ratelimiter = Ratelimiter(
store=self._main_store, store=self._main_store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=self.hs.config.ratelimiting.rc_login_account.per_second, cfg=self.hs.config.ratelimiting.rc_login_account,
burst_count=self.hs.config.ratelimiting.rc_login_account.burst_count,
) )
# ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.

View file

@ -16,6 +16,7 @@ import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.config.ratelimiting import RatelimitSettings
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@ -66,15 +67,18 @@ class LoginTokenRequestServlet(RestServlet):
self.token_timeout = hs.config.auth.login_via_existing_token_timeout self.token_timeout = hs.config.auth.login_via_existing_token_timeout
self._require_ui_auth = hs.config.auth.login_via_existing_require_ui_auth self._require_ui_auth = hs.config.auth.login_via_existing_require_ui_auth
# Ratelimit aggressively to a maxmimum of 1 request per minute. # Ratelimit aggressively to a maximum of 1 request per minute.
# #
# This endpoint can be used to spawn additional sessions and could be # This endpoint can be used to spawn additional sessions and could be
# abused by a malicious client to create many sessions. # abused by a malicious client to create many sessions.
self._ratelimiter = Ratelimiter( self._ratelimiter = Ratelimiter(
store=self._main_store, store=self._main_store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=1 / 60, cfg=RatelimitSettings(
burst_count=1, key="<login token request>",
per_second=1 / 60,
burst_count=1,
),
) )
@interactive_auth_handler @interactive_auth_handler

View file

@ -376,8 +376,7 @@ class RegistrationTokenValidityRestServlet(RestServlet):
self.ratelimiter = Ratelimiter( self.ratelimiter = Ratelimiter(
store=self.store, store=self.store,
clock=hs.get_clock(), clock=hs.get_clock(),
rate_hz=hs.config.ratelimiting.rc_registration_token_validity.per_second, cfg=hs.config.ratelimiting.rc_registration_token_validity,
burst_count=hs.config.ratelimiting.rc_registration_token_validity.burst_count,
) )
async def on_GET(self, request: Request) -> Tuple[int, JsonDict]: async def on_GET(self, request: Request) -> Tuple[int, JsonDict]:

View file

@ -408,8 +408,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return Ratelimiter( return Ratelimiter(
store=self.get_datastores().main, store=self.get_datastores().main,
clock=self.get_clock(), clock=self.get_clock(),
rate_hz=self.config.ratelimiting.rc_registration.per_second, cfg=self.config.ratelimiting.rc_registration,
burst_count=self.config.ratelimiting.rc_registration.burst_count,
) )
@cache_in_self @cache_in_self

View file

@ -291,7 +291,8 @@ class _PerHostRatelimiter:
if self.metrics_name: if self.metrics_name:
rate_limit_reject_counter.labels(self.metrics_name).inc() rate_limit_reject_counter.labels(self.metrics_name).inc()
raise LimitExceededError( raise LimitExceededError(
retry_after_ms=int(self.window_size / self.sleep_limit) limiter_name="rc_federation",
retry_after_ms=int(self.window_size / self.sleep_limit),
) )
self.request_times.append(time_now) self.request_times.append(time_now)

View file

@ -1,6 +1,5 @@
# Copyright 2023 The Matrix.org Foundation C.I.C. # Copyright 2023 The Matrix.org Foundation C.I.C.
# #
#
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
# You may obtain a copy of the License at # You may obtain a copy of the License at
@ -13,24 +12,32 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import json
from synapse.api.errors import LimitExceededError from synapse.api.errors import LimitExceededError
from tests import unittest from tests import unittest
class ErrorsTestCase(unittest.TestCase): class LimitExceededErrorTestCase(unittest.TestCase):
def test_key_appears_in_context_but_not_error_dict(self) -> None:
err = LimitExceededError("needle")
serialised = json.dumps(err.error_dict(None))
self.assertIn("needle", err.debug_context)
self.assertNotIn("needle", serialised)
# Create a sub-class to avoid mutating the class-level property. # Create a sub-class to avoid mutating the class-level property.
class LimitExceededErrorHeaders(LimitExceededError): class LimitExceededErrorHeaders(LimitExceededError):
include_retry_after_header = True include_retry_after_header = True
def test_limit_exceeded_header(self) -> None: def test_limit_exceeded_header(self) -> None:
err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=100) err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=100)
self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100) self.assertEqual(err.error_dict(None).get("retry_after_ms"), 100)
assert err.headers is not None assert err.headers is not None
self.assertEqual(err.headers.get("Retry-After"), "1") self.assertEqual(err.headers.get("Retry-After"), "1")
def test_limit_exceeded_rounding(self) -> None: def test_limit_exceeded_rounding(self) -> None:
err = ErrorsTestCase.LimitExceededErrorHeaders(retry_after_ms=3001) err = self.LimitExceededErrorHeaders(limiter_name="test", retry_after_ms=3001)
self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001) self.assertEqual(err.error_dict(None).get("retry_after_ms"), 3001)
assert err.headers is not None assert err.headers is not None
self.assertEqual(err.headers.get("Retry-After"), "4") self.assertEqual(err.headers.get("Retry-After"), "4")

View file

@ -1,5 +1,6 @@
from synapse.api.ratelimiting import LimitExceededError, Ratelimiter from synapse.api.ratelimiting import LimitExceededError, Ratelimiter
from synapse.appservice import ApplicationService from synapse.appservice import ApplicationService
from synapse.config.ratelimiting import RatelimitSettings
from synapse.types import create_requester from synapse.types import create_requester
from tests import unittest from tests import unittest
@ -10,8 +11,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
rate_hz=0.1, cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
burst_count=1,
) )
allowed, time_allowed = self.get_success_or_raise( allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(None, key="test_id", _time_now_s=0) limiter.can_do_action(None, key="test_id", _time_now_s=0)
@ -43,8 +43,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
rate_hz=0.1, cfg=RatelimitSettings(
burst_count=1, key="",
per_second=0.1,
burst_count=1,
),
) )
allowed, time_allowed = self.get_success_or_raise( allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(as_requester, _time_now_s=0) limiter.can_do_action(as_requester, _time_now_s=0)
@ -76,8 +79,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
rate_hz=0.1, cfg=RatelimitSettings(
burst_count=1, key="",
per_second=0.1,
burst_count=1,
),
) )
allowed, time_allowed = self.get_success_or_raise( allowed, time_allowed = self.get_success_or_raise(
limiter.can_do_action(as_requester, _time_now_s=0) limiter.can_do_action(as_requester, _time_now_s=0)
@ -101,8 +107,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
rate_hz=0.1, cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
burst_count=1,
) )
# Shouldn't raise # Shouldn't raise
@ -128,8 +133,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
rate_hz=0.1, cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
burst_count=1,
) )
# First attempt should be allowed # First attempt should be allowed
@ -177,8 +181,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
rate_hz=0.1, cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
burst_count=1,
) )
# First attempt should be allowed # First attempt should be allowed
@ -208,8 +211,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
rate_hz=0.1, cfg=RatelimitSettings(key="", per_second=0.1, burst_count=1),
burst_count=1,
) )
self.get_success_or_raise( self.get_success_or_raise(
limiter.can_do_action(None, key="test_id_1", _time_now_s=0) limiter.can_do_action(None, key="test_id_1", _time_now_s=0)
@ -244,7 +246,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
) )
) )
limiter = Ratelimiter(store=store, clock=self.clock, rate_hz=0.1, burst_count=1) limiter = Ratelimiter(
store=store,
clock=self.clock,
cfg=RatelimitSettings("", per_second=0.1, burst_count=1),
)
# Shouldn't raise # Shouldn't raise
for _ in range(20): for _ in range(20):
@ -254,8 +260,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
rate_hz=0.1, cfg=RatelimitSettings(
burst_count=3, key="",
per_second=0.1,
burst_count=3,
),
) )
# Test that 4 actions aren't allowed with a maximum burst of 3. # Test that 4 actions aren't allowed with a maximum burst of 3.
allowed, time_allowed = self.get_success_or_raise( allowed, time_allowed = self.get_success_or_raise(
@ -321,8 +330,7 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
rate_hz=0.1, cfg=RatelimitSettings("", per_second=0.1, burst_count=3),
burst_count=3,
) )
def consume_at(time: float) -> bool: def consume_at(time: float) -> bool:
@ -346,8 +354,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
rate_hz=0.1, cfg=RatelimitSettings(
burst_count=3, "",
per_second=0.1,
burst_count=3,
),
) )
# Observe two actions, leaving room in the bucket for one more. # Observe two actions, leaving room in the bucket for one more.
@ -369,8 +380,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
rate_hz=0.1, cfg=RatelimitSettings(
burst_count=3, "",
per_second=0.1,
burst_count=3,
),
) )
# Observe three actions, filling up the bucket. # Observe three actions, filling up the bucket.
@ -398,8 +412,11 @@ class TestRatelimiter(unittest.HomeserverTestCase):
limiter = Ratelimiter( limiter = Ratelimiter(
store=self.hs.get_datastores().main, store=self.hs.get_datastores().main,
clock=self.clock, clock=self.clock,
rate_hz=0.1, cfg=RatelimitSettings(
burst_count=3, "",
per_second=0.1,
burst_count=3,
),
) )
# Observe four actions, exceeding the bucket. # Observe four actions, exceeding the bucket.

View file

@ -12,11 +12,42 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
from synapse.config.ratelimiting import RatelimitSettings
from tests.unittest import TestCase from tests.unittest import TestCase
from tests.utils import default_config from tests.utils import default_config
class ParseRatelimitSettingsTestcase(TestCase):
def test_depth_1(self) -> None:
cfg = {
"a": {
"per_second": 5,
"burst_count": 10,
}
}
parsed = RatelimitSettings.parse(cfg, "a")
self.assertEqual(parsed, RatelimitSettings("a", 5, 10))
def test_depth_2(self) -> None:
cfg = {
"a": {
"b": {
"per_second": 5,
"burst_count": 10,
},
}
}
parsed = RatelimitSettings.parse(cfg, "a.b")
self.assertEqual(parsed, RatelimitSettings("a.b", 5, 10))
def test_missing(self) -> None:
parsed = RatelimitSettings.parse(
{}, "a", defaults={"per_second": 5, "burst_count": 10}
)
self.assertEqual(parsed, RatelimitSettings("a", 5, 10))
class RatelimitConfigTestCase(TestCase): class RatelimitConfigTestCase(TestCase):
def test_parse_rc_federation(self) -> None: def test_parse_rc_federation(self) -> None:
config_dict = default_config("test") config_dict = default_config("test")