Merge remote-tracking branch 'origin/release-v1.21.3' into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2020-10-22 09:57:06 +01:00
commit ab4cd7f802
84 changed files with 1264 additions and 781 deletions

1
.gitignore vendored
View file

@ -21,6 +21,7 @@ _trial_temp*/
/.python-version /.python-version
/*.signing.key /*.signing.key
/env/ /env/
/.venv*/
/homeserver*.yaml /homeserver*.yaml
/logs /logs
/media_store/ /media_store/

1
changelog.d/8504.bugfix Normal file
View file

@ -0,0 +1 @@
Expose the `uk.half-shot.msc2778.login.application_service` to clients from the login API. This feature was added in v1.21.0, but was not exposed as a potential login flow.

1
changelog.d/8544.feature Normal file
View file

@ -0,0 +1 @@
Allow running background tasks in a separate worker process.

1
changelog.d/8545.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a long standing bug where email notifications for encrypted messages were blank.

1
changelog.d/8561.misc Normal file
View file

@ -0,0 +1 @@
Move metric registration code down into `LruCache`.

1
changelog.d/8562.misc Normal file
View file

@ -0,0 +1 @@
Add type annotations for `LruCache`.

1
changelog.d/8563.misc Normal file
View file

@ -0,0 +1 @@
Replace `DeferredCache` with the lighter-weight `LruCache` where possible.

1
changelog.d/8564.feature Normal file
View file

@ -0,0 +1 @@
Support modifying event content in `ThirdPartyRules` modules.

1
changelog.d/8566.misc Normal file
View file

@ -0,0 +1 @@
Add virtualenv-generated folders to `.gitignore`.

1
changelog.d/8567.bugfix Normal file
View file

@ -0,0 +1 @@
Fix increase in the number of `There was no active span...` errors logged when using OpenTracing.

1
changelog.d/8568.misc Normal file
View file

@ -0,0 +1 @@
Add `get_immediate` method to `DeferredCache`.

1
changelog.d/8569.misc Normal file
View file

@ -0,0 +1 @@
Fix mypy not properly checking across the codebase, additionally, fix a typing assertion error in `handlers/auth.py`.

1
changelog.d/8571.misc Normal file
View file

@ -0,0 +1 @@
Fix `synmark` benchmark runner.

1
changelog.d/8572.misc Normal file
View file

@ -0,0 +1 @@
Modify `DeferredCache.get()` to return `Deferred`s instead of `ObservableDeferred`s.

1
changelog.d/8577.misc Normal file
View file

@ -0,0 +1 @@
Adjust a protocol-type definition to fit `sqlite3` assertions.

1
changelog.d/8578.misc Normal file
View file

@ -0,0 +1 @@
Support macOS on the `synmark` benchmark runner.

1
changelog.d/8583.misc Normal file
View file

@ -0,0 +1 @@
Update `mypy` static type checker to 0.790.

1
changelog.d/8585.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug that prevented errors encountered during execution of the `synapse_port_db` from being correctly printed.

1
changelog.d/8587.misc Normal file
View file

@ -0,0 +1 @@
Re-organize the structured logging code to separate the TCP transport handling from the JSON formatting.

1
changelog.d/8589.removal Normal file
View file

@ -0,0 +1 @@
Drop unused `device_max_stream_id` table.

1
changelog.d/8590.misc Normal file
View file

@ -0,0 +1 @@
Implement [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409) to send typing, read receipts, and presence events to appservices.

1
changelog.d/8591.misc Normal file
View file

@ -0,0 +1 @@
Move metric registration code down into `LruCache`.

1
changelog.d/8592.misc Normal file
View file

@ -0,0 +1 @@
Remove extraneous unittest logging decorators from unit tests.

1
changelog.d/8593.misc Normal file
View file

@ -0,0 +1 @@
Minor optimisations in caching code.

1
changelog.d/8594.misc Normal file
View file

@ -0,0 +1 @@
Minor optimisations in caching code.

1
changelog.d/8599.feature Normal file
View file

@ -0,0 +1 @@
Allow running background tasks in a separate worker process.

1
changelog.d/8600.misc Normal file
View file

@ -0,0 +1 @@
Update `mypy` static type checker to 0.790.

1
changelog.d/8606.feature Normal file
View file

@ -0,0 +1 @@
Limit appservice transactions to 100 persistent and 100 ephemeral events.

1
changelog.d/8609.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to profile and base handler.

View file

@ -15,8 +15,9 @@ files =
synapse/events/builder.py, synapse/events/builder.py,
synapse/events/spamcheck.py, synapse/events/spamcheck.py,
synapse/federation, synapse/federation,
synapse/handlers/appservice.py, synapse/handlers/_base.py,
synapse/handlers/account_data.py, synapse/handlers/account_data.py,
synapse/handlers/appservice.py,
synapse/handlers/auth.py, synapse/handlers/auth.py,
synapse/handlers/cas_handler.py, synapse/handlers/cas_handler.py,
synapse/handlers/deactivate_account.py, synapse/handlers/deactivate_account.py,
@ -32,6 +33,7 @@ files =
synapse/handlers/pagination.py, synapse/handlers/pagination.py,
synapse/handlers/password_policy.py, synapse/handlers/password_policy.py,
synapse/handlers/presence.py, synapse/handlers/presence.py,
synapse/handlers/profile.py,
synapse/handlers/read_marker.py, synapse/handlers/read_marker.py,
synapse/handlers/room.py, synapse/handlers/room.py,
synapse/handlers/room_member.py, synapse/handlers/room_member.py,

View file

@ -22,6 +22,7 @@ import logging
import sys import sys
import time import time
import traceback import traceback
from typing import Optional
import yaml import yaml
@ -152,7 +153,7 @@ IGNORED_TABLES = {
# Error returned by the run function. Used at the top-level part of the script to # Error returned by the run function. Used at the top-level part of the script to
# handle errors and return codes. # handle errors and return codes.
end_error = None end_error = None # type: Optional[str]
# The exec_info for the error, if any. If error is defined but not exec_info the script # The exec_info for the error, if any. If error is defined but not exec_info the script
# will show only the error message without the stacktrace, if exec_info is defined but # will show only the error message without the stacktrace, if exec_info is defined but
# not the error then the script will show nothing outside of what's printed in the run # not the error then the script will show nothing outside of what's printed in the run
@ -635,7 +636,7 @@ class Porter(object):
self.progress.done() self.progress.done()
except Exception as e: except Exception as e:
global end_error_exec_info global end_error_exec_info
end_error = e end_error = str(e)
end_error_exec_info = sys.exc_info() end_error_exec_info = sys.exc_info()
logger.exception("") logger.exception("")
finally: finally:

View file

@ -102,6 +102,8 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
"flake8", "flake8",
] ]
CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
# Dependencies which are exclusively required by unit test code. This is # Dependencies which are exclusively required by unit test code. This is
# NOT a list of all modules that are necessary to run the unit tests. # NOT a list of all modules that are necessary to run the unit tests.
# Tests assume that all optional dependencies are installed. # Tests assume that all optional dependencies are installed.

View file

@ -34,7 +34,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase from synapse.events import EventBase
from synapse.logging import opentracing as opentracing from synapse.logging import opentracing as opentracing
from synapse.types import StateMap, UserID from synapse.types import StateMap, UserID
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -70,8 +69,9 @@ class Auth:
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.token_cache = LruCache(10000) self.token_cache = LruCache(
register_cache("cache", "token_cache", self.token_cache) 10000, "token_cache"
) # type: LruCache[str, Tuple[str, bool]]
self._auth_blocking = AuthBlocking(self.hs) self._auth_blocking = AuthBlocking(self.hs)

View file

@ -60,6 +60,13 @@ from synapse.types import JsonDict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Maximum number of events to provide in an AS transaction.
MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100
# Maximum number of ephemeral events to provide in an AS transaction.
MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
class ApplicationServiceScheduler: class ApplicationServiceScheduler:
""" Public facing API for this module. Does the required DI to tie the """ Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this components together. This also serves as the "event_pool", which in this
@ -136,10 +143,17 @@ class _ServiceQueuer:
self.requests_in_flight.add(service.id) self.requests_in_flight.add(service.id)
try: try:
while True: while True:
events = self.queued_events.pop(service.id, []) all_events = self.queued_events.get(service.id, [])
ephemeral = self.queued_ephemeral.pop(service.id, []) events = all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
del all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
all_events_ephemeral = self.queued_ephemeral.get(service.id, [])
ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
if not events and not ephemeral: if not events and not ephemeral:
return return
try: try:
await self.txn_ctrl.send(service, events, ephemeral) await self.txn_ctrl.send(service, events, ephemeral)
except Exception: except Exception:

View file

@ -98,7 +98,7 @@ class EventBuilder:
return self._state_key is not None return self._state_key is not None
async def build( async def build(
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]] self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]],
) -> EventBase: ) -> EventBase:
"""Transform into a fully signed and hashed event """Transform into a fully signed and hashed event

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Optional
import synapse.state import synapse.state
import synapse.storage import synapse.storage
@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership
from synapse.api.ratelimiting import Ratelimiter from synapse.api.ratelimiting import Ratelimiter
from synapse.types import UserID from synapse.types import UserID
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -30,11 +34,7 @@ class BaseHandler:
Common base class for the event handlers. Common base class for the event handlers.
""" """
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer):
"""
self.store = hs.get_datastore() # type: synapse.storage.DataStore self.store = hs.get_datastore() # type: synapse.storage.DataStore
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
@ -56,7 +56,7 @@ class BaseHandler:
clock=self.clock, clock=self.clock,
rate_hz=self.hs.config.rc_admin_redaction.per_second, rate_hz=self.hs.config.rc_admin_redaction.per_second,
burst_count=self.hs.config.rc_admin_redaction.burst_count, burst_count=self.hs.config.rc_admin_redaction.burst_count,
) ) # type: Optional[Ratelimiter]
else: else:
self.admin_redaction_ratelimiter = None self.admin_redaction_ratelimiter = None
@ -127,15 +127,15 @@ class BaseHandler:
if guest_access != "can_join": if guest_access != "can_join":
if context: if context:
current_state_ids = await context.get_current_state_ids() current_state_ids = await context.get_current_state_ids()
current_state = await self.store.get_events( current_state_dict = await self.store.get_events(
list(current_state_ids.values()) list(current_state_ids.values())
) )
current_state = list(current_state_dict.values())
else: else:
current_state = await self.state_handler.get_current_state( current_state_map = await self.state_handler.get_current_state(
event.room_id event.room_id
) )
current_state = list(current_state_map.values())
current_state = list(current_state.values())
logger.info("maybe_kick_guest_users %r", current_state) logger.info("maybe_kick_guest_users %r", current_state)
await self.kick_guest_users(current_state) await self.kick_guest_users(current_state)

View file

@ -22,7 +22,7 @@ from typing import List
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable from synapse.logging.context import make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID from synapse.types import UserID
from synapse.util import stringutils from synapse.util import stringutils
@ -63,16 +63,10 @@ class AccountValidityHandler:
self._raw_from = email.utils.parseaddr(self._from_string)[1] self._raw_from = email.utils.parseaddr(self._from_string)[1]
# Check the renewal emails to send and send them every 30min. # Check the renewal emails to send and send them every 30min.
def send_emails():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"send_renewals", self._send_renewal_emails
)
if hs.config.run_background_tasks: if hs.config.run_background_tasks:
self.clock.looping_call(send_emails, 30 * 60 * 1000) self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
@wrap_as_background_process("send_renewals")
async def _send_renewal_emails(self): async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time """Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity`` configured in the ``renew_at`` parameter from the ``account_validity``

View file

@ -1122,20 +1122,22 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash. Whether self.hash(password) == stored_hash.
""" """
def _do_validate_hash(): def _do_validate_hash(checked_hash: bytes):
# Normalise the Unicode in the password # Normalise the Unicode in the password
pw = unicodedata.normalize("NFKC", password) pw = unicodedata.normalize("NFKC", password)
return bcrypt.checkpw( return bcrypt.checkpw(
pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"), pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
stored_hash, checked_hash,
) )
if stored_hash: if stored_hash:
if not isinstance(stored_hash, bytes): if not isinstance(stored_hash, bytes):
stored_hash = stored_hash.encode("ascii") stored_hash = stored_hash.encode("ascii")
return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash) return await defer_to_thread(
self.hs.get_reactor(), _do_validate_hash, stored_hash
)
else: else:
return False return False

View file

@ -293,6 +293,10 @@ class InitialSyncHandler(BaseHandler):
user_id, room_id, pagin_config, membership, is_peeking user_id, room_id, pagin_config, membership, is_peeking
) )
elif membership == Membership.LEAVE: elif membership == Membership.LEAVE:
# The member_event_id will always be available if membership is set
# to leave.
assert member_event_id
result = await self._room_initial_sync_parted( result = await self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking user_id, room_id, pagin_config, membership, member_event_id, is_peeking
) )
@ -315,7 +319,7 @@ class InitialSyncHandler(BaseHandler):
user_id: str, user_id: str,
room_id: str, room_id: str,
pagin_config: PaginationConfig, pagin_config: PaginationConfig,
membership: Membership, membership: str,
member_event_id: str, member_event_id: str,
is_peeking: bool, is_peeking: bool,
) -> JsonDict: ) -> JsonDict:
@ -367,7 +371,7 @@ class InitialSyncHandler(BaseHandler):
user_id: str, user_id: str,
room_id: str, room_id: str,
pagin_config: PaginationConfig, pagin_config: PaginationConfig,
membership: Membership, membership: str,
is_peeking: bool, is_peeking: bool,
) -> JsonDict: ) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id) current_state = await self.state.get_current_state(room_id=room_id)

View file

@ -1364,7 +1364,12 @@ class EventCreationHandler:
for k, v in original_event.internal_metadata.get_dict().items(): for k, v in original_event.internal_metadata.get_dict().items():
setattr(builder.internal_metadata, k, v) setattr(builder.internal_metadata, k, v)
event = await builder.build(prev_event_ids=original_event.prev_event_ids()) # the event type hasn't changed, so there's no point in re-calculating the
# auth events.
event = await builder.build(
prev_event_ids=original_event.prev_event_ids(),
auth_event_ids=original_event.auth_event_ids(),
)
# we rebuild the event context, to be on the safe side. If nothing else, # we rebuild the event context, to be on the safe side. If nothing else,
# delta_ids might need an update. # delta_ids might need an update.

View file

@ -12,9 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import random import random
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import ( from synapse.api.errors import (
AuthError, AuthError,
@ -24,11 +24,20 @@ from synapse.api.errors import (
StoreError, StoreError,
SynapseError, SynapseError,
) )
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import UserID, create_requester, get_domain_from_id from synapse.types import (
JsonDict,
Requester,
UserID,
create_requester,
get_domain_from_id,
)
from ._base import BaseHandler from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_DISPLAYNAME_LEN = 256 MAX_DISPLAYNAME_LEN = 256
@ -45,7 +54,7 @@ class ProfileHandler(BaseHandler):
PROFILE_UPDATE_MS = 60 * 1000 PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__(hs) super().__init__(hs)
self.federation = hs.get_federation_client() self.federation = hs.get_federation_client()
@ -57,10 +66,10 @@ class ProfileHandler(BaseHandler):
if hs.config.run_background_tasks: if hs.config.run_background_tasks:
self.clock.looping_call( self.clock.looping_call(
self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
) )
async def get_profile(self, user_id): async def get_profile(self, user_id: str) -> JsonDict:
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
@ -91,7 +100,7 @@ class ProfileHandler(BaseHandler):
except HttpResponseException as e: except HttpResponseException as e:
raise e.to_synapse_error() raise e.to_synapse_error()
async def get_profile_from_cache(self, user_id): async def get_profile_from_cache(self, user_id: str) -> JsonDict:
"""Get the profile information from our local cache. If the user is """Get the profile information from our local cache. If the user is
ours then the profile information will always be corect. Otherwise, ours then the profile information will always be corect. Otherwise,
it may be out of date/missing. it may be out of date/missing.
@ -115,7 +124,7 @@ class ProfileHandler(BaseHandler):
profile = await self.store.get_from_remote_profile_cache(user_id) profile = await self.store.get_from_remote_profile_cache(user_id)
return profile or {} return profile or {}
async def get_displayname(self, target_user): async def get_displayname(self, target_user: UserID) -> str:
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
try: try:
displayname = await self.store.get_profile_displayname( displayname = await self.store.get_profile_displayname(
@ -143,15 +152,19 @@ class ProfileHandler(BaseHandler):
return result["displayname"] return result["displayname"]
async def set_displayname( async def set_displayname(
self, target_user, requester, new_displayname, by_admin=False self,
): target_user: UserID,
requester: Requester,
new_displayname: str,
by_admin: bool = False,
) -> None:
"""Set the displayname of a user """Set the displayname of a user
Args: Args:
target_user (UserID): the user whose displayname is to be changed. target_user: the user whose displayname is to be changed.
requester (Requester): The user attempting to make this change. requester: The user attempting to make this change.
new_displayname (str): The displayname to give this user. new_displayname: The displayname to give this user.
by_admin (bool): Whether this change was made by an administrator. by_admin: Whether this change was made by an administrator.
""" """
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver") raise SynapseError(400, "User is not hosted on this homeserver")
@ -176,8 +189,9 @@ class ProfileHandler(BaseHandler):
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,) 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
) )
displayname_to_set = new_displayname # type: Optional[str]
if new_displayname == "": if new_displayname == "":
new_displayname = None displayname_to_set = None
# If the admin changes the display name of a user, the requesting user cannot send # If the admin changes the display name of a user, the requesting user cannot send
# the join event to update the displayname in the rooms. # the join event to update the displayname in the rooms.
@ -185,7 +199,9 @@ class ProfileHandler(BaseHandler):
if by_admin: if by_admin:
requester = create_requester(target_user) requester = create_requester(target_user)
await self.store.set_profile_displayname(target_user.localpart, new_displayname) await self.store.set_profile_displayname(
target_user.localpart, displayname_to_set
)
if self.hs.config.user_directory_search_all_users: if self.hs.config.user_directory_search_all_users:
profile = await self.store.get_profileinfo(target_user.localpart) profile = await self.store.get_profileinfo(target_user.localpart)
@ -195,7 +211,7 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user) await self._update_join_states(requester, target_user)
async def get_avatar_url(self, target_user): async def get_avatar_url(self, target_user: UserID) -> str:
if self.hs.is_mine(target_user): if self.hs.is_mine(target_user):
try: try:
avatar_url = await self.store.get_profile_avatar_url( avatar_url = await self.store.get_profile_avatar_url(
@ -222,15 +238,19 @@ class ProfileHandler(BaseHandler):
return result["avatar_url"] return result["avatar_url"]
async def set_avatar_url( async def set_avatar_url(
self, target_user, requester, new_avatar_url, by_admin=False self,
target_user: UserID,
requester: Requester,
new_avatar_url: str,
by_admin: bool = False,
): ):
"""Set a new avatar URL for a user. """Set a new avatar URL for a user.
Args: Args:
target_user (UserID): the user whose avatar URL is to be changed. target_user: the user whose avatar URL is to be changed.
requester (Requester): The user attempting to make this change. requester: The user attempting to make this change.
new_avatar_url (str): The avatar URL to give this user. new_avatar_url: The avatar URL to give this user.
by_admin (bool): Whether this change was made by an administrator. by_admin: Whether this change was made by an administrator.
""" """
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
raise SynapseError(400, "User is not hosted on this homeserver") raise SynapseError(400, "User is not hosted on this homeserver")
@ -267,7 +287,7 @@ class ProfileHandler(BaseHandler):
await self._update_join_states(requester, target_user) await self._update_join_states(requester, target_user)
async def on_profile_query(self, args): async def on_profile_query(self, args: JsonDict) -> JsonDict:
user = UserID.from_string(args["user_id"]) user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user): if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver") raise SynapseError(400, "User is not hosted on this homeserver")
@ -292,7 +312,9 @@ class ProfileHandler(BaseHandler):
return response return response
async def _update_join_states(self, requester, target_user): async def _update_join_states(
self, requester: Requester, target_user: UserID
) -> None:
if not self.hs.is_mine(target_user): if not self.hs.is_mine(target_user):
return return
@ -323,15 +345,17 @@ class ProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s", room_id, str(e) "Failed to update join event for room %s - %s", room_id, str(e)
) )
async def check_profile_query_allowed(self, target_user, requester=None): async def check_profile_query_allowed(
self, target_user: UserID, requester: Optional[UserID] = None
) -> None:
"""Checks whether a profile query is allowed. If the """Checks whether a profile query is allowed. If the
'require_auth_for_profile_requests' config flag is set to True and a 'require_auth_for_profile_requests' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users 'requester' is provided, the query is only allowed if the two users
share a room. share a room.
Args: Args:
target_user (UserID): The owner of the queried profile. target_user: The owner of the queried profile.
requester (None|UserID): The user querying for the profile. requester: The user querying for the profile.
Raises: Raises:
SynapseError(403): The two users share no room, or ne user couldn't SynapseError(403): The two users share no room, or ne user couldn't
@ -370,11 +394,7 @@ class ProfileHandler(BaseHandler):
raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN) raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
raise raise
def _start_update_remote_profile_cache(self): @wrap_as_background_process("Update remote profile")
return run_as_background_process(
"Update remote profile", self._update_remote_profile_cache
)
async def _update_remote_profile_cache(self): async def _update_remote_profile_cache(self):
"""Called periodically to check profiles of remote users we haven't """Called periodically to check profiles of remote users we haven't
checked in a while. checked in a while.

225
synapse/logging/_remote.py Normal file
View file

@ -0,0 +1,225 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import traceback
from collections import deque
from ipaddress import IPv4Address, IPv6Address, ip_address
from math import floor
from typing import Callable, Optional
import attr
from zope.interface import implementer
from twisted.application.internet import ClientService
from twisted.internet.defer import Deferred
from twisted.internet.endpoints import (
HostnameEndpoint,
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
from twisted.internet.interfaces import IPushProducer, ITransport
from twisted.internet.protocol import Factory, Protocol
from twisted.logger import ILogObserver, Logger, LogLevel
@attr.s
@implementer(IPushProducer)
class LogProducer:
"""
An IPushProducer that writes logs from its buffer to its transport when it
is resumed.
Args:
buffer: Log buffer to read logs from.
transport: Transport to write to.
format_event: A callable to format the log entry to a string.
"""
transport = attr.ib(type=ITransport)
format_event = attr.ib(type=Callable[[dict], str])
_buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False)
def pauseProducing(self):
self._paused = True
def stopProducing(self):
self._paused = True
self._buffer = deque()
def resumeProducing(self):
self._paused = False
while self._paused is False and (self._buffer and self.transport.connected):
try:
# Request the next event and format it.
event = self._buffer.popleft()
msg = self.format_event(event)
# Send it as a new line over the transport.
self.transport.write(msg.encode("utf8"))
except Exception:
# Something has gone wrong writing to the transport -- log it
# and break out of the while.
traceback.print_exc(file=sys.__stderr__)
break
@attr.s
@implementer(ILogObserver)
class TCPLogObserver:
"""
An IObserver that writes JSON logs to a TCP target.
Args:
hs (HomeServer): The homeserver that is being logged for.
host: The host of the logging target.
port: The logging target's port.
format_event: A callable to format the log entry to a string.
maximum_buffer: The maximum buffer size.
"""
hs = attr.ib()
host = attr.ib(type=str)
port = attr.ib(type=int)
format_event = attr.ib(type=Callable[[dict], str])
maximum_buffer = attr.ib(type=int)
_buffer = attr.ib(default=attr.Factory(deque), type=deque)
_connection_waiter = attr.ib(default=None, type=Optional[Deferred])
_logger = attr.ib(default=attr.Factory(Logger))
_producer = attr.ib(default=None, type=Optional[LogProducer])
def start(self) -> None:
# Connect without DNS lookups if it's a direct IP.
try:
ip = ip_address(self.host)
if isinstance(ip, IPv4Address):
endpoint = TCP4ClientEndpoint(
self.hs.get_reactor(), self.host, self.port
)
elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(
self.hs.get_reactor(), self.host, self.port
)
else:
raise ValueError("Unknown IP address provided: %s" % (self.host,))
except ValueError:
endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
factory = Factory.forProtocol(Protocol)
self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
self._service.startService()
self._connect()
def stop(self):
self._service.stopService()
def _connect(self) -> None:
"""
Triggers an attempt to connect then write to the remote if not already writing.
"""
if self._connection_waiter:
return
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
@self._connection_waiter.addErrback
def fail(r):
r.printTraceback(file=sys.__stderr__)
self._connection_waiter = None
self._connect()
@self._connection_waiter.addCallback
def writer(r):
# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
if self._producer and r.transport is self._producer.transport:
self._producer.resumeProducing()
self._connection_waiter = None
return
# If the producer is still producing, stop it.
if self._producer:
self._producer.stopProducing()
# Make a new producer and start it.
self._producer = LogProducer(
buffer=self._buffer,
transport=r.transport,
format_event=self.format_event,
)
r.transport.registerProducer(self._producer, True)
self._producer.resumeProducing()
self._connection_waiter = None
def _handle_pressure(self) -> None:
"""
Handle backpressure by shedding events.
The buffer will, in this order, until the buffer is below the maximum:
- Shed DEBUG events
- Shed INFO events
- Shed the middle 50% of the events.
"""
if len(self._buffer) <= self.maximum_buffer:
return
# Strip out DEBUGs
self._buffer = deque(
filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer)
)
if len(self._buffer) <= self.maximum_buffer:
return
# Strip out INFOs
self._buffer = deque(
filter(lambda event: event["log_level"] != LogLevel.info, self._buffer)
)
if len(self._buffer) <= self.maximum_buffer:
return
# Cut the middle entries out
buffer_split = floor(self.maximum_buffer / 2)
old_buffer = self._buffer
self._buffer = deque()
for i in range(buffer_split):
self._buffer.append(old_buffer.popleft())
end_buffer = []
for i in range(buffer_split):
end_buffer.append(old_buffer.pop())
self._buffer.extend(reversed(end_buffer))
def __call__(self, event: dict) -> None:
self._buffer.append(event)
# Handle backpressure, if it exists.
try:
self._handle_pressure()
except Exception:
# If handling backpressure fails,clear the buffer and log the
# exception.
self._buffer.clear()
self._logger.failure("Failed clearing backpressure")
# Try and write immediately.
self._connect()

View file

@ -18,26 +18,11 @@ Log formatters that output terse JSON.
""" """
import json import json
import sys from typing import IO
import traceback
from collections import deque
from ipaddress import IPv4Address, IPv6Address, ip_address
from math import floor
from typing import IO, Optional
import attr from twisted.logger import FileLogObserver
from zope.interface import implementer
from twisted.application.internet import ClientService from synapse.logging._remote import TCPLogObserver
from twisted.internet.defer import Deferred
from twisted.internet.endpoints import (
HostnameEndpoint,
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
from twisted.internet.interfaces import IPushProducer, ITransport
from twisted.internet.protocol import Factory, Protocol
from twisted.logger import FileLogObserver, ILogObserver, Logger
_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":")) _encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
@ -150,180 +135,22 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb
return FileLogObserver(outFile, formatEvent) return FileLogObserver(outFile, formatEvent)
@attr.s def TerseJSONToTCPLogObserver(
@implementer(IPushProducer) hs, host: str, port: int, metadata: dict, maximum_buffer: int
class LogProducer: ) -> FileLogObserver:
""" """
An IPushProducer that writes logs from its buffer to its transport when it A log observer that formats events to a flattened JSON representation.
is resumed.
Args:
buffer: Log buffer to read logs from.
transport: Transport to write to.
"""
transport = attr.ib(type=ITransport)
_buffer = attr.ib(type=deque)
_paused = attr.ib(default=False, type=bool, init=False)
def pauseProducing(self):
self._paused = True
def stopProducing(self):
self._paused = True
self._buffer = deque()
def resumeProducing(self):
self._paused = False
while self._paused is False and (self._buffer and self.transport.connected):
try:
event = self._buffer.popleft()
self.transport.write(_encoder.encode(event).encode("utf8"))
self.transport.write(b"\n")
except Exception:
# Something has gone wrong writing to the transport -- log it
# and break out of the while.
traceback.print_exc(file=sys.__stderr__)
break
@attr.s
@implementer(ILogObserver)
class TerseJSONToTCPLogObserver:
"""
An IObserver that writes JSON logs to a TCP target.
Args: Args:
hs (HomeServer): The homeserver that is being logged for. hs (HomeServer): The homeserver that is being logged for.
host: The host of the logging target. host: The host of the logging target.
port: The logging target's port. port: The logging target's port.
metadata: Metadata to be added to each log entry. metadata: Metadata to be added to each log object.
maximum_buffer: The maximum buffer size.
""" """
hs = attr.ib() def formatEvent(_event: dict) -> str:
host = attr.ib(type=str) flattened = flatten_event(_event, metadata, include_time=True)
port = attr.ib(type=int) return _encoder.encode(flattened) + "\n"
metadata = attr.ib(type=dict)
maximum_buffer = attr.ib(type=int)
_buffer = attr.ib(default=attr.Factory(deque), type=deque)
_connection_waiter = attr.ib(default=None, type=Optional[Deferred])
_logger = attr.ib(default=attr.Factory(Logger))
_producer = attr.ib(default=None, type=Optional[LogProducer])
def start(self) -> None: return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer)
# Connect without DNS lookups if it's a direct IP.
try:
ip = ip_address(self.host)
if isinstance(ip, IPv4Address):
endpoint = TCP4ClientEndpoint(
self.hs.get_reactor(), self.host, self.port
)
elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(
self.hs.get_reactor(), self.host, self.port
)
except ValueError:
endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
factory = Factory.forProtocol(Protocol)
self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
self._service.startService()
self._connect()
def stop(self):
self._service.stopService()
def _connect(self) -> None:
"""
Triggers an attempt to connect then write to the remote if not already writing.
"""
if self._connection_waiter:
return
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
@self._connection_waiter.addErrback
def fail(r):
r.printTraceback(file=sys.__stderr__)
self._connection_waiter = None
self._connect()
@self._connection_waiter.addCallback
def writer(r):
# We have a connection. If we already have a producer, and its
# transport is the same, just trigger a resumeProducing.
if self._producer and r.transport is self._producer.transport:
self._producer.resumeProducing()
self._connection_waiter = None
return
# If the producer is still producing, stop it.
if self._producer:
self._producer.stopProducing()
# Make a new producer and start it.
self._producer = LogProducer(buffer=self._buffer, transport=r.transport)
r.transport.registerProducer(self._producer, True)
self._producer.resumeProducing()
self._connection_waiter = None
def _handle_pressure(self) -> None:
"""
Handle backpressure by shedding events.
The buffer will, in this order, until the buffer is below the maximum:
- Shed DEBUG events
- Shed INFO events
- Shed the middle 50% of the events.
"""
if len(self._buffer) <= self.maximum_buffer:
return
# Strip out DEBUGs
self._buffer = deque(
filter(lambda event: event["level"] != "DEBUG", self._buffer)
)
if len(self._buffer) <= self.maximum_buffer:
return
# Strip out INFOs
self._buffer = deque(
filter(lambda event: event["level"] != "INFO", self._buffer)
)
if len(self._buffer) <= self.maximum_buffer:
return
# Cut the middle entries out
buffer_split = floor(self.maximum_buffer / 2)
old_buffer = self._buffer
self._buffer = deque()
for i in range(buffer_split):
self._buffer.append(old_buffer.popleft())
end_buffer = []
for i in range(buffer_split):
end_buffer.append(old_buffer.pop())
self._buffer.extend(reversed(end_buffer))
def __call__(self, event: dict) -> None:
flattened = flatten_event(event, self.metadata, include_time=True)
self._buffer.append(flattened)
# Handle backpressure, if it exists.
try:
self._handle_pressure()
except Exception:
# If handling backpressure fails,clear the buffer and log the
# exception.
self._buffer.clear()
self._logger.failure("Failed clearing backpressure")
# Try and write immediately.
self._connect()

View file

@ -24,6 +24,7 @@ from prometheus_client.core import REGISTRY, Counter, Gauge
from twisted.internet import defer from twisted.internet import defer
from synapse.logging.context import LoggingContext, PreserveLoggingContext from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.opentracing import start_active_span
if TYPE_CHECKING: if TYPE_CHECKING:
import resource import resource
@ -197,14 +198,14 @@ def run_as_background_process(desc: str, func, *args, **kwargs):
with BackgroundProcessLoggingContext(desc) as context: with BackgroundProcessLoggingContext(desc) as context:
context.request = "%s-%i" % (desc, count) context.request = "%s-%i" % (desc, count)
try: try:
result = func(*args, **kwargs) with start_active_span(desc, tags={"request_id": context.request}):
result = func(*args, **kwargs)
if inspect.isawaitable(result): if inspect.isawaitable(result):
result = await result result = await result
return result return result
except Exception: except Exception:
logger.exception( logger.exception(
"Background process '%s' threw an exception", desc, "Background process '%s' threw an exception", desc,

View file

@ -496,6 +496,6 @@ class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
# dedupe when we add callbacks to lru cache nodes, otherwise the number # dedupe when we add callbacks to lru cache nodes, otherwise the number
# of callbacks would grow. # of callbacks would grow.
def __call__(self): def __call__(self):
rules = self.cache.get(self.room_id, None, update_metrics=False) rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
if rules: if rules:
rules.invalidate_all() rules.invalidate_all()

View file

@ -387,8 +387,8 @@ class Mailer:
return ret return ret
async def get_message_vars(self, notif, event, room_state_ids): async def get_message_vars(self, notif, event, room_state_ids):
if event.type != EventTypes.Message: if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
return return None
sender_state_event_id = room_state_ids[("m.room.member", event.sender)] sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
sender_state_event = await self.store.get_event(sender_state_event_id) sender_state_event = await self.store.get_event(sender_state_event_id)
@ -399,10 +399,8 @@ class Mailer:
# sender_hash % the number of default images to choose from # sender_hash % the number of default images to choose from
sender_hash = string_ordinal_total(event.sender) sender_hash = string_ordinal_total(event.sender)
msgtype = event.content.get("msgtype")
ret = { ret = {
"msgtype": msgtype, "event_type": event.type,
"is_historical": event.event_id != notif["event_id"], "is_historical": event.event_id != notif["event_id"],
"id": event.event_id, "id": event.event_id,
"ts": event.origin_server_ts, "ts": event.origin_server_ts,
@ -411,6 +409,14 @@ class Mailer:
"sender_hash": sender_hash, "sender_hash": sender_hash,
} }
# Encrypted messages don't have any additional useful information.
if event.type == EventTypes.Encrypted:
return ret
msgtype = event.content.get("msgtype")
ret["msgtype"] = msgtype
if msgtype == "m.text": if msgtype == "m.text":
self.add_text_message_vars(ret, event) self.add_text_message_vars(ret, event)
elif msgtype == "m.image": elif msgtype == "m.image":

View file

@ -16,11 +16,10 @@
import logging import logging
import re import re
from typing import Any, Dict, List, Optional, Pattern, Union from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
from synapse.events import EventBase from synapse.events import EventBase
from synapse.types import UserID from synapse.types import UserID
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -174,20 +173,21 @@ class PushRuleEvaluatorForEvent:
# Similar to _glob_matches, but do not treat display_name as a glob. # Similar to _glob_matches, but do not treat display_name as a glob.
r = regex_cache.get((display_name, False, True), None) r = regex_cache.get((display_name, False, True), None)
if not r: if not r:
r = re.escape(display_name) r1 = re.escape(display_name)
r = _re_word_boundary(r) r1 = _re_word_boundary(r1)
r = re.compile(r, flags=re.IGNORECASE) r = re.compile(r1, flags=re.IGNORECASE)
regex_cache[(display_name, False, True)] = r regex_cache[(display_name, False, True)] = r
return r.search(body) return bool(r.search(body))
def _get_value(self, dotted_key: str) -> Optional[str]: def _get_value(self, dotted_key: str) -> Optional[str]:
return self._value_cache.get(dotted_key, None) return self._value_cache.get(dotted_key, None)
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches # Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
regex_cache = LruCache(50000) regex_cache = LruCache(
register_cache("cache", "regex_push_cache", regex_cache) 50000, "regex_push_cache"
) # type: LruCache[Tuple[str, bool, bool], Pattern]
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool: def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
@ -205,7 +205,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
if not r: if not r:
r = _glob_to_re(glob, word_boundary) r = _glob_to_re(glob, word_boundary)
regex_cache[(glob, True, word_boundary)] = r regex_cache[(glob, True, word_boundary)] = r
return r.search(value) return bool(r.search(value))
except re.error: except re.error:
logger.warning("Failed to parse glob to regex: %r", glob) logger.warning("Failed to parse glob to regex: %r", glob)
return False return False

View file

@ -15,7 +15,7 @@
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.lrucache import LruCache
from ._base import BaseSlavedStore from ._base import BaseSlavedStore
@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.client_ip_last_seen = DeferredCache( self.client_ip_last_seen = LruCache(
name="client_ip_last_seen", keylen=4, max_entries=50000 cache_name="client_ip_last_seen", keylen=4, max_size=50000
) # type: DeferredCache[tuple, int] ) # type: LruCache[tuple, int]
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())
@ -41,7 +41,7 @@ class SlavedClientIpStore(BaseSlavedStore):
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return return
self.client_ip_last_seen.prefill(key, now) self.client_ip_last_seen.set(key, now)
self.hs.get_tcp_replication().send_user_ip( self.hs.get_tcp_replication().send_user_ip(
user_id, access_token, ip, user_agent, device_id, now user_id, access_token, ip, user_agent, device_id, now

View file

@ -1,41 +1,47 @@
{% for message in notif.messages %} {%- for message in notif.messages %}
<tr class="{{ "historical_message" if message.is_historical else "message" }}"> <tr class="{{ "historical_message" if message.is_historical else "message" }}">
<td class="sender_avatar"> <td class="sender_avatar">
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} {%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
{% if message.sender_avatar_url %} {%- if message.sender_avatar_url %}
<img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" /> <img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
{% else %} {%- else %}
{% if message.sender_hash % 3 == 0 %} {%- if message.sender_hash % 3 == 0 %}
<img class="sender_avatar" src="https://riot.im/img/external/avatar-1.png" /> <img class="sender_avatar" src="https://riot.im/img/external/avatar-1.png" />
{% elif message.sender_hash % 3 == 1 %} {%- elif message.sender_hash % 3 == 1 %}
<img class="sender_avatar" src="https://riot.im/img/external/avatar-2.png" /> <img class="sender_avatar" src="https://riot.im/img/external/avatar-2.png" />
{% else %} {%- else %}
<img class="sender_avatar" src="https://riot.im/img/external/avatar-3.png" /> <img class="sender_avatar" src="https://riot.im/img/external/avatar-3.png" />
{% endif %} {%- endif %}
{% endif %} {%- endif %}
{% endif %} {%- endif %}
</td> </td>
<td class="message_contents"> <td class="message_contents">
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %} {%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
<div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div> <div class="sender_name">{%- if message.msgtype == "m.emote" %}*{%- endif %} {{ message.sender_name }}</div>
{% endif %} {%- endif %}
<div class="message_body"> <div class="message_body">
{% if message.msgtype == "m.text" %} {%- if message.event_type == "m.room.encrypted" %}
{{ message.body_text_html }} An encrypted message.
{% elif message.msgtype == "m.emote" %} {%- elif message.event_type == "m.room.message" %}
{{ message.body_text_html }} {%- if message.msgtype == "m.text" %}
{% elif message.msgtype == "m.notice" %} {{ message.body_text_html }}
{{ message.body_text_html }} {%- elif message.msgtype == "m.emote" %}
{% elif message.msgtype == "m.image" %} {{ message.body_text_html }}
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" /> {%- elif message.msgtype == "m.notice" %}
{% elif message.msgtype == "m.file" %} {{ message.body_text_html }}
<span class="filename">{{ message.body_text_plain }}</span> {%- elif message.msgtype == "m.image" %}
{% endif %} <img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
{%- elif message.msgtype == "m.file" %}
<span class="filename">{{ message.body_text_plain }}</span>
{%- else %}
A message with unrecognised content.
{%- endif %}
{%- endif %}
</div> </div>
</td> </td>
<td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td> <td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
</tr> </tr>
{% endfor %} {%- endfor %}
<tr class="notif_link"> <tr class="notif_link">
<td></td> <td></td>
<td> <td>

View file

@ -1,16 +1,22 @@
{% for message in notif.messages %} {%- for message in notif.messages %}
{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }}) {%- if message.event_type == "m.room.encrypted" %}
{% if message.msgtype == "m.text" %} An encrypted message.
{%- elif message.event_type == "m.room.message" %}
{%- if message.msgtype == "m.emote" %}* {%- endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
{%- if message.msgtype == "m.text" %}
{{ message.body_text_plain }} {{ message.body_text_plain }}
{% elif message.msgtype == "m.emote" %} {%- elif message.msgtype == "m.emote" %}
{{ message.body_text_plain }} {{ message.body_text_plain }}
{% elif message.msgtype == "m.notice" %} {%- elif message.msgtype == "m.notice" %}
{{ message.body_text_plain }} {{ message.body_text_plain }}
{% elif message.msgtype == "m.image" %} {%- elif message.msgtype == "m.image" %}
{{ message.body_text_plain }} {{ message.body_text_plain }}
{% elif message.msgtype == "m.file" %} {%- elif message.msgtype == "m.file" %}
{{ message.body_text_plain }} {{ message.body_text_plain }}
{% endif %} {%- else %}
{% endfor %} A message with unrecognised content.
{%- endif %}
{%- endif %}
{%- endfor %}
View {{ room.title }} at {{ notif.link }} View {{ room.title }} at {{ notif.link }}

View file

@ -2,8 +2,8 @@
<html lang="en"> <html lang="en">
<head> <head>
<style type="text/css"> <style type="text/css">
{% include 'mail.css' without context %} {%- include 'mail.css' without context %}
{% include "mail-%s.css" % app_name ignore missing without context %} {%- include "mail-%s.css" % app_name ignore missing without context %}
</style> </style>
</head> </head>
<body> <body>
@ -18,21 +18,21 @@
<div class="summarytext">{{ summary_text }}</div> <div class="summarytext">{{ summary_text }}</div>
</td> </td>
<td class="logo"> <td class="logo">
{% if app_name == "Riot" %} {%- if app_name == "Riot" %}
<img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/> <img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
{% elif app_name == "Vector" %} {%- elif app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/> <img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
{% elif app_name == "Element" %} {%- elif app_name == "Element" %}
<img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/> <img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/>
{% else %} {%- else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/> <img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
{% endif %} {%- endif %}
</td> </td>
</tr> </tr>
</table> </table>
{% for room in rooms %} {%- for room in rooms %}
{% include 'room.html' with context %} {%- include 'room.html' with context %}
{% endfor %} {%- endfor %}
<div class="footer"> <div class="footer">
<a href="{{ unsubscribe_link }}">Unsubscribe</a> <a href="{{ unsubscribe_link }}">Unsubscribe</a>
<br/> <br/>
@ -41,12 +41,12 @@
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
an event was received at {{ reason.received_at|format_ts("%c") }} an event was received at {{ reason.received_at|format_ts("%c") }}
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago, which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
{% if reason.last_sent_ts %} {%- if reason.last_sent_ts %}
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }}, and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago. which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
{% else %} {%- else %}
and we don't have a last time we sent a mail for this room. and we don't have a last time we sent a mail for this room.
{% endif %} {%- endif %}
</div> </div>
</div> </div>
</td> </td>

View file

@ -2,9 +2,9 @@ Hi {{ user_display_name }},
{{ summary_text }} {{ summary_text }}
{% for room in rooms %} {%- for room in rooms %}
{% include 'room.txt' with context %} {%- include 'room.txt' with context %}
{% endfor %} {%- endfor %}
You can disable these notifications at {{ unsubscribe_link }} You can disable these notifications at {{ unsubscribe_link }}

View file

@ -1,23 +1,23 @@
<table class="room"> <table class="room">
<tr class="room_header"> <tr class="room_header">
<td class="room_avatar"> <td class="room_avatar">
{% if room.avatar_url %} {%- if room.avatar_url %}
<img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" /> <img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
{% else %} {%- else %}
{% if room.hash % 3 == 0 %} {%- if room.hash % 3 == 0 %}
<img alt="" src="https://riot.im/img/external/avatar-1.png" /> <img alt="" src="https://riot.im/img/external/avatar-1.png" />
{% elif room.hash % 3 == 1 %} {%- elif room.hash % 3 == 1 %}
<img alt="" src="https://riot.im/img/external/avatar-2.png" /> <img alt="" src="https://riot.im/img/external/avatar-2.png" />
{% else %} {%- else %}
<img alt="" src="https://riot.im/img/external/avatar-3.png" /> <img alt="" src="https://riot.im/img/external/avatar-3.png" />
{% endif %} {%- endif %}
{% endif %} {%- endif %}
</td> </td>
<td class="room_name" colspan="2"> <td class="room_name" colspan="2">
{{ room.title }} {{ room.title }}
</td> </td>
</tr> </tr>
{% if room.invite %} {%- if room.invite %}
<tr> <tr>
<td></td> <td></td>
<td> <td>
@ -25,9 +25,9 @@
</td> </td>
<td></td> <td></td>
</tr> </tr>
{% else %} {%- else %}
{% for notif in room.notifs %} {%- for notif in room.notifs %}
{% include 'notif.html' with context %} {%- include 'notif.html' with context %}
{% endfor %} {%- endfor %}
{% endif %} {%- endif %}
</table> </table>

View file

@ -1,9 +1,9 @@
{{ room.title }} {{ room.title }}
{% if room.invite %} {%- if room.invite %}
You've been invited, join at {{ room.link }} You've been invited, join at {{ room.link }}
{% else %} {%- else %}
{% for notif in room.notifs %} {%- for notif in room.notifs %}
{% include 'notif.txt' with context %} {%- include 'notif.txt' with context %}
{% endfor %} {%- endfor %}
{% endif %} {%- endif %}

View file

@ -110,6 +110,8 @@ class LoginRestServlet(RestServlet):
({"type": t} for t in self.auth_handler.get_supported_login_types()) ({"type": t} for t in self.auth_handler.get_supported_login_types())
) )
flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
return 200, {"flows": flows} return 200, {"flows": flows}
def on_OPTIONS(self, request: SynapseRequest): def on_OPTIONS(self, request: SynapseRequest):

View file

@ -76,14 +76,16 @@ class SQLBaseStore(metaclass=ABCMeta):
""" """
try: try:
if key is None: cache = getattr(self, cache_name)
getattr(self, cache_name).invalidate_all()
else:
getattr(self, cache_name).invalidate(tuple(key))
except AttributeError: except AttributeError:
# We probably haven't pulled in the cache in this worker, # We probably haven't pulled in the cache in this worker,
# which is fine. # which is fine.
pass return
if key is None:
cache.invalidate_all()
else:
cache.invalidate(tuple(key))
def db_to_json(db_content): def db_to_json(db_content):

View file

@ -160,7 +160,7 @@ class LoggingDatabaseConnection:
self.conn.__enter__() self.conn.__enter__()
return self return self
def __exit__(self, exc_type, exc_value, traceback) -> bool: def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
return self.conn.__exit__(exc_type, exc_value, traceback) return self.conn.__exit__(exc_type, exc_value, traceback)
# Proxy through any unknown lookups to the DB conn class. # Proxy through any unknown lookups to the DB conn class.

View file

@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.lrucache import LruCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -410,8 +410,8 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
class ClientIpStore(ClientIpWorkerStore): class ClientIpStore(ClientIpWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs): def __init__(self, database: DatabasePool, db_conn, hs):
self.client_ip_last_seen = DeferredCache( self.client_ip_last_seen = LruCache(
name="client_ip_last_seen", keylen=4, max_entries=50000 cache_name="client_ip_last_seen", keylen=4, max_size=50000
) )
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
@ -442,7 +442,7 @@ class ClientIpStore(ClientIpWorkerStore):
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY: if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
return return
self.client_ip_last_seen.prefill(key, now) self.client_ip_last_seen.set(key, now)
self._batch_row_update[key] = (user_agent, device_id, now) self._batch_row_update[key] = (user_agent, device_id, now)

View file

@ -34,8 +34,8 @@ from synapse.storage.database import (
) )
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
from synapse.util import json_decoder, json_encoder from synapse.util import json_decoder, json_encoder
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.stringutils import shortstr from synapse.util.stringutils import shortstr
@ -1005,8 +1005,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Map of (user_id, device_id) -> bool. If there is an entry that implies # Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists. # the device exists.
self.device_id_exists_cache = DeferredCache( self.device_id_exists_cache = LruCache(
name="device_id_exists", keylen=2, max_entries=10000 cache_name="device_id_exists", keylen=2, max_size=10000
) )
async def store_device( async def store_device(
@ -1052,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
) )
if hidden: if hidden:
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN) raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
self.device_id_exists_cache.prefill(key, True) self.device_id_exists_cache.set(key, True)
return inserted return inserted
except StoreError: except StoreError:
raise raise

View file

@ -1051,9 +1051,7 @@ class PersistEventsStore:
def prefill(): def prefill():
for cache_entry in to_prefill: for cache_entry in to_prefill:
self.store._get_event_cache.prefill( self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
(cache_entry[0].event_id,), cache_entry
)
txn.call_after(prefill) txn.call_after(prefill)

View file

@ -33,7 +33,10 @@ from synapse.api.room_versions import (
from synapse.events import EventBase, make_event_from_dict from synapse.events import EventBase, make_event_from_dict
from synapse.events.utils import prune_event from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream from synapse.replication.tcp.streams.events import EventsStream
@ -42,8 +45,8 @@ from synapse.storage.database import DatabasePool
from synapse.storage.engines import PostgresEngine from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.deferred_cache import DeferredCache
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@ -137,20 +140,16 @@ class EventsWorkerStore(SQLBaseStore):
db_conn, "events", "stream_ordering", step=-1 db_conn, "events", "stream_ordering", step=-1
) )
if not hs.config.worker.worker_app: if hs.config.run_background_tasks:
# We periodically clean out old transaction ID mappings # We periodically clean out old transaction ID mappings
self._clock.looping_call( self._clock.looping_call(
run_as_background_process, self._cleanup_old_transaction_ids, 5 * 60 * 1000,
5 * 60 * 1000,
"_cleanup_old_transaction_ids",
self._cleanup_old_transaction_ids,
) )
self._get_event_cache = DeferredCache( self._get_event_cache = LruCache(
"*getEvent*", cache_name="*getEvent*",
keylen=3, keylen=3,
max_entries=hs.config.caches.event_cache_size, max_size=hs.config.caches.event_cache_size,
apply_cache_factor_from_config=False,
) )
self._event_fetch_lock = threading.Condition() self._event_fetch_lock = threading.Condition()
@ -749,7 +748,7 @@ class EventsWorkerStore(SQLBaseStore):
event=original_ev, redacted_event=redacted_event event=original_ev, redacted_event=redacted_event
) )
self._get_event_cache.prefill((event_id,), cache_entry) self._get_event_cache.set((event_id,), cache_entry)
result_map[event_id] = cache_entry result_map[event_id] = cache_entry
return result_map return result_map
@ -1375,6 +1374,7 @@ class EventsWorkerStore(SQLBaseStore):
return mapping return mapping
@wrap_as_background_process("_cleanup_old_transaction_ids")
async def _cleanup_old_transaction_ids(self): async def _cleanup_old_transaction_ids(self):
"""Cleans out transaction id mappings older than 24hrs. """Cleans out transaction id mappings older than 24hrs.
""" """

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Dict, Optional from typing import Any, Dict, List, Optional
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore from synapse.storage._base import SQLBaseStore
@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore):
) )
async def set_profile_displayname( async def set_profile_displayname(
self, user_localpart: str, new_displayname: str self, user_localpart: str, new_displayname: Optional[str]
) -> None: ) -> None:
await self.db_pool.simple_update_one( await self.db_pool.simple_update_one(
table="profiles", table="profiles",
@ -144,7 +144,7 @@ class ProfileWorkerStore(SQLBaseStore):
async def get_remote_profile_cache_entries_that_expire( async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int self, last_checked: int
) -> Dict[str, str]: ) -> List[Dict[str, str]]:
"""Get all users who haven't been checked since `last_checked` """Get all users who haven't been checked since `last_checked`
""" """

View file

@ -303,7 +303,7 @@ class PusherStore(PusherWorkerStore):
lock=False, lock=False,
) )
user_has_pusher = self.get_if_user_has_pusher.cache.get( user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate(
(user_id,), None, update_metrics=False (user_id,), None, update_metrics=False
) )

View file

@ -25,7 +25,6 @@ from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@ -413,18 +412,10 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
if receipt_type != "m.read": if receipt_type != "m.read":
return return
# Returns either an ObservableDeferred or the raw result res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
res = self.get_users_with_read_receipts_in_room.cache.get(
room_id, None, update_metrics=False room_id, None, update_metrics=False
) )
# first handle the ObservableDeferred case
if isinstance(res, ObservableDeferred):
if res.has_called():
res = res.get_result()
else:
res = None
if res and user_id in res: if res and user_id in res:
# We'd only be adding to the set, so no point invalidating if the # We'd only be adding to the set, so no point invalidating if the
# user is already there # user is already there

View file

@ -20,7 +20,10 @@ from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
)
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
@ -67,16 +70,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
): ):
self._known_servers_count = 1 self._known_servers_count = 1
self.hs.get_clock().looping_call( self.hs.get_clock().looping_call(
run_as_background_process, self._count_known_servers, 60 * 1000,
60 * 1000,
"_count_known_servers",
self._count_known_servers,
) )
self.hs.get_clock().call_later( self.hs.get_clock().call_later(
1000, 1000, self._count_known_servers,
run_as_background_process,
"_count_known_servers",
self._count_known_servers,
) )
LaterGauge( LaterGauge(
"synapse_federation_known_servers", "synapse_federation_known_servers",
@ -85,6 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
lambda: self._known_servers_count, lambda: self._known_servers_count,
) )
@wrap_as_background_process("_count_known_servers")
async def _count_known_servers(self): async def _count_known_servers(self):
""" """
Count the servers that this server knows about. Count the servers that this server knows about.
@ -531,7 +529,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
# If we do then we can reuse that result and simply update it with # If we do then we can reuse that result and simply update it with
# any membership changes in `delta_ids` # any membership changes in `delta_ids`
if context.prev_group and context.delta_ids: if context.prev_group and context.delta_ids:
prev_res = self._get_joined_users_from_context.cache.get( prev_res = self._get_joined_users_from_context.cache.get_immediate(
(room_id, context.prev_group), None (room_id, context.prev_group), None
) )
if prev_res and isinstance(prev_res, dict): if prev_res and isinstance(prev_res, dict):

View file

@ -13,6 +13,5 @@
* limitations under the License. * limitations under the License.
*/ */
ALTER TABLE application_services_state ALTER TABLE application_services_state ADD COLUMN read_receipt_stream_id INT;
ADD COLUMN read_receipt_stream_id INT, ALTER TABLE application_services_state ADD COLUMN presence_stream_id INT;
ADD COLUMN presence_stream_id INT;

View file

@ -0,0 +1 @@
DROP TABLE device_max_stream_id;

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Iterable, Iterator, List, Tuple from typing import Any, Iterable, Iterator, List, Optional, Tuple
from typing_extensions import Protocol from typing_extensions import Protocol
@ -65,5 +65,5 @@ class Connection(Protocol):
def __enter__(self) -> "Connection": def __enter__(self) -> "Connection":
... ...
def __exit__(self, exc_type, exc_value, traceback) -> bool: def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
... ...

View file

@ -16,7 +16,7 @@
import logging import logging
from sys import intern from sys import intern
from typing import Callable, Dict, Optional from typing import Callable, Dict, Optional, Sized
import attr import attr
from prometheus_client.core import Gauge from prometheus_client.core import Gauge
@ -92,7 +92,7 @@ class CacheMetric:
def register_cache( def register_cache(
cache_type: str, cache_type: str,
cache_name: str, cache_name: str,
cache, cache: Sized,
collect_callback: Optional[Callable] = None, collect_callback: Optional[Callable] = None,
resizable: bool = True, resizable: bool = True,
resize_callback: Optional[Callable] = None, resize_callback: Optional[Callable] = None,
@ -100,12 +100,15 @@ def register_cache(
"""Register a cache object for metric collection and resizing. """Register a cache object for metric collection and resizing.
Args: Args:
cache_type cache_type: a string indicating the "type" of the cache. This is used
only for deduplication so isn't too important provided it's constant.
cache_name: name of the cache cache_name: name of the cache
cache: cache itself cache: cache itself, which must implement __len__(), and may optionally implement
a max_size property
collect_callback: If given, a function which is called during metric collect_callback: If given, a function which is called during metric
collection to update additional metrics. collection to update additional metrics.
resizable: Whether this cache supports being resized. resizable: Whether this cache supports being resized, in which case either
resize_callback must be provided, or the cache must support set_max_size().
resize_callback: A function which can be called to resize the cache. resize_callback: A function which can be called to resize the cache.
Returns: Returns:

View file

@ -17,14 +17,23 @@
import enum import enum
import threading import threading
from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast from typing import (
Callable,
Generic,
Iterable,
MutableMapping,
Optional,
TypeVar,
Union,
cast,
)
from prometheus_client import Gauge from prometheus_client import Gauge
from twisted.internet import defer from twisted.internet import defer
from twisted.python import failure
from synapse.util.async_helpers import ObservableDeferred from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches import register_cache
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
@ -34,7 +43,7 @@ cache_pending_metric = Gauge(
["name"], ["name"],
) )
T = TypeVar("T")
KT = TypeVar("KT") KT = TypeVar("KT")
VT = TypeVar("VT") VT = TypeVar("VT")
@ -49,15 +58,12 @@ class DeferredCache(Generic[KT, VT]):
"""Wraps an LruCache, adding support for Deferred results. """Wraps an LruCache, adding support for Deferred results.
It expects that each entry added with set() will be a Deferred; likewise get() It expects that each entry added with set() will be a Deferred; likewise get()
may return an ObservableDeferred. will return a Deferred.
""" """
__slots__ = ( __slots__ = (
"cache", "cache",
"name",
"keylen",
"thread", "thread",
"metrics",
"_pending_deferred_cache", "_pending_deferred_cache",
) )
@ -89,37 +95,27 @@ class DeferredCache(Generic[KT, VT]):
cache_type() cache_type()
) # type: MutableMapping[KT, CacheEntry] ) # type: MutableMapping[KT, CacheEntry]
def metrics_cb():
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
# cache is used for completed results and maps to the result itself, rather than # cache is used for completed results and maps to the result itself, rather than
# a Deferred. # a Deferred.
self.cache = LruCache( self.cache = LruCache(
max_size=max_entries, max_size=max_entries,
keylen=keylen, keylen=keylen,
cache_name=name,
cache_type=cache_type, cache_type=cache_type,
size_callback=(lambda d: len(d)) if iterable else None, size_callback=(lambda d: len(d)) if iterable else None,
evicted_callback=self._on_evicted, metrics_collection_callback=metrics_cb,
apply_cache_factor_from_config=apply_cache_factor_from_config, apply_cache_factor_from_config=apply_cache_factor_from_config,
) ) # type: LruCache[KT, VT]
self.name = name
self.keylen = keylen
self.thread = None # type: Optional[threading.Thread] self.thread = None # type: Optional[threading.Thread]
self.metrics = register_cache(
"cache",
name,
self.cache,
collect_callback=self._metrics_collection_callback,
)
@property @property
def max_entries(self): def max_entries(self):
return self.cache.max_size return self.cache.max_size
def _on_evicted(self, evicted_count):
self.metrics.inc_evictions(evicted_count)
def _metrics_collection_callback(self):
cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
def check_thread(self): def check_thread(self):
expected_thread = self.thread expected_thread = self.thread
if expected_thread is None: if expected_thread is None:
@ -133,62 +129,113 @@ class DeferredCache(Generic[KT, VT]):
def get( def get(
self, self,
key: KT, key: KT,
default=_Sentinel.sentinel,
callback: Optional[Callable[[], None]] = None, callback: Optional[Callable[[], None]] = None,
update_metrics: bool = True, update_metrics: bool = True,
): ) -> defer.Deferred:
"""Looks the key up in the caches. """Looks the key up in the caches.
For symmetry with set(), this method does *not* follow the synapse logcontext
rules: the logcontext will not be cleared on return, and the Deferred will run
its callbacks in the sentinel context. In other words: wrap the result with
make_deferred_yieldable() before `await`ing it.
Args: Args:
key(tuple) key:
default: What is returned if key is not in the caches. If not callback: Gets called when the entry in the cache is invalidated
specified then function throws KeyError instead
callback(fn): Gets called when the entry in the cache is invalidated
update_metrics (bool): whether to update the cache hit rate metrics update_metrics (bool): whether to update the cache hit rate metrics
Returns: Returns:
Either an ObservableDeferred or the result itself A Deferred which completes with the result. Note that this may later fail
if there is an ongoing set() operation which later completes with a failure.
Raises:
KeyError if the key is not found in the cache
""" """
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel) val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
if val is not _Sentinel.sentinel: if val is not _Sentinel.sentinel:
val.callbacks.update(callbacks) val.callbacks.update(callbacks)
if update_metrics: if update_metrics:
self.metrics.inc_hits() m = self.cache.metrics
return val.deferred assert m # we always have a name, so should always have metrics
m.inc_hits()
return val.deferred.observe()
val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks) val2 = self.cache.get(
if val is not _Sentinel.sentinel: key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
self.metrics.inc_hits() )
return val if val2 is _Sentinel.sentinel:
if update_metrics:
self.metrics.inc_misses()
if default is _Sentinel.sentinel:
raise KeyError() raise KeyError()
else: else:
return default return defer.succeed(val2)
def get_immediate(
self, key: KT, default: T, update_metrics: bool = True
) -> Union[VT, T]:
"""If we have a *completed* cached value, return it."""
return self.cache.get(key, default, update_metrics=update_metrics)
def set( def set(
self, self,
key: KT, key: KT,
value: defer.Deferred, value: defer.Deferred,
callback: Optional[Callable[[], None]] = None, callback: Optional[Callable[[], None]] = None,
) -> ObservableDeferred: ) -> defer.Deferred:
"""Adds a new entry to the cache (or updates an existing one).
The given `value` *must* be a Deferred.
First any existing entry for the same key is invalidated. Then a new entry
is added to the cache for the given key.
Until the `value` completes, calls to `get()` for the key will also result in an
incomplete Deferred, which will ultimately complete with the same result as
`value`.
If `value` completes successfully, subsequent calls to `get()` will then return
a completed deferred with the same result. If it *fails*, the cache is
invalidated and subequent calls to `get()` will raise a KeyError.
If another call to `set()` happens before `value` completes, then (a) any
invalidation callbacks registered in the interim will be called, (b) any
`get()`s in the interim will continue to complete with the result from the
*original* `value`, (c) any future calls to `get()` will complete with the
result from the *new* `value`.
It is expected that `value` does *not* follow the synapse logcontext rules - ie,
if it is incomplete, it runs its callbacks in the sentinel context.
Args:
key: Key to be set
value: a deferred which will complete with a result to add to the cache
callback: An optional callback to be called when the entry is invalidated
"""
if not isinstance(value, defer.Deferred): if not isinstance(value, defer.Deferred):
raise TypeError("not a Deferred") raise TypeError("not a Deferred")
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
self.check_thread() self.check_thread()
observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None) existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry: if existing_entry:
existing_entry.invalidate() existing_entry.invalidate()
# XXX: why don't we invalidate the entry in `self.cache` yet?
# we can save a whole load of effort if the deferred is ready.
if value.called:
result = value.result
if not isinstance(result, failure.Failure):
self.cache.set(key, result, callbacks)
return value
# otherwise, we'll add an entry to the _pending_deferred_cache for now,
# and add callbacks to add it to the cache properly later.
observable = ObservableDeferred(value, consumeErrors=True)
observer = observable.observe()
entry = CacheEntry(deferred=observable, callbacks=callbacks)
self._pending_deferred_cache[key] = entry self._pending_deferred_cache[key] = entry
def compare_and_pop(): def compare_and_pop():
@ -232,7 +279,9 @@ class DeferredCache(Generic[KT, VT]):
# _pending_deferred_cache to the real cache. # _pending_deferred_cache to the real cache.
# #
observer.addCallbacks(cb, eb) observer.addCallbacks(cb, eb)
return observable
# we return a new Deferred which will be called before any subsequent observers.
return observable.observe()
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None): def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
callbacks = [callback] if callback else [] callbacks = [callback] if callback else []
@ -257,11 +306,12 @@ class DeferredCache(Generic[KT, VT]):
self.check_thread() self.check_thread()
if not isinstance(key, tuple): if not isinstance(key, tuple):
raise TypeError("The cache key must be a tuple not %r" % (type(key),)) raise TypeError("The cache key must be a tuple not %r" % (type(key),))
key = cast(KT, key)
self.cache.del_multi(key) self.cache.del_multi(key)
# if we have a pending lookup for this key, remove it from the # if we have a pending lookup for this key, remove it from the
# _pending_deferred_cache, as above # _pending_deferred_cache, as above
entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None) entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None: if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict): for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate() entry.invalidate()

View file

@ -23,7 +23,6 @@ from twisted.internet import defer
from synapse.logging.context import make_deferred_yieldable, preserve_fn from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.util import unwrapFirstError from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.deferred_cache import DeferredCache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -156,7 +155,7 @@ class CacheDescriptor(_CacheDescriptorBase):
keylen=self.num_args, keylen=self.num_args,
tree=self.tree, tree=self.tree,
iterable=self.iterable, iterable=self.iterable,
) # type: DeferredCache[Tuple, Any] ) # type: DeferredCache[CacheKey, Any]
def get_cache_key_gen(args, kwargs): def get_cache_key_gen(args, kwargs):
"""Given some args/kwargs return a generator that resolves into """Given some args/kwargs return a generator that resolves into
@ -202,32 +201,20 @@ class CacheDescriptor(_CacheDescriptorBase):
cache_key = get_cache_key(args, kwargs) cache_key = get_cache_key(args, kwargs)
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
try: try:
cached_result_d = cache.get(cache_key, callback=invalidate_callback) ret = cache.get(cache_key, callback=invalidate_callback)
if isinstance(cached_result_d, ObservableDeferred):
observer = cached_result_d.observe()
else:
observer = defer.succeed(cached_result_d)
except KeyError: except KeyError:
# Add our own `cache_context` to argument list if the wrapped function
# has asked for one
if self.add_cache_context:
kwargs["cache_context"] = _CacheContext.get_instance(
cache, cache_key
)
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs) ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
ret = cache.set(cache_key, ret, callback=invalidate_callback)
def onErr(f): return make_deferred_yieldable(ret)
cache.invalidate(cache_key)
return f
ret.addErrback(onErr)
result_d = cache.set(cache_key, ret, callback=invalidate_callback)
observer = result_d.observe()
return make_deferred_yieldable(observer)
wrapped = cast(_CachedFunction, _wrapped) wrapped = cast(_CachedFunction, _wrapped)
@ -286,7 +273,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
def __get__(self, obj, objtype=None): def __get__(self, obj, objtype=None):
cached_method = getattr(obj, self.cached_method_name) cached_method = getattr(obj, self.cached_method_name)
cache = cached_method.cache cache = cached_method.cache # type: DeferredCache[CacheKey, Any]
num_args = cached_method.num_args num_args = cached_method.num_args
@functools.wraps(self.orig) @functools.wraps(self.orig)
@ -326,14 +313,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
for arg in list_args: for arg in list_args:
try: try:
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback) res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
if not isinstance(res, ObservableDeferred): if not res.called:
results[arg] = res
elif not res.has_succeeded():
res = res.observe()
res.addCallback(update_results_dict, arg) res.addCallback(update_results_dict, arg)
cached_defers.append(res) cached_defers.append(res)
else: else:
results[arg] = res.get_result() results[arg] = res.result
except KeyError: except KeyError:
missing.add(arg) missing.add(arg)

View file

@ -12,15 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import enum
import logging import logging
import threading import threading
from collections import namedtuple from collections import namedtuple
from typing import Any
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from . import register_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -40,24 +39,25 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
return len(self.value) return len(self.value)
class _Sentinel(enum.Enum):
# defining a sentinel in this way allows mypy to correctly handle the
# type of a dictionary lookup.
sentinel = object()
class DictionaryCache: class DictionaryCache:
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e. """Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
fetching a subset of dictionary keys for a particular key. fetching a subset of dictionary keys for a particular key.
""" """
def __init__(self, name, max_entries=1000): def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries, size_callback=len) self.cache = LruCache(
max_size=max_entries, cache_name=name, size_callback=len
) # type: LruCache[Any, DictionaryEntry]
self.name = name self.name = name
self.sequence = 0 self.sequence = 0
self.thread = None self.thread = None
# caches_by_name[name] = self.cache
class Sentinel:
__slots__ = []
self.sentinel = Sentinel()
self.metrics = register_cache("dictionary", name, self.cache)
def check_thread(self): def check_thread(self):
expected_thread = self.thread expected_thread = self.thread
@ -80,10 +80,8 @@ class DictionaryCache:
Returns: Returns:
DictionaryEntry DictionaryEntry
""" """
entry = self.cache.get(key, self.sentinel) entry = self.cache.get(key, _Sentinel.sentinel)
if entry is not self.sentinel: if entry is not _Sentinel.sentinel:
self.metrics.inc_hits()
if dict_keys is None: if dict_keys is None:
return DictionaryEntry( return DictionaryEntry(
entry.full, entry.known_absent, dict(entry.value) entry.full, entry.known_absent, dict(entry.value)
@ -95,7 +93,6 @@ class DictionaryCache:
{k: entry.value[k] for k in dict_keys if k in entry.value}, {k: entry.value[k] for k in dict_keys if k in entry.value},
) )
self.metrics.inc_misses()
return DictionaryEntry(False, set(), {}) return DictionaryEntry(False, set(), {})
def invalidate(self, key): def invalidate(self, key):

View file

@ -15,11 +15,35 @@
import threading import threading
from functools import wraps from functools import wraps
from typing import Callable, Optional, Type, Union from typing import (
Any,
Callable,
Generic,
Iterable,
Optional,
Type,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import Literal
from synapse.config import cache as cache_config from synapse.config import cache as cache_config
from synapse.util.caches import CacheMetric, register_cache
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache
# Function type: the type used for invalidation callbacks
FT = TypeVar("FT", bound=Callable[..., Any])
# Key and Value type for the cache
KT = TypeVar("KT")
VT = TypeVar("VT")
# a general type var, distinct from either KT or VT
T = TypeVar("T")
def enumerate_leaves(node, depth): def enumerate_leaves(node, depth):
if depth == 0: if depth == 0:
@ -41,29 +65,31 @@ class _Node:
self.callbacks = callbacks self.callbacks = callbacks
class LruCache: class LruCache(Generic[KT, VT]):
""" """
Least-recently-used cache. Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
Supports del_multi only if cache_type=TreeCache Supports del_multi only if cache_type=TreeCache
If cache_type=TreeCache, all keys must be tuples. If cache_type=TreeCache, all keys must be tuples.
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
""" """
def __init__( def __init__(
self, self,
max_size: int, max_size: int,
cache_name: Optional[str] = None,
keylen: int = 1, keylen: int = 1,
cache_type: Type[Union[dict, TreeCache]] = dict, cache_type: Type[Union[dict, TreeCache]] = dict,
size_callback: Optional[Callable] = None, size_callback: Optional[Callable] = None,
evicted_callback: Optional[Callable] = None, metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True, apply_cache_factor_from_config: bool = True,
): ):
""" """
Args: Args:
max_size: The maximum amount of entries the cache can hold max_size: The maximum amount of entries the cache can hold
cache_name: The name of this cache, for the prometheus metrics. If unset,
no metrics will be reported on this cache.
keylen: The length of the tuple used as the cache key. Ignored unless keylen: The length of the tuple used as the cache key. Ignored unless
cache_type is `TreeCache`. cache_type is `TreeCache`.
@ -73,9 +99,13 @@ class LruCache:
size_callback (func(V) -> int | None): size_callback (func(V) -> int | None):
evicted_callback (func(int)|None): metrics_collection_callback:
if not None, called on eviction with the size of the evicted metrics collection callback. This is called early in the metrics
entry collection process, before any of the metrics registered with the
prometheus Registry are collected, so can be used to update any dynamic
metrics.
Ignored if cache_name is None.
apply_cache_factor_from_config (bool): If true, `max_size` will be apply_cache_factor_from_config (bool): If true, `max_size` will be
multiplied by a cache factor derived from the homeserver config multiplied by a cache factor derived from the homeserver config
@ -94,6 +124,23 @@ class LruCache:
else: else:
self.max_size = int(max_size) self.max_size = int(max_size)
# register_cache might call our "set_cache_factor" callback; there's nothing to
# do yet when we get resized.
self._on_resize = None # type: Optional[Callable[[],None]]
if cache_name is not None:
metrics = register_cache(
"lru_cache",
cache_name,
self,
collect_callback=metrics_collection_callback,
) # type: Optional[CacheMetric]
else:
metrics = None
# this is exposed for access from outside this class
self.metrics = metrics
list_root = _Node(None, None, None, None) list_root = _Node(None, None, None, None)
list_root.next_node = list_root list_root.next_node = list_root
list_root.prev_node = list_root list_root.prev_node = list_root
@ -105,16 +152,16 @@ class LruCache:
todelete = list_root.prev_node todelete = list_root.prev_node
evicted_len = delete_node(todelete) evicted_len = delete_node(todelete)
cache.pop(todelete.key, None) cache.pop(todelete.key, None)
if evicted_callback: if metrics:
evicted_callback(evicted_len) metrics.inc_evictions(evicted_len)
def synchronized(f): def synchronized(f: FT) -> FT:
@wraps(f) @wraps(f)
def inner(*args, **kwargs): def inner(*args, **kwargs):
with lock: with lock:
return f(*args, **kwargs) return f(*args, **kwargs)
return inner return cast(FT, inner)
cached_cache_len = [0] cached_cache_len = [0]
if size_callback is not None: if size_callback is not None:
@ -168,18 +215,45 @@ class LruCache:
node.callbacks.clear() node.callbacks.clear()
return deleted_len return deleted_len
@overload
def cache_get(
key: KT,
default: Literal[None] = None,
callbacks: Iterable[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Optional[VT]:
...
@overload
def cache_get(
key: KT,
default: T,
callbacks: Iterable[Callable[[], None]] = ...,
update_metrics: bool = ...,
) -> Union[T, VT]:
...
@synchronized @synchronized
def cache_get(key, default=None, callbacks=[]): def cache_get(
key: KT,
default: Optional[T] = None,
callbacks: Iterable[Callable[[], None]] = [],
update_metrics: bool = True,
):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
move_node_to_front(node) move_node_to_front(node)
node.callbacks.update(callbacks) node.callbacks.update(callbacks)
if update_metrics and metrics:
metrics.inc_hits()
return node.value return node.value
else: else:
if update_metrics and metrics:
metrics.inc_misses()
return default return default
@synchronized @synchronized
def cache_set(key, value, callbacks=[]): def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
# We sometimes store large objects, e.g. dicts, which cause # We sometimes store large objects, e.g. dicts, which cause
@ -208,7 +282,7 @@ class LruCache:
evict() evict()
@synchronized @synchronized
def cache_set_default(key, value): def cache_set_default(key: KT, value: VT) -> VT:
node = cache.get(key, None) node = cache.get(key, None)
if node is not None: if node is not None:
return node.value return node.value
@ -217,8 +291,16 @@ class LruCache:
evict() evict()
return value return value
@overload
def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
...
@overload
def cache_pop(key: KT, default: T) -> Union[T, VT]:
...
@synchronized @synchronized
def cache_pop(key, default=None): def cache_pop(key: KT, default: Optional[T] = None):
node = cache.get(key, None) node = cache.get(key, None)
if node: if node:
delete_node(node) delete_node(node)
@ -228,18 +310,18 @@ class LruCache:
return default return default
@synchronized @synchronized
def cache_del_multi(key): def cache_del_multi(key: KT) -> None:
""" """
This will only work if constructed with cache_type=TreeCache This will only work if constructed with cache_type=TreeCache
""" """
popped = cache.pop(key) popped = cache.pop(key)
if popped is None: if popped is None:
return return
for leaf in enumerate_leaves(popped, keylen - len(key)): for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
delete_node(leaf) delete_node(leaf)
@synchronized @synchronized
def cache_clear(): def cache_clear() -> None:
list_root.next_node = list_root list_root.next_node = list_root
list_root.prev_node = list_root list_root.prev_node = list_root
for node in cache.values(): for node in cache.values():
@ -250,15 +332,21 @@ class LruCache:
cached_cache_len[0] = 0 cached_cache_len[0] = 0
@synchronized @synchronized
def cache_contains(key): def cache_contains(key: KT) -> bool:
return key in cache return key in cache
self.sentinel = object() self.sentinel = object()
# make sure that we clear out any excess entries after we get resized.
self._on_resize = evict self._on_resize = evict
self.get = cache_get self.get = cache_get
self.set = cache_set self.set = cache_set
self.setdefault = cache_set_default self.setdefault = cache_set_default
self.pop = cache_pop self.pop = cache_pop
# `invalidate` is exposed for consistency with DeferredCache, so that it can be
# invalidated by the cache invalidation replication stream.
self.invalidate = cache_pop
if cache_type is TreeCache: if cache_type is TreeCache:
self.del_multi = cache_del_multi self.del_multi = cache_del_multi
self.len = synchronized(cache_len) self.len = synchronized(cache_len)
@ -302,6 +390,7 @@ class LruCache:
new_size = int(self._original_max_size * factor) new_size = int(self._original_max_size * factor)
if new_size != self.max_size: if new_size != self.max_size:
self.max_size = new_size self.max_size = new_size
self._on_resize() if self._on_resize:
self._on_resize()
return True return True
return False return False

View file

@ -15,7 +15,10 @@
import sys import sys
from twisted.internet import epollreactor try:
from twisted.internet.epollreactor import EPollReactor as Reactor
except ImportError:
from twisted.internet.pollreactor import PollReactor as Reactor
from twisted.internet.main import installReactor from twisted.internet.main import installReactor
from synapse.config.homeserver import HomeServerConfig from synapse.config.homeserver import HomeServerConfig
@ -41,7 +44,7 @@ async def make_homeserver(reactor, config=None):
config_obj = HomeServerConfig() config_obj = HomeServerConfig()
config_obj.parse_config_dict(config, "", "") config_obj.parse_config_dict(config, "", "")
hs = await setup_test_homeserver( hs = setup_test_homeserver(
cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock
) )
stor = hs.get_datastore() stor = hs.get_datastore()
@ -63,7 +66,7 @@ def make_reactor():
Instantiate and install a Twisted reactor suitable for testing (i.e. not the Instantiate and install a Twisted reactor suitable for testing (i.e. not the
default global one). default global one).
""" """
reactor = epollreactor.EPollReactor() reactor = Reactor()
if "twisted.internet.reactor" in sys.modules: if "twisted.internet.reactor" in sys.modules:
del sys.modules["twisted.internet.reactor"] del sys.modules["twisted.internet.reactor"]

View file

@ -260,6 +260,31 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], []) self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
self.assertEquals(3, self.txn_ctrl.send.call_count) self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_large_txns(self):
srv_1_defer = defer.Deferred()
srv_2_defer = defer.Deferred()
send_return_list = [srv_1_defer, srv_2_defer]
def do_send(x, y, z):
return make_deferred_yieldable(send_return_list.pop(0))
self.txn_ctrl.send = Mock(side_effect=do_send)
service = Mock(id=4, name="service")
event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)]
for event in event_list:
self.queuer.enqueue_event(service, event)
# Expect the first event to be sent immediately.
self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [])
srv_1_defer.callback(service)
# Then send the next 100 events
self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [])
srv_2_defer.callback(service)
# Then the final 99 events
self.txn_ctrl.send.assert_called_with(service, event_list[101:], [])
self.assertEquals(3, self.txn_ctrl.send.call_count)
def test_send_single_ephemeral_no_queue(self): def test_send_single_ephemeral_no_queue(self):
# Expect the event to be sent immediately. # Expect the event to be sent immediately.
service = Mock(id=4, name="service") service = Mock(id=4, name="service")
@ -296,3 +321,19 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
# Expect the queued events to be sent # Expect the queued events to be sent
self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3) self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
self.assertEquals(2, self.txn_ctrl.send.call_count) self.assertEquals(2, self.txn_ctrl.send.call_count)
def test_send_large_txns_ephemeral(self):
d = defer.Deferred()
self.txn_ctrl.send = Mock(
side_effect=lambda x, y, z: make_deferred_yieldable(d)
)
# Expect the event to be sent immediately.
service = Mock(id=4, name="service")
first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)]
second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
event_list = first_chunk + second_chunk
self.queuer.enqueue_ephemeral(service, event_list)
self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk)
d.callback(service)
self.txn_ctrl.send.assert_called_with(service, [], second_chunk)
self.assertEquals(2, self.txn_ctrl.send.call_count)

View file

@ -78,7 +78,7 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
"server_name", "server_name",
"name", "name",
] ]
self.assertEqual(set(log.keys()), set(expected_log_keys)) self.assertCountEqual(log.keys(), expected_log_keys)
# It contains the data we expect. # It contains the data we expect.
self.assertEqual(log["name"], "wally") self.assertEqual(log["name"], "wally")

View file

@ -158,8 +158,21 @@ class EmailPusherTests(HomeserverTestCase):
# We should get emailed about those messages # We should get emailed about those messages
self._check_for_mail() self._check_for_mail()
def test_encrypted_message(self):
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
self.helper.invite(
room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
)
self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
# The other user sends some messages
self.helper.send_event(room, "m.room.encrypted", {}, tok=self.others[0].token)
# We should get emailed about that message
self._check_for_mail()
def _check_for_mail(self): def _check_for_mail(self):
"Check that the user receives an email notification" """Check that the user receives an email notification"""
# Get the stream ordering before it gets sent # Get the stream ordering before it gets sent
pushers = self.get_success( pushers = self.get_success(

View file

@ -352,7 +352,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.render(request) self.render(request)
self.assertEqual(request.code, 401) self.assertEqual(request.code, 401)
@unittest.INFO
def test_pending_invites(self): def test_pending_invites(self):
"""Tests that deactivating a user rejects every pending invite for them.""" """Tests that deactivating a user rejects every pending invite for them."""
store = self.hs.get_datastore() store = self.hs.get_datastore()

View file

@ -104,7 +104,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
self.assertEqual(len(attempts), 1) self.assertEqual(len(attempts), 1)
self.assertEqual(attempts[0][0]["response"], "a") self.assertEqual(attempts[0][0]["response"], "a")
@unittest.INFO
def test_fallback_captcha(self): def test_fallback_captcha(self):
"""Ensure that fallback auth via a captcha works.""" """Ensure that fallback auth via a captcha works."""
# Returns a 401 as per the spec # Returns a 401 as per the spec

View file

@ -15,237 +15,9 @@
# limitations under the License. # limitations under the License.
from mock import Mock
from twisted.internet import defer
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached
from tests import unittest from tests import unittest
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def test_passthrough(self):
class A:
@cached()
def func(self, key):
return key
a = A()
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals((yield a.func("bar")), "bar")
@defer.inlineCallbacks
def test_hit(self):
callcount = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals(callcount[0], 1)
@defer.inlineCallbacks
def test_invalidate(self):
callcount = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self):
class A:
@cached()
def func(self, key):
return key
A().func.invalidate(("what",))
@defer.inlineCallbacks
def test_max_entries(self):
callcount = [0]
class A:
@cached(max_entries=10)
def func(self, key):
callcount[0] += 1
return key
a = A()
for k in range(0, 12):
yield a.func(k)
self.assertEquals(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times
for k in range(0, 12):
yield a.func(k)
self.assertTrue(
callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
)
def test_prefill(self):
callcount = [0]
d = defer.succeed(123)
class A:
@cached()
def func(self, key):
callcount[0] += 1
return d
a = A()
a.func.prefill(("foo",), ObservableDeferred(d))
self.assertEquals(a.func("foo").result, d.result)
self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
def test_invalidate_context(self):
callcount = [0]
callcount2 = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 1)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
@defer.inlineCallbacks
def test_eviction_context(self):
callcount = [0]
callcount2 = [0]
class A:
@cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
yield a.func2("foo2")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 4)
self.assertEquals(callcount2[0], 3)
@defer.inlineCallbacks
def test_double_get(self):
callcount = [0]
callcount2 = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
yield a.func2("foo")
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 2)
a.func.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 3)
class UpsertManyTests(unittest.HomeserverTestCase): class UpsertManyTests(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs): def prepare(self, reactor, clock, hs):
self.storage = hs.get_datastore() self.storage = hs.get_datastore()

View file

@ -13,15 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import unittest
from functools import partial from functools import partial
from twisted.internet import defer from twisted.internet import defer
from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.deferred_cache import DeferredCache
from tests.unittest import TestCase
class DeferredCacheTestCase(unittest.TestCase):
class DeferredCacheTestCase(TestCase):
def test_empty(self): def test_empty(self):
cache = DeferredCache("test") cache = DeferredCache("test")
failed = False failed = False
@ -36,7 +37,118 @@ class DeferredCacheTestCase(unittest.TestCase):
cache = DeferredCache("test") cache = DeferredCache("test")
cache.prefill("foo", 123) cache.prefill("foo", 123)
self.assertEquals(cache.get("foo"), 123) self.assertEquals(self.successResultOf(cache.get("foo")), 123)
def test_hit_deferred(self):
cache = DeferredCache("test")
origin_d = defer.Deferred()
set_d = cache.set("k1", origin_d)
# get should return an incomplete deferred
get_d = cache.get("k1")
self.assertFalse(get_d.called)
# add a callback that will make sure that the set_d gets called before the get_d
def check1(r):
self.assertTrue(set_d.called)
return r
# TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
# maybe we should fix that?
# get_d.addCallback(check1)
# now fire off all the deferreds
origin_d.callback(99)
self.assertEqual(self.successResultOf(origin_d), 99)
self.assertEqual(self.successResultOf(set_d), 99)
self.assertEqual(self.successResultOf(get_d), 99)
def test_callbacks(self):
"""Invalidation callbacks are called at the right time"""
cache = DeferredCache("test")
callbacks = set()
# start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
# now replace that entry with a pending result
origin_d = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
# ... and also make a get request
get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
# we don't expect the invalidation callback for the original value to have
# been called yet, even though get() will now return a different result.
# I'm not sure if that is by design or not.
self.assertEqual(callbacks, set())
# now fire off all the deferreds
origin_d.callback(20)
self.assertEqual(self.successResultOf(set_d), 20)
self.assertEqual(self.successResultOf(get_d), 20)
# now the original invalidation callback should have been called, but none of
# the others
self.assertEqual(callbacks, {"prefill"})
callbacks.clear()
# another update should invalidate both the previous results
cache.prefill("k1", 30)
self.assertEqual(callbacks, {"set", "get"})
def test_set_fail(self):
cache = DeferredCache("test")
callbacks = set()
# start with an entry, with a callback
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
# now replace that entry with a pending result
origin_d = defer.Deferred()
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
# ... and also make a get request
get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
# none of the callbacks should have been called yet
self.assertEqual(callbacks, set())
# oh noes! fails!
e = Exception("oops")
origin_d.errback(e)
self.assertIs(self.failureResultOf(set_d, Exception).value, e)
self.assertIs(self.failureResultOf(get_d, Exception).value, e)
# the callbacks for the failed requests should have been called.
# I'm not sure if this is deliberate or not.
self.assertEqual(callbacks, {"get", "set"})
callbacks.clear()
# the old value should still be returned now?
get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2"))
self.assertEqual(self.successResultOf(get_d2), 10)
# replacing the value now should run the callbacks for those requests
# which got the original result
cache.prefill("k1", 30)
self.assertEqual(callbacks, {"prefill", "get2"})
def test_get_immediate(self):
cache = DeferredCache("test")
d1 = defer.Deferred()
cache.set("key1", d1)
# get_immediate should return default
v = cache.get_immediate("key1", 1)
self.assertEqual(v, 1)
# now complete the set
d1.callback(2)
# get_immediate should return result
v = cache.get_immediate("key1", 1)
self.assertEqual(v, 2)
def test_invalidate(self): def test_invalidate(self):
cache = DeferredCache("test") cache = DeferredCache("test")
@ -66,23 +178,24 @@ class DeferredCacheTestCase(unittest.TestCase):
d2 = defer.Deferred() d2 = defer.Deferred()
cache.set("key2", d2, partial(record_callback, 1)) cache.set("key2", d2, partial(record_callback, 1))
# lookup should return observable deferreds # lookup should return pending deferreds
self.assertFalse(cache.get("key1").has_called()) self.assertFalse(cache.get("key1").called)
self.assertFalse(cache.get("key2").has_called()) self.assertFalse(cache.get("key2").called)
# let one of the lookups complete # let one of the lookups complete
d2.callback("result2") d2.callback("result2")
# for now at least, the cache will return real results rather than an # now the cache will return a completed deferred
# observabledeferred self.assertEqual(self.successResultOf(cache.get("key2")), "result2")
self.assertEqual(cache.get("key2"), "result2")
# now do the invalidation # now do the invalidation
cache.invalidate_all() cache.invalidate_all()
# lookup should return none # lookup should fail
self.assertIsNone(cache.get("key1", None)) with self.assertRaises(KeyError):
self.assertIsNone(cache.get("key2", None)) cache.get("key1")
with self.assertRaises(KeyError):
cache.get("key2")
# both callbacks should have been callbacked # both callbacks should have been callbacked
self.assertTrue(callback_record[0], "Invalidation callback for key1 not called") self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
@ -90,7 +203,8 @@ class DeferredCacheTestCase(unittest.TestCase):
# letting the other lookup complete should do nothing # letting the other lookup complete should do nothing
d1.callback("result1") d1.callback("result1")
self.assertIsNone(cache.get("key1", None)) with self.assertRaises(KeyError):
cache.get("key1", None)
def test_eviction(self): def test_eviction(self):
cache = DeferredCache( cache = DeferredCache(

View file

@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Set
import mock import mock
@ -130,6 +131,57 @@ class DescriptorTestCase(unittest.TestCase):
d = obj.fn(1) d = obj.fn(1)
self.failureResultOf(d, SynapseError) self.failureResultOf(d, SynapseError)
def test_cache_with_async_exception(self):
"""The wrapped function returns a failure
"""
class Cls:
result = None
call_count = 0
@cached()
def fn(self, arg1):
self.call_count += 1
return self.result
obj = Cls()
callbacks = set() # type: Set[str]
# set off an asynchronous request
obj.result = origin_d = defer.Deferred()
d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
self.assertFalse(d1.called)
# a second request should also return a deferred, but should not call the
# function itself.
d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2"))
self.assertFalse(d2.called)
self.assertEqual(obj.call_count, 1)
# no callbacks yet
self.assertEqual(callbacks, set())
# the original request fails
e = Exception("bzz")
origin_d.errback(e)
# ... which should cause the lookups to fail similarly
self.assertIs(self.failureResultOf(d1, Exception).value, e)
self.assertIs(self.failureResultOf(d2, Exception).value, e)
# ... and the callbacks to have been, uh, called.
self.assertEqual(callbacks, {"d1", "d2"})
# ... leaving the cache empty
self.assertEqual(len(obj.fn.cache.cache), 0)
# and a second call should work as normal
obj.result = defer.succeed(100)
d3 = obj.fn(1)
self.assertEqual(self.successResultOf(d3), 100)
self.assertEqual(obj.call_count, 2)
def test_cache_logcontexts(self): def test_cache_logcontexts(self):
"""Check that logcontexts are set and restored correctly when """Check that logcontexts are set and restored correctly when
using the cache.""" using the cache."""
@ -311,6 +363,235 @@ class DescriptorTestCase(unittest.TestCase):
self.failureResultOf(d, SynapseError) self.failureResultOf(d, SynapseError)
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
"""More tests for @cached
The following is a set of tests that got lost in a different file for a while.
There are probably duplicates of the tests in DescriptorTestCase. Ideally the
duplicates would be removed and the two sets of classes combined.
"""
@defer.inlineCallbacks
def test_passthrough(self):
class A:
@cached()
def func(self, key):
return key
a = A()
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals((yield a.func("bar")), "bar")
@defer.inlineCallbacks
def test_hit(self):
callcount = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals((yield a.func("foo")), "foo")
self.assertEquals(callcount[0], 1)
@defer.inlineCallbacks
def test_invalidate(self):
callcount = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
a = A()
yield a.func("foo")
self.assertEquals(callcount[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
def test_invalidate_missing(self):
class A:
@cached()
def func(self, key):
return key
A().func.invalidate(("what",))
@defer.inlineCallbacks
def test_max_entries(self):
callcount = [0]
class A:
@cached(max_entries=10)
def func(self, key):
callcount[0] += 1
return key
a = A()
for k in range(0, 12):
yield a.func(k)
self.assertEquals(callcount[0], 12)
# There must have been at least 2 evictions, meaning if we calculate
# all 12 values again, we must get called at least 2 more times
for k in range(0, 12):
yield a.func(k)
self.assertTrue(
callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
)
def test_prefill(self):
callcount = [0]
d = defer.succeed(123)
class A:
@cached()
def func(self, key):
callcount[0] += 1
return d
a = A()
a.func.prefill(("foo",), 456)
self.assertEquals(a.func("foo").result, 456)
self.assertEquals(callcount[0], 0)
@defer.inlineCallbacks
def test_invalidate_context(self):
callcount = [0]
callcount2 = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func.invalidate(("foo",))
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 1)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
@defer.inlineCallbacks
def test_eviction_context(self):
callcount = [0]
callcount2 = [0]
class A:
@cached(max_entries=2)
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
yield a.func2("foo")
yield a.func2("foo2")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 4)
self.assertEquals(callcount2[0], 3)
@defer.inlineCallbacks
def test_double_get(self):
callcount = [0]
callcount2 = [0]
class A:
@cached()
def func(self, key):
callcount[0] += 1
return key
@cached(cache_context=True)
def func2(self, key, cache_context):
callcount2[0] += 1
return self.func(key, on_invalidate=cache_context.invalidate)
a = A()
a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache)
yield a.func2("foo")
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 1)
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
yield a.func2("foo")
a.func2.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
self.assertEquals(callcount[0], 1)
self.assertEquals(callcount2[0], 2)
a.func.invalidate(("foo",))
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
yield a.func("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 3)
class CachedListDescriptorTestCase(unittest.TestCase): class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks @defer.inlineCallbacks
def test_cache(self): def test_cache(self):

View file

@ -19,7 +19,8 @@ from mock import Mock
from synapse.util.caches.lrucache import LruCache from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache from synapse.util.caches.treecache import TreeCache
from .. import unittest from tests import unittest
from tests.unittest import override_config
class LruCacheTestCase(unittest.HomeserverTestCase): class LruCacheTestCase(unittest.HomeserverTestCase):
@ -59,7 +60,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
self.assertEquals(cache.pop("key"), None) self.assertEquals(cache.pop("key"), None)
def test_del_multi(self): def test_del_multi(self):
cache = LruCache(4, 2, cache_type=TreeCache) cache = LruCache(4, keylen=2, cache_type=TreeCache)
cache[("animal", "cat")] = "mew" cache[("animal", "cat")] = "mew"
cache[("animal", "dog")] = "woof" cache[("animal", "dog")] = "woof"
cache[("vehicles", "car")] = "vroom" cache[("vehicles", "car")] = "vroom"
@ -83,6 +84,11 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
cache.clear() cache.clear()
self.assertEquals(len(cache), 0) self.assertEquals(len(cache), 0)
@override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
def test_special_size(self):
cache = LruCache(10, "mycache")
self.assertEqual(cache.max_size, 100)
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
def test_get(self): def test_get(self):
@ -160,7 +166,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
m2 = Mock() m2 = Mock()
m3 = Mock() m3 = Mock()
m4 = Mock() m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache) cache = LruCache(4, keylen=2, cache_type=TreeCache)
cache.set(("a", "1"), "value", callbacks=[m1]) cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", callbacks=[m2]) cache.set(("a", "2"), "value", callbacks=[m2])

View file

@ -158,12 +158,9 @@ commands=
coverage html coverage html
[testenv:mypy] [testenv:mypy]
skip_install = True
deps = deps =
{[base]deps} {[base]deps}
mypy==0.782 extras = all,mypy
mypy-zope
extras = all
commands = mypy commands = mypy
# To find all folders that pass mypy you run: # To find all folders that pass mypy you run: