mirror of
https://github.com/element-hq/synapse
synced 2024-09-12 03:35:10 +00:00
Merge remote-tracking branch 'origin/release-v1.21.3' into matrix-org-hotfixes
This commit is contained in:
commit
ab4cd7f802
84 changed files with 1264 additions and 781 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -21,6 +21,7 @@ _trial_temp*/
|
||||||
/.python-version
|
/.python-version
|
||||||
/*.signing.key
|
/*.signing.key
|
||||||
/env/
|
/env/
|
||||||
|
/.venv*/
|
||||||
/homeserver*.yaml
|
/homeserver*.yaml
|
||||||
/logs
|
/logs
|
||||||
/media_store/
|
/media_store/
|
||||||
|
|
1
changelog.d/8504.bugfix
Normal file
1
changelog.d/8504.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Expose the `uk.half-shot.msc2778.login.application_service` to clients from the login API. This feature was added in v1.21.0, but was not exposed as a potential login flow.
|
1
changelog.d/8544.feature
Normal file
1
changelog.d/8544.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Allow running background tasks in a separate worker process.
|
1
changelog.d/8545.bugfix
Normal file
1
changelog.d/8545.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix a long standing bug where email notifications for encrypted messages were blank.
|
1
changelog.d/8561.misc
Normal file
1
changelog.d/8561.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Move metric registration code down into `LruCache`.
|
1
changelog.d/8562.misc
Normal file
1
changelog.d/8562.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add type annotations for `LruCache`.
|
1
changelog.d/8563.misc
Normal file
1
changelog.d/8563.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Replace `DeferredCache` with the lighter-weight `LruCache` where possible.
|
1
changelog.d/8564.feature
Normal file
1
changelog.d/8564.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Support modifying event content in `ThirdPartyRules` modules.
|
1
changelog.d/8566.misc
Normal file
1
changelog.d/8566.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add virtualenv-generated folders to `.gitignore`.
|
1
changelog.d/8567.bugfix
Normal file
1
changelog.d/8567.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix increase in the number of `There was no active span...` errors logged when using OpenTracing.
|
1
changelog.d/8568.misc
Normal file
1
changelog.d/8568.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add `get_immediate` method to `DeferredCache`.
|
1
changelog.d/8569.misc
Normal file
1
changelog.d/8569.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix mypy not properly checking across the codebase, additionally, fix a typing assertion error in `handlers/auth.py`.
|
1
changelog.d/8571.misc
Normal file
1
changelog.d/8571.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix `synmark` benchmark runner.
|
1
changelog.d/8572.misc
Normal file
1
changelog.d/8572.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Modify `DeferredCache.get()` to return `Deferred`s instead of `ObservableDeferred`s.
|
1
changelog.d/8577.misc
Normal file
1
changelog.d/8577.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Adjust a protocol-type definition to fit `sqlite3` assertions.
|
1
changelog.d/8578.misc
Normal file
1
changelog.d/8578.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Support macOS on the `synmark` benchmark runner.
|
1
changelog.d/8583.misc
Normal file
1
changelog.d/8583.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Update `mypy` static type checker to 0.790.
|
1
changelog.d/8585.bugfix
Normal file
1
changelog.d/8585.bugfix
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Fix a bug that prevented errors encountered during execution of the `synapse_port_db` from being correctly printed.
|
1
changelog.d/8587.misc
Normal file
1
changelog.d/8587.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Re-organize the structured logging code to separate the TCP transport handling from the JSON formatting.
|
1
changelog.d/8589.removal
Normal file
1
changelog.d/8589.removal
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Drop unused `device_max_stream_id` table.
|
1
changelog.d/8590.misc
Normal file
1
changelog.d/8590.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Implement [MSC2409](https://github.com/matrix-org/matrix-doc/pull/2409) to send typing, read receipts, and presence events to appservices.
|
1
changelog.d/8591.misc
Normal file
1
changelog.d/8591.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Move metric registration code down into `LruCache`.
|
1
changelog.d/8592.misc
Normal file
1
changelog.d/8592.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Remove extraneous unittest logging decorators from unit tests.
|
1
changelog.d/8593.misc
Normal file
1
changelog.d/8593.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Minor optimisations in caching code.
|
1
changelog.d/8594.misc
Normal file
1
changelog.d/8594.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Minor optimisations in caching code.
|
1
changelog.d/8599.feature
Normal file
1
changelog.d/8599.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Allow running background tasks in a separate worker process.
|
1
changelog.d/8600.misc
Normal file
1
changelog.d/8600.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Update `mypy` static type checker to 0.790.
|
1
changelog.d/8606.feature
Normal file
1
changelog.d/8606.feature
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Limit appservice transactions to 100 persistent and 100 ephemeral events.
|
1
changelog.d/8609.misc
Normal file
1
changelog.d/8609.misc
Normal file
|
@ -0,0 +1 @@
|
||||||
|
Add type hints to profile and base handler.
|
4
mypy.ini
4
mypy.ini
|
@ -15,8 +15,9 @@ files =
|
||||||
synapse/events/builder.py,
|
synapse/events/builder.py,
|
||||||
synapse/events/spamcheck.py,
|
synapse/events/spamcheck.py,
|
||||||
synapse/federation,
|
synapse/federation,
|
||||||
synapse/handlers/appservice.py,
|
synapse/handlers/_base.py,
|
||||||
synapse/handlers/account_data.py,
|
synapse/handlers/account_data.py,
|
||||||
|
synapse/handlers/appservice.py,
|
||||||
synapse/handlers/auth.py,
|
synapse/handlers/auth.py,
|
||||||
synapse/handlers/cas_handler.py,
|
synapse/handlers/cas_handler.py,
|
||||||
synapse/handlers/deactivate_account.py,
|
synapse/handlers/deactivate_account.py,
|
||||||
|
@ -32,6 +33,7 @@ files =
|
||||||
synapse/handlers/pagination.py,
|
synapse/handlers/pagination.py,
|
||||||
synapse/handlers/password_policy.py,
|
synapse/handlers/password_policy.py,
|
||||||
synapse/handlers/presence.py,
|
synapse/handlers/presence.py,
|
||||||
|
synapse/handlers/profile.py,
|
||||||
synapse/handlers/read_marker.py,
|
synapse/handlers/read_marker.py,
|
||||||
synapse/handlers/room.py,
|
synapse/handlers/room.py,
|
||||||
synapse/handlers/room_member.py,
|
synapse/handlers/room_member.py,
|
||||||
|
|
|
@ -22,6 +22,7 @@ import logging
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
@ -152,7 +153,7 @@ IGNORED_TABLES = {
|
||||||
|
|
||||||
# Error returned by the run function. Used at the top-level part of the script to
|
# Error returned by the run function. Used at the top-level part of the script to
|
||||||
# handle errors and return codes.
|
# handle errors and return codes.
|
||||||
end_error = None
|
end_error = None # type: Optional[str]
|
||||||
# The exec_info for the error, if any. If error is defined but not exec_info the script
|
# The exec_info for the error, if any. If error is defined but not exec_info the script
|
||||||
# will show only the error message without the stacktrace, if exec_info is defined but
|
# will show only the error message without the stacktrace, if exec_info is defined but
|
||||||
# not the error then the script will show nothing outside of what's printed in the run
|
# not the error then the script will show nothing outside of what's printed in the run
|
||||||
|
@ -635,7 +636,7 @@ class Porter(object):
|
||||||
self.progress.done()
|
self.progress.done()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
global end_error_exec_info
|
global end_error_exec_info
|
||||||
end_error = e
|
end_error = str(e)
|
||||||
end_error_exec_info = sys.exc_info()
|
end_error_exec_info = sys.exc_info()
|
||||||
logger.exception("")
|
logger.exception("")
|
||||||
finally:
|
finally:
|
||||||
|
|
2
setup.py
2
setup.py
|
@ -102,6 +102,8 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
|
||||||
"flake8",
|
"flake8",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
|
||||||
|
|
||||||
# Dependencies which are exclusively required by unit test code. This is
|
# Dependencies which are exclusively required by unit test code. This is
|
||||||
# NOT a list of all modules that are necessary to run the unit tests.
|
# NOT a list of all modules that are necessary to run the unit tests.
|
||||||
# Tests assume that all optional dependencies are installed.
|
# Tests assume that all optional dependencies are installed.
|
||||||
|
|
|
@ -34,7 +34,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.logging import opentracing as opentracing
|
from synapse.logging import opentracing as opentracing
|
||||||
from synapse.types import StateMap, UserID
|
from synapse.types import StateMap, UserID
|
||||||
from synapse.util.caches import register_cache
|
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
|
@ -70,8 +69,9 @@ class Auth:
|
||||||
self.store = hs.get_datastore()
|
self.store = hs.get_datastore()
|
||||||
self.state = hs.get_state_handler()
|
self.state = hs.get_state_handler()
|
||||||
|
|
||||||
self.token_cache = LruCache(10000)
|
self.token_cache = LruCache(
|
||||||
register_cache("cache", "token_cache", self.token_cache)
|
10000, "token_cache"
|
||||||
|
) # type: LruCache[str, Tuple[str, bool]]
|
||||||
|
|
||||||
self._auth_blocking = AuthBlocking(self.hs)
|
self._auth_blocking = AuthBlocking(self.hs)
|
||||||
|
|
||||||
|
|
|
@ -60,6 +60,13 @@ from synapse.types import JsonDict
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# Maximum number of events to provide in an AS transaction.
|
||||||
|
MAX_PERSISTENT_EVENTS_PER_TRANSACTION = 100
|
||||||
|
|
||||||
|
# Maximum number of ephemeral events to provide in an AS transaction.
|
||||||
|
MAX_EPHEMERAL_EVENTS_PER_TRANSACTION = 100
|
||||||
|
|
||||||
|
|
||||||
class ApplicationServiceScheduler:
|
class ApplicationServiceScheduler:
|
||||||
""" Public facing API for this module. Does the required DI to tie the
|
""" Public facing API for this module. Does the required DI to tie the
|
||||||
components together. This also serves as the "event_pool", which in this
|
components together. This also serves as the "event_pool", which in this
|
||||||
|
@ -136,10 +143,17 @@ class _ServiceQueuer:
|
||||||
self.requests_in_flight.add(service.id)
|
self.requests_in_flight.add(service.id)
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
events = self.queued_events.pop(service.id, [])
|
all_events = self.queued_events.get(service.id, [])
|
||||||
ephemeral = self.queued_ephemeral.pop(service.id, [])
|
events = all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
|
||||||
|
del all_events[:MAX_PERSISTENT_EVENTS_PER_TRANSACTION]
|
||||||
|
|
||||||
|
all_events_ephemeral = self.queued_ephemeral.get(service.id, [])
|
||||||
|
ephemeral = all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
|
||||||
|
del all_events_ephemeral[:MAX_EPHEMERAL_EVENTS_PER_TRANSACTION]
|
||||||
|
|
||||||
if not events and not ephemeral:
|
if not events and not ephemeral:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.txn_ctrl.send(service, events, ephemeral)
|
await self.txn_ctrl.send(service, events, ephemeral)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
|
@ -98,7 +98,7 @@ class EventBuilder:
|
||||||
return self._state_key is not None
|
return self._state_key is not None
|
||||||
|
|
||||||
async def build(
|
async def build(
|
||||||
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]]
|
self, prev_event_ids: List[str], auth_event_ids: Optional[List[str]],
|
||||||
) -> EventBase:
|
) -> EventBase:
|
||||||
"""Transform into a fully signed and hashed event
|
"""Transform into a fully signed and hashed event
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import synapse.state
|
import synapse.state
|
||||||
import synapse.storage
|
import synapse.storage
|
||||||
|
@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.api.ratelimiting import Ratelimiter
|
from synapse.api.ratelimiting import Ratelimiter
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -30,11 +34,7 @@ class BaseHandler:
|
||||||
Common base class for the event handlers.
|
Common base class for the event handlers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
"""
|
|
||||||
Args:
|
|
||||||
hs (synapse.server.HomeServer):
|
|
||||||
"""
|
|
||||||
self.store = hs.get_datastore() # type: synapse.storage.DataStore
|
self.store = hs.get_datastore() # type: synapse.storage.DataStore
|
||||||
self.auth = hs.get_auth()
|
self.auth = hs.get_auth()
|
||||||
self.notifier = hs.get_notifier()
|
self.notifier = hs.get_notifier()
|
||||||
|
@ -56,7 +56,7 @@ class BaseHandler:
|
||||||
clock=self.clock,
|
clock=self.clock,
|
||||||
rate_hz=self.hs.config.rc_admin_redaction.per_second,
|
rate_hz=self.hs.config.rc_admin_redaction.per_second,
|
||||||
burst_count=self.hs.config.rc_admin_redaction.burst_count,
|
burst_count=self.hs.config.rc_admin_redaction.burst_count,
|
||||||
)
|
) # type: Optional[Ratelimiter]
|
||||||
else:
|
else:
|
||||||
self.admin_redaction_ratelimiter = None
|
self.admin_redaction_ratelimiter = None
|
||||||
|
|
||||||
|
@ -127,15 +127,15 @@ class BaseHandler:
|
||||||
if guest_access != "can_join":
|
if guest_access != "can_join":
|
||||||
if context:
|
if context:
|
||||||
current_state_ids = await context.get_current_state_ids()
|
current_state_ids = await context.get_current_state_ids()
|
||||||
current_state = await self.store.get_events(
|
current_state_dict = await self.store.get_events(
|
||||||
list(current_state_ids.values())
|
list(current_state_ids.values())
|
||||||
)
|
)
|
||||||
|
current_state = list(current_state_dict.values())
|
||||||
else:
|
else:
|
||||||
current_state = await self.state_handler.get_current_state(
|
current_state_map = await self.state_handler.get_current_state(
|
||||||
event.room_id
|
event.room_id
|
||||||
)
|
)
|
||||||
|
current_state = list(current_state_map.values())
|
||||||
current_state = list(current_state.values())
|
|
||||||
|
|
||||||
logger.info("maybe_kick_guest_users %r", current_state)
|
logger.info("maybe_kick_guest_users %r", current_state)
|
||||||
await self.kick_guest_users(current_state)
|
await self.kick_guest_users(current_state)
|
||||||
|
|
|
@ -22,7 +22,7 @@ from typing import List
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.logging.context import make_deferred_yieldable
|
from synapse.logging.context import make_deferred_yieldable
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util import stringutils
|
from synapse.util import stringutils
|
||||||
|
|
||||||
|
@ -63,16 +63,10 @@ class AccountValidityHandler:
|
||||||
self._raw_from = email.utils.parseaddr(self._from_string)[1]
|
self._raw_from = email.utils.parseaddr(self._from_string)[1]
|
||||||
|
|
||||||
# Check the renewal emails to send and send them every 30min.
|
# Check the renewal emails to send and send them every 30min.
|
||||||
def send_emails():
|
|
||||||
# run as a background process to make sure that the database transactions
|
|
||||||
# have a logcontext to report to
|
|
||||||
return run_as_background_process(
|
|
||||||
"send_renewals", self._send_renewal_emails
|
|
||||||
)
|
|
||||||
|
|
||||||
if hs.config.run_background_tasks:
|
if hs.config.run_background_tasks:
|
||||||
self.clock.looping_call(send_emails, 30 * 60 * 1000)
|
self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
|
||||||
|
|
||||||
|
@wrap_as_background_process("send_renewals")
|
||||||
async def _send_renewal_emails(self):
|
async def _send_renewal_emails(self):
|
||||||
"""Gets the list of users whose account is expiring in the amount of time
|
"""Gets the list of users whose account is expiring in the amount of time
|
||||||
configured in the ``renew_at`` parameter from the ``account_validity``
|
configured in the ``renew_at`` parameter from the ``account_validity``
|
||||||
|
|
|
@ -1122,20 +1122,22 @@ class AuthHandler(BaseHandler):
|
||||||
Whether self.hash(password) == stored_hash.
|
Whether self.hash(password) == stored_hash.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _do_validate_hash():
|
def _do_validate_hash(checked_hash: bytes):
|
||||||
# Normalise the Unicode in the password
|
# Normalise the Unicode in the password
|
||||||
pw = unicodedata.normalize("NFKC", password)
|
pw = unicodedata.normalize("NFKC", password)
|
||||||
|
|
||||||
return bcrypt.checkpw(
|
return bcrypt.checkpw(
|
||||||
pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
|
pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
|
||||||
stored_hash,
|
checked_hash,
|
||||||
)
|
)
|
||||||
|
|
||||||
if stored_hash:
|
if stored_hash:
|
||||||
if not isinstance(stored_hash, bytes):
|
if not isinstance(stored_hash, bytes):
|
||||||
stored_hash = stored_hash.encode("ascii")
|
stored_hash = stored_hash.encode("ascii")
|
||||||
|
|
||||||
return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
|
return await defer_to_thread(
|
||||||
|
self.hs.get_reactor(), _do_validate_hash, stored_hash
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
@ -293,6 +293,10 @@ class InitialSyncHandler(BaseHandler):
|
||||||
user_id, room_id, pagin_config, membership, is_peeking
|
user_id, room_id, pagin_config, membership, is_peeking
|
||||||
)
|
)
|
||||||
elif membership == Membership.LEAVE:
|
elif membership == Membership.LEAVE:
|
||||||
|
# The member_event_id will always be available if membership is set
|
||||||
|
# to leave.
|
||||||
|
assert member_event_id
|
||||||
|
|
||||||
result = await self._room_initial_sync_parted(
|
result = await self._room_initial_sync_parted(
|
||||||
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
|
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
|
||||||
)
|
)
|
||||||
|
@ -315,7 +319,7 @@ class InitialSyncHandler(BaseHandler):
|
||||||
user_id: str,
|
user_id: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
pagin_config: PaginationConfig,
|
pagin_config: PaginationConfig,
|
||||||
membership: Membership,
|
membership: str,
|
||||||
member_event_id: str,
|
member_event_id: str,
|
||||||
is_peeking: bool,
|
is_peeking: bool,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
|
@ -367,7 +371,7 @@ class InitialSyncHandler(BaseHandler):
|
||||||
user_id: str,
|
user_id: str,
|
||||||
room_id: str,
|
room_id: str,
|
||||||
pagin_config: PaginationConfig,
|
pagin_config: PaginationConfig,
|
||||||
membership: Membership,
|
membership: str,
|
||||||
is_peeking: bool,
|
is_peeking: bool,
|
||||||
) -> JsonDict:
|
) -> JsonDict:
|
||||||
current_state = await self.state.get_current_state(room_id=room_id)
|
current_state = await self.state.get_current_state(room_id=room_id)
|
||||||
|
|
|
@ -1364,7 +1364,12 @@ class EventCreationHandler:
|
||||||
for k, v in original_event.internal_metadata.get_dict().items():
|
for k, v in original_event.internal_metadata.get_dict().items():
|
||||||
setattr(builder.internal_metadata, k, v)
|
setattr(builder.internal_metadata, k, v)
|
||||||
|
|
||||||
event = await builder.build(prev_event_ids=original_event.prev_event_ids())
|
# the event type hasn't changed, so there's no point in re-calculating the
|
||||||
|
# auth events.
|
||||||
|
event = await builder.build(
|
||||||
|
prev_event_ids=original_event.prev_event_ids(),
|
||||||
|
auth_event_ids=original_event.auth_event_ids(),
|
||||||
|
)
|
||||||
|
|
||||||
# we rebuild the event context, to be on the safe side. If nothing else,
|
# we rebuild the event context, to be on the safe side. If nothing else,
|
||||||
# delta_ids might need an update.
|
# delta_ids might need an update.
|
||||||
|
|
|
@ -12,9 +12,9 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import random
|
import random
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from synapse.api.errors import (
|
from synapse.api.errors import (
|
||||||
AuthError,
|
AuthError,
|
||||||
|
@ -24,11 +24,20 @@ from synapse.api.errors import (
|
||||||
StoreError,
|
StoreError,
|
||||||
SynapseError,
|
SynapseError,
|
||||||
)
|
)
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.types import UserID, create_requester, get_domain_from_id
|
from synapse.types import (
|
||||||
|
JsonDict,
|
||||||
|
Requester,
|
||||||
|
UserID,
|
||||||
|
create_requester,
|
||||||
|
get_domain_from_id,
|
||||||
|
)
|
||||||
|
|
||||||
from ._base import BaseHandler
|
from ._base import BaseHandler
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from synapse.app.homeserver import HomeServer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MAX_DISPLAYNAME_LEN = 256
|
MAX_DISPLAYNAME_LEN = 256
|
||||||
|
@ -45,7 +54,7 @@ class ProfileHandler(BaseHandler):
|
||||||
PROFILE_UPDATE_MS = 60 * 1000
|
PROFILE_UPDATE_MS = 60 * 1000
|
||||||
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
|
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
|
||||||
|
|
||||||
def __init__(self, hs):
|
def __init__(self, hs: "HomeServer"):
|
||||||
super().__init__(hs)
|
super().__init__(hs)
|
||||||
|
|
||||||
self.federation = hs.get_federation_client()
|
self.federation = hs.get_federation_client()
|
||||||
|
@ -57,10 +66,10 @@ class ProfileHandler(BaseHandler):
|
||||||
|
|
||||||
if hs.config.run_background_tasks:
|
if hs.config.run_background_tasks:
|
||||||
self.clock.looping_call(
|
self.clock.looping_call(
|
||||||
self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
|
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_profile(self, user_id):
|
async def get_profile(self, user_id: str) -> JsonDict:
|
||||||
target_user = UserID.from_string(user_id)
|
target_user = UserID.from_string(user_id)
|
||||||
|
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
|
@ -91,7 +100,7 @@ class ProfileHandler(BaseHandler):
|
||||||
except HttpResponseException as e:
|
except HttpResponseException as e:
|
||||||
raise e.to_synapse_error()
|
raise e.to_synapse_error()
|
||||||
|
|
||||||
async def get_profile_from_cache(self, user_id):
|
async def get_profile_from_cache(self, user_id: str) -> JsonDict:
|
||||||
"""Get the profile information from our local cache. If the user is
|
"""Get the profile information from our local cache. If the user is
|
||||||
ours then the profile information will always be corect. Otherwise,
|
ours then the profile information will always be corect. Otherwise,
|
||||||
it may be out of date/missing.
|
it may be out of date/missing.
|
||||||
|
@ -115,7 +124,7 @@ class ProfileHandler(BaseHandler):
|
||||||
profile = await self.store.get_from_remote_profile_cache(user_id)
|
profile = await self.store.get_from_remote_profile_cache(user_id)
|
||||||
return profile or {}
|
return profile or {}
|
||||||
|
|
||||||
async def get_displayname(self, target_user):
|
async def get_displayname(self, target_user: UserID) -> str:
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
try:
|
try:
|
||||||
displayname = await self.store.get_profile_displayname(
|
displayname = await self.store.get_profile_displayname(
|
||||||
|
@ -143,15 +152,19 @@ class ProfileHandler(BaseHandler):
|
||||||
return result["displayname"]
|
return result["displayname"]
|
||||||
|
|
||||||
async def set_displayname(
|
async def set_displayname(
|
||||||
self, target_user, requester, new_displayname, by_admin=False
|
self,
|
||||||
):
|
target_user: UserID,
|
||||||
|
requester: Requester,
|
||||||
|
new_displayname: str,
|
||||||
|
by_admin: bool = False,
|
||||||
|
) -> None:
|
||||||
"""Set the displayname of a user
|
"""Set the displayname of a user
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_user (UserID): the user whose displayname is to be changed.
|
target_user: the user whose displayname is to be changed.
|
||||||
requester (Requester): The user attempting to make this change.
|
requester: The user attempting to make this change.
|
||||||
new_displayname (str): The displayname to give this user.
|
new_displayname: The displayname to give this user.
|
||||||
by_admin (bool): Whether this change was made by an administrator.
|
by_admin: Whether this change was made by an administrator.
|
||||||
"""
|
"""
|
||||||
if not self.hs.is_mine(target_user):
|
if not self.hs.is_mine(target_user):
|
||||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||||
|
@ -176,8 +189,9 @@ class ProfileHandler(BaseHandler):
|
||||||
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
|
400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
displayname_to_set = new_displayname # type: Optional[str]
|
||||||
if new_displayname == "":
|
if new_displayname == "":
|
||||||
new_displayname = None
|
displayname_to_set = None
|
||||||
|
|
||||||
# If the admin changes the display name of a user, the requesting user cannot send
|
# If the admin changes the display name of a user, the requesting user cannot send
|
||||||
# the join event to update the displayname in the rooms.
|
# the join event to update the displayname in the rooms.
|
||||||
|
@ -185,7 +199,9 @@ class ProfileHandler(BaseHandler):
|
||||||
if by_admin:
|
if by_admin:
|
||||||
requester = create_requester(target_user)
|
requester = create_requester(target_user)
|
||||||
|
|
||||||
await self.store.set_profile_displayname(target_user.localpart, new_displayname)
|
await self.store.set_profile_displayname(
|
||||||
|
target_user.localpart, displayname_to_set
|
||||||
|
)
|
||||||
|
|
||||||
if self.hs.config.user_directory_search_all_users:
|
if self.hs.config.user_directory_search_all_users:
|
||||||
profile = await self.store.get_profileinfo(target_user.localpart)
|
profile = await self.store.get_profileinfo(target_user.localpart)
|
||||||
|
@ -195,7 +211,7 @@ class ProfileHandler(BaseHandler):
|
||||||
|
|
||||||
await self._update_join_states(requester, target_user)
|
await self._update_join_states(requester, target_user)
|
||||||
|
|
||||||
async def get_avatar_url(self, target_user):
|
async def get_avatar_url(self, target_user: UserID) -> str:
|
||||||
if self.hs.is_mine(target_user):
|
if self.hs.is_mine(target_user):
|
||||||
try:
|
try:
|
||||||
avatar_url = await self.store.get_profile_avatar_url(
|
avatar_url = await self.store.get_profile_avatar_url(
|
||||||
|
@ -222,15 +238,19 @@ class ProfileHandler(BaseHandler):
|
||||||
return result["avatar_url"]
|
return result["avatar_url"]
|
||||||
|
|
||||||
async def set_avatar_url(
|
async def set_avatar_url(
|
||||||
self, target_user, requester, new_avatar_url, by_admin=False
|
self,
|
||||||
|
target_user: UserID,
|
||||||
|
requester: Requester,
|
||||||
|
new_avatar_url: str,
|
||||||
|
by_admin: bool = False,
|
||||||
):
|
):
|
||||||
"""Set a new avatar URL for a user.
|
"""Set a new avatar URL for a user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_user (UserID): the user whose avatar URL is to be changed.
|
target_user: the user whose avatar URL is to be changed.
|
||||||
requester (Requester): The user attempting to make this change.
|
requester: The user attempting to make this change.
|
||||||
new_avatar_url (str): The avatar URL to give this user.
|
new_avatar_url: The avatar URL to give this user.
|
||||||
by_admin (bool): Whether this change was made by an administrator.
|
by_admin: Whether this change was made by an administrator.
|
||||||
"""
|
"""
|
||||||
if not self.hs.is_mine(target_user):
|
if not self.hs.is_mine(target_user):
|
||||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||||
|
@ -267,7 +287,7 @@ class ProfileHandler(BaseHandler):
|
||||||
|
|
||||||
await self._update_join_states(requester, target_user)
|
await self._update_join_states(requester, target_user)
|
||||||
|
|
||||||
async def on_profile_query(self, args):
|
async def on_profile_query(self, args: JsonDict) -> JsonDict:
|
||||||
user = UserID.from_string(args["user_id"])
|
user = UserID.from_string(args["user_id"])
|
||||||
if not self.hs.is_mine(user):
|
if not self.hs.is_mine(user):
|
||||||
raise SynapseError(400, "User is not hosted on this homeserver")
|
raise SynapseError(400, "User is not hosted on this homeserver")
|
||||||
|
@ -292,7 +312,9 @@ class ProfileHandler(BaseHandler):
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def _update_join_states(self, requester, target_user):
|
async def _update_join_states(
|
||||||
|
self, requester: Requester, target_user: UserID
|
||||||
|
) -> None:
|
||||||
if not self.hs.is_mine(target_user):
|
if not self.hs.is_mine(target_user):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -323,15 +345,17 @@ class ProfileHandler(BaseHandler):
|
||||||
"Failed to update join event for room %s - %s", room_id, str(e)
|
"Failed to update join event for room %s - %s", room_id, str(e)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def check_profile_query_allowed(self, target_user, requester=None):
|
async def check_profile_query_allowed(
|
||||||
|
self, target_user: UserID, requester: Optional[UserID] = None
|
||||||
|
) -> None:
|
||||||
"""Checks whether a profile query is allowed. If the
|
"""Checks whether a profile query is allowed. If the
|
||||||
'require_auth_for_profile_requests' config flag is set to True and a
|
'require_auth_for_profile_requests' config flag is set to True and a
|
||||||
'requester' is provided, the query is only allowed if the two users
|
'requester' is provided, the query is only allowed if the two users
|
||||||
share a room.
|
share a room.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_user (UserID): The owner of the queried profile.
|
target_user: The owner of the queried profile.
|
||||||
requester (None|UserID): The user querying for the profile.
|
requester: The user querying for the profile.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SynapseError(403): The two users share no room, or ne user couldn't
|
SynapseError(403): The two users share no room, or ne user couldn't
|
||||||
|
@ -370,11 +394,7 @@ class ProfileHandler(BaseHandler):
|
||||||
raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
|
raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _start_update_remote_profile_cache(self):
|
@wrap_as_background_process("Update remote profile")
|
||||||
return run_as_background_process(
|
|
||||||
"Update remote profile", self._update_remote_profile_cache
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _update_remote_profile_cache(self):
|
async def _update_remote_profile_cache(self):
|
||||||
"""Called periodically to check profiles of remote users we haven't
|
"""Called periodically to check profiles of remote users we haven't
|
||||||
checked in a while.
|
checked in a while.
|
||||||
|
|
225
synapse/logging/_remote.py
Normal file
225
synapse/logging/_remote.py
Normal file
|
@ -0,0 +1,225 @@
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from collections import deque
|
||||||
|
from ipaddress import IPv4Address, IPv6Address, ip_address
|
||||||
|
from math import floor
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import attr
|
||||||
|
from zope.interface import implementer
|
||||||
|
|
||||||
|
from twisted.application.internet import ClientService
|
||||||
|
from twisted.internet.defer import Deferred
|
||||||
|
from twisted.internet.endpoints import (
|
||||||
|
HostnameEndpoint,
|
||||||
|
TCP4ClientEndpoint,
|
||||||
|
TCP6ClientEndpoint,
|
||||||
|
)
|
||||||
|
from twisted.internet.interfaces import IPushProducer, ITransport
|
||||||
|
from twisted.internet.protocol import Factory, Protocol
|
||||||
|
from twisted.logger import ILogObserver, Logger, LogLevel
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
@implementer(IPushProducer)
|
||||||
|
class LogProducer:
|
||||||
|
"""
|
||||||
|
An IPushProducer that writes logs from its buffer to its transport when it
|
||||||
|
is resumed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
buffer: Log buffer to read logs from.
|
||||||
|
transport: Transport to write to.
|
||||||
|
format_event: A callable to format the log entry to a string.
|
||||||
|
"""
|
||||||
|
|
||||||
|
transport = attr.ib(type=ITransport)
|
||||||
|
format_event = attr.ib(type=Callable[[dict], str])
|
||||||
|
_buffer = attr.ib(type=deque)
|
||||||
|
_paused = attr.ib(default=False, type=bool, init=False)
|
||||||
|
|
||||||
|
def pauseProducing(self):
|
||||||
|
self._paused = True
|
||||||
|
|
||||||
|
def stopProducing(self):
|
||||||
|
self._paused = True
|
||||||
|
self._buffer = deque()
|
||||||
|
|
||||||
|
def resumeProducing(self):
|
||||||
|
self._paused = False
|
||||||
|
|
||||||
|
while self._paused is False and (self._buffer and self.transport.connected):
|
||||||
|
try:
|
||||||
|
# Request the next event and format it.
|
||||||
|
event = self._buffer.popleft()
|
||||||
|
msg = self.format_event(event)
|
||||||
|
|
||||||
|
# Send it as a new line over the transport.
|
||||||
|
self.transport.write(msg.encode("utf8"))
|
||||||
|
except Exception:
|
||||||
|
# Something has gone wrong writing to the transport -- log it
|
||||||
|
# and break out of the while.
|
||||||
|
traceback.print_exc(file=sys.__stderr__)
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s
|
||||||
|
@implementer(ILogObserver)
|
||||||
|
class TCPLogObserver:
|
||||||
|
"""
|
||||||
|
An IObserver that writes JSON logs to a TCP target.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hs (HomeServer): The homeserver that is being logged for.
|
||||||
|
host: The host of the logging target.
|
||||||
|
port: The logging target's port.
|
||||||
|
format_event: A callable to format the log entry to a string.
|
||||||
|
maximum_buffer: The maximum buffer size.
|
||||||
|
"""
|
||||||
|
|
||||||
|
hs = attr.ib()
|
||||||
|
host = attr.ib(type=str)
|
||||||
|
port = attr.ib(type=int)
|
||||||
|
format_event = attr.ib(type=Callable[[dict], str])
|
||||||
|
maximum_buffer = attr.ib(type=int)
|
||||||
|
_buffer = attr.ib(default=attr.Factory(deque), type=deque)
|
||||||
|
_connection_waiter = attr.ib(default=None, type=Optional[Deferred])
|
||||||
|
_logger = attr.ib(default=attr.Factory(Logger))
|
||||||
|
_producer = attr.ib(default=None, type=Optional[LogProducer])
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
|
||||||
|
# Connect without DNS lookups if it's a direct IP.
|
||||||
|
try:
|
||||||
|
ip = ip_address(self.host)
|
||||||
|
if isinstance(ip, IPv4Address):
|
||||||
|
endpoint = TCP4ClientEndpoint(
|
||||||
|
self.hs.get_reactor(), self.host, self.port
|
||||||
|
)
|
||||||
|
elif isinstance(ip, IPv6Address):
|
||||||
|
endpoint = TCP6ClientEndpoint(
|
||||||
|
self.hs.get_reactor(), self.host, self.port
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown IP address provided: %s" % (self.host,))
|
||||||
|
except ValueError:
|
||||||
|
endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
|
||||||
|
|
||||||
|
factory = Factory.forProtocol(Protocol)
|
||||||
|
self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
|
||||||
|
self._service.startService()
|
||||||
|
self._connect()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._service.stopService()
|
||||||
|
|
||||||
|
def _connect(self) -> None:
|
||||||
|
"""
|
||||||
|
Triggers an attempt to connect then write to the remote if not already writing.
|
||||||
|
"""
|
||||||
|
if self._connection_waiter:
|
||||||
|
return
|
||||||
|
|
||||||
|
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
|
||||||
|
|
||||||
|
@self._connection_waiter.addErrback
|
||||||
|
def fail(r):
|
||||||
|
r.printTraceback(file=sys.__stderr__)
|
||||||
|
self._connection_waiter = None
|
||||||
|
self._connect()
|
||||||
|
|
||||||
|
@self._connection_waiter.addCallback
|
||||||
|
def writer(r):
|
||||||
|
# We have a connection. If we already have a producer, and its
|
||||||
|
# transport is the same, just trigger a resumeProducing.
|
||||||
|
if self._producer and r.transport is self._producer.transport:
|
||||||
|
self._producer.resumeProducing()
|
||||||
|
self._connection_waiter = None
|
||||||
|
return
|
||||||
|
|
||||||
|
# If the producer is still producing, stop it.
|
||||||
|
if self._producer:
|
||||||
|
self._producer.stopProducing()
|
||||||
|
|
||||||
|
# Make a new producer and start it.
|
||||||
|
self._producer = LogProducer(
|
||||||
|
buffer=self._buffer,
|
||||||
|
transport=r.transport,
|
||||||
|
format_event=self.format_event,
|
||||||
|
)
|
||||||
|
r.transport.registerProducer(self._producer, True)
|
||||||
|
self._producer.resumeProducing()
|
||||||
|
self._connection_waiter = None
|
||||||
|
|
||||||
|
def _handle_pressure(self) -> None:
|
||||||
|
"""
|
||||||
|
Handle backpressure by shedding events.
|
||||||
|
|
||||||
|
The buffer will, in this order, until the buffer is below the maximum:
|
||||||
|
- Shed DEBUG events
|
||||||
|
- Shed INFO events
|
||||||
|
- Shed the middle 50% of the events.
|
||||||
|
"""
|
||||||
|
if len(self._buffer) <= self.maximum_buffer:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Strip out DEBUGs
|
||||||
|
self._buffer = deque(
|
||||||
|
filter(lambda event: event["log_level"] != LogLevel.debug, self._buffer)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(self._buffer) <= self.maximum_buffer:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Strip out INFOs
|
||||||
|
self._buffer = deque(
|
||||||
|
filter(lambda event: event["log_level"] != LogLevel.info, self._buffer)
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(self._buffer) <= self.maximum_buffer:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Cut the middle entries out
|
||||||
|
buffer_split = floor(self.maximum_buffer / 2)
|
||||||
|
|
||||||
|
old_buffer = self._buffer
|
||||||
|
self._buffer = deque()
|
||||||
|
|
||||||
|
for i in range(buffer_split):
|
||||||
|
self._buffer.append(old_buffer.popleft())
|
||||||
|
|
||||||
|
end_buffer = []
|
||||||
|
for i in range(buffer_split):
|
||||||
|
end_buffer.append(old_buffer.pop())
|
||||||
|
|
||||||
|
self._buffer.extend(reversed(end_buffer))
|
||||||
|
|
||||||
|
def __call__(self, event: dict) -> None:
|
||||||
|
self._buffer.append(event)
|
||||||
|
|
||||||
|
# Handle backpressure, if it exists.
|
||||||
|
try:
|
||||||
|
self._handle_pressure()
|
||||||
|
except Exception:
|
||||||
|
# If handling backpressure fails,clear the buffer and log the
|
||||||
|
# exception.
|
||||||
|
self._buffer.clear()
|
||||||
|
self._logger.failure("Failed clearing backpressure")
|
||||||
|
|
||||||
|
# Try and write immediately.
|
||||||
|
self._connect()
|
|
@ -18,26 +18,11 @@ Log formatters that output terse JSON.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import sys
|
from typing import IO
|
||||||
import traceback
|
|
||||||
from collections import deque
|
|
||||||
from ipaddress import IPv4Address, IPv6Address, ip_address
|
|
||||||
from math import floor
|
|
||||||
from typing import IO, Optional
|
|
||||||
|
|
||||||
import attr
|
from twisted.logger import FileLogObserver
|
||||||
from zope.interface import implementer
|
|
||||||
|
|
||||||
from twisted.application.internet import ClientService
|
from synapse.logging._remote import TCPLogObserver
|
||||||
from twisted.internet.defer import Deferred
|
|
||||||
from twisted.internet.endpoints import (
|
|
||||||
HostnameEndpoint,
|
|
||||||
TCP4ClientEndpoint,
|
|
||||||
TCP6ClientEndpoint,
|
|
||||||
)
|
|
||||||
from twisted.internet.interfaces import IPushProducer, ITransport
|
|
||||||
from twisted.internet.protocol import Factory, Protocol
|
|
||||||
from twisted.logger import FileLogObserver, ILogObserver, Logger
|
|
||||||
|
|
||||||
_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
|
_encoder = json.JSONEncoder(ensure_ascii=False, separators=(",", ":"))
|
||||||
|
|
||||||
|
@ -150,180 +135,22 @@ def TerseJSONToConsoleLogObserver(outFile: IO[str], metadata: dict) -> FileLogOb
|
||||||
return FileLogObserver(outFile, formatEvent)
|
return FileLogObserver(outFile, formatEvent)
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
def TerseJSONToTCPLogObserver(
|
||||||
@implementer(IPushProducer)
|
hs, host: str, port: int, metadata: dict, maximum_buffer: int
|
||||||
class LogProducer:
|
) -> FileLogObserver:
|
||||||
"""
|
"""
|
||||||
An IPushProducer that writes logs from its buffer to its transport when it
|
A log observer that formats events to a flattened JSON representation.
|
||||||
is resumed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
buffer: Log buffer to read logs from.
|
|
||||||
transport: Transport to write to.
|
|
||||||
"""
|
|
||||||
|
|
||||||
transport = attr.ib(type=ITransport)
|
|
||||||
_buffer = attr.ib(type=deque)
|
|
||||||
_paused = attr.ib(default=False, type=bool, init=False)
|
|
||||||
|
|
||||||
def pauseProducing(self):
|
|
||||||
self._paused = True
|
|
||||||
|
|
||||||
def stopProducing(self):
|
|
||||||
self._paused = True
|
|
||||||
self._buffer = deque()
|
|
||||||
|
|
||||||
def resumeProducing(self):
|
|
||||||
self._paused = False
|
|
||||||
|
|
||||||
while self._paused is False and (self._buffer and self.transport.connected):
|
|
||||||
try:
|
|
||||||
event = self._buffer.popleft()
|
|
||||||
self.transport.write(_encoder.encode(event).encode("utf8"))
|
|
||||||
self.transport.write(b"\n")
|
|
||||||
except Exception:
|
|
||||||
# Something has gone wrong writing to the transport -- log it
|
|
||||||
# and break out of the while.
|
|
||||||
traceback.print_exc(file=sys.__stderr__)
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
@attr.s
|
|
||||||
@implementer(ILogObserver)
|
|
||||||
class TerseJSONToTCPLogObserver:
|
|
||||||
"""
|
|
||||||
An IObserver that writes JSON logs to a TCP target.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hs (HomeServer): The homeserver that is being logged for.
|
hs (HomeServer): The homeserver that is being logged for.
|
||||||
host: The host of the logging target.
|
host: The host of the logging target.
|
||||||
port: The logging target's port.
|
port: The logging target's port.
|
||||||
metadata: Metadata to be added to each log entry.
|
metadata: Metadata to be added to each log object.
|
||||||
|
maximum_buffer: The maximum buffer size.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
hs = attr.ib()
|
def formatEvent(_event: dict) -> str:
|
||||||
host = attr.ib(type=str)
|
flattened = flatten_event(_event, metadata, include_time=True)
|
||||||
port = attr.ib(type=int)
|
return _encoder.encode(flattened) + "\n"
|
||||||
metadata = attr.ib(type=dict)
|
|
||||||
maximum_buffer = attr.ib(type=int)
|
|
||||||
_buffer = attr.ib(default=attr.Factory(deque), type=deque)
|
|
||||||
_connection_waiter = attr.ib(default=None, type=Optional[Deferred])
|
|
||||||
_logger = attr.ib(default=attr.Factory(Logger))
|
|
||||||
_producer = attr.ib(default=None, type=Optional[LogProducer])
|
|
||||||
|
|
||||||
def start(self) -> None:
|
return TCPLogObserver(hs, host, port, formatEvent, maximum_buffer)
|
||||||
|
|
||||||
# Connect without DNS lookups if it's a direct IP.
|
|
||||||
try:
|
|
||||||
ip = ip_address(self.host)
|
|
||||||
if isinstance(ip, IPv4Address):
|
|
||||||
endpoint = TCP4ClientEndpoint(
|
|
||||||
self.hs.get_reactor(), self.host, self.port
|
|
||||||
)
|
|
||||||
elif isinstance(ip, IPv6Address):
|
|
||||||
endpoint = TCP6ClientEndpoint(
|
|
||||||
self.hs.get_reactor(), self.host, self.port
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
endpoint = HostnameEndpoint(self.hs.get_reactor(), self.host, self.port)
|
|
||||||
|
|
||||||
factory = Factory.forProtocol(Protocol)
|
|
||||||
self._service = ClientService(endpoint, factory, clock=self.hs.get_reactor())
|
|
||||||
self._service.startService()
|
|
||||||
self._connect()
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
self._service.stopService()
|
|
||||||
|
|
||||||
def _connect(self) -> None:
|
|
||||||
"""
|
|
||||||
Triggers an attempt to connect then write to the remote if not already writing.
|
|
||||||
"""
|
|
||||||
if self._connection_waiter:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._connection_waiter = self._service.whenConnected(failAfterFailures=1)
|
|
||||||
|
|
||||||
@self._connection_waiter.addErrback
|
|
||||||
def fail(r):
|
|
||||||
r.printTraceback(file=sys.__stderr__)
|
|
||||||
self._connection_waiter = None
|
|
||||||
self._connect()
|
|
||||||
|
|
||||||
@self._connection_waiter.addCallback
|
|
||||||
def writer(r):
|
|
||||||
# We have a connection. If we already have a producer, and its
|
|
||||||
# transport is the same, just trigger a resumeProducing.
|
|
||||||
if self._producer and r.transport is self._producer.transport:
|
|
||||||
self._producer.resumeProducing()
|
|
||||||
self._connection_waiter = None
|
|
||||||
return
|
|
||||||
|
|
||||||
# If the producer is still producing, stop it.
|
|
||||||
if self._producer:
|
|
||||||
self._producer.stopProducing()
|
|
||||||
|
|
||||||
# Make a new producer and start it.
|
|
||||||
self._producer = LogProducer(buffer=self._buffer, transport=r.transport)
|
|
||||||
r.transport.registerProducer(self._producer, True)
|
|
||||||
self._producer.resumeProducing()
|
|
||||||
self._connection_waiter = None
|
|
||||||
|
|
||||||
def _handle_pressure(self) -> None:
|
|
||||||
"""
|
|
||||||
Handle backpressure by shedding events.
|
|
||||||
|
|
||||||
The buffer will, in this order, until the buffer is below the maximum:
|
|
||||||
- Shed DEBUG events
|
|
||||||
- Shed INFO events
|
|
||||||
- Shed the middle 50% of the events.
|
|
||||||
"""
|
|
||||||
if len(self._buffer) <= self.maximum_buffer:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Strip out DEBUGs
|
|
||||||
self._buffer = deque(
|
|
||||||
filter(lambda event: event["level"] != "DEBUG", self._buffer)
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(self._buffer) <= self.maximum_buffer:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Strip out INFOs
|
|
||||||
self._buffer = deque(
|
|
||||||
filter(lambda event: event["level"] != "INFO", self._buffer)
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(self._buffer) <= self.maximum_buffer:
|
|
||||||
return
|
|
||||||
|
|
||||||
# Cut the middle entries out
|
|
||||||
buffer_split = floor(self.maximum_buffer / 2)
|
|
||||||
|
|
||||||
old_buffer = self._buffer
|
|
||||||
self._buffer = deque()
|
|
||||||
|
|
||||||
for i in range(buffer_split):
|
|
||||||
self._buffer.append(old_buffer.popleft())
|
|
||||||
|
|
||||||
end_buffer = []
|
|
||||||
for i in range(buffer_split):
|
|
||||||
end_buffer.append(old_buffer.pop())
|
|
||||||
|
|
||||||
self._buffer.extend(reversed(end_buffer))
|
|
||||||
|
|
||||||
def __call__(self, event: dict) -> None:
|
|
||||||
flattened = flatten_event(event, self.metadata, include_time=True)
|
|
||||||
self._buffer.append(flattened)
|
|
||||||
|
|
||||||
# Handle backpressure, if it exists.
|
|
||||||
try:
|
|
||||||
self._handle_pressure()
|
|
||||||
except Exception:
|
|
||||||
# If handling backpressure fails,clear the buffer and log the
|
|
||||||
# exception.
|
|
||||||
self._buffer.clear()
|
|
||||||
self._logger.failure("Failed clearing backpressure")
|
|
||||||
|
|
||||||
# Try and write immediately.
|
|
||||||
self._connect()
|
|
||||||
|
|
|
@ -24,6 +24,7 @@ from prometheus_client.core import REGISTRY, Counter, Gauge
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
from synapse.logging.context import LoggingContext, PreserveLoggingContext
|
||||||
|
from synapse.logging.opentracing import start_active_span
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import resource
|
import resource
|
||||||
|
@ -197,14 +198,14 @@ def run_as_background_process(desc: str, func, *args, **kwargs):
|
||||||
|
|
||||||
with BackgroundProcessLoggingContext(desc) as context:
|
with BackgroundProcessLoggingContext(desc) as context:
|
||||||
context.request = "%s-%i" % (desc, count)
|
context.request = "%s-%i" % (desc, count)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result = func(*args, **kwargs)
|
with start_active_span(desc, tags={"request_id": context.request}):
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
|
||||||
if inspect.isawaitable(result):
|
if inspect.isawaitable(result):
|
||||||
result = await result
|
result = await result
|
||||||
|
|
||||||
return result
|
return result
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Background process '%s' threw an exception", desc,
|
"Background process '%s' threw an exception", desc,
|
||||||
|
|
|
@ -496,6 +496,6 @@ class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
|
||||||
# dedupe when we add callbacks to lru cache nodes, otherwise the number
|
# dedupe when we add callbacks to lru cache nodes, otherwise the number
|
||||||
# of callbacks would grow.
|
# of callbacks would grow.
|
||||||
def __call__(self):
|
def __call__(self):
|
||||||
rules = self.cache.get(self.room_id, None, update_metrics=False)
|
rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
|
||||||
if rules:
|
if rules:
|
||||||
rules.invalidate_all()
|
rules.invalidate_all()
|
||||||
|
|
|
@ -387,8 +387,8 @@ class Mailer:
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def get_message_vars(self, notif, event, room_state_ids):
|
async def get_message_vars(self, notif, event, room_state_ids):
|
||||||
if event.type != EventTypes.Message:
|
if event.type != EventTypes.Message and event.type != EventTypes.Encrypted:
|
||||||
return
|
return None
|
||||||
|
|
||||||
sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
|
sender_state_event_id = room_state_ids[("m.room.member", event.sender)]
|
||||||
sender_state_event = await self.store.get_event(sender_state_event_id)
|
sender_state_event = await self.store.get_event(sender_state_event_id)
|
||||||
|
@ -399,10 +399,8 @@ class Mailer:
|
||||||
# sender_hash % the number of default images to choose from
|
# sender_hash % the number of default images to choose from
|
||||||
sender_hash = string_ordinal_total(event.sender)
|
sender_hash = string_ordinal_total(event.sender)
|
||||||
|
|
||||||
msgtype = event.content.get("msgtype")
|
|
||||||
|
|
||||||
ret = {
|
ret = {
|
||||||
"msgtype": msgtype,
|
"event_type": event.type,
|
||||||
"is_historical": event.event_id != notif["event_id"],
|
"is_historical": event.event_id != notif["event_id"],
|
||||||
"id": event.event_id,
|
"id": event.event_id,
|
||||||
"ts": event.origin_server_ts,
|
"ts": event.origin_server_ts,
|
||||||
|
@ -411,6 +409,14 @@ class Mailer:
|
||||||
"sender_hash": sender_hash,
|
"sender_hash": sender_hash,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Encrypted messages don't have any additional useful information.
|
||||||
|
if event.type == EventTypes.Encrypted:
|
||||||
|
return ret
|
||||||
|
|
||||||
|
msgtype = event.content.get("msgtype")
|
||||||
|
|
||||||
|
ret["msgtype"] = msgtype
|
||||||
|
|
||||||
if msgtype == "m.text":
|
if msgtype == "m.text":
|
||||||
self.add_text_message_vars(ret, event)
|
self.add_text_message_vars(ret, event)
|
||||||
elif msgtype == "m.image":
|
elif msgtype == "m.image":
|
||||||
|
|
|
@ -16,11 +16,10 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from typing import Any, Dict, List, Optional, Pattern, Union
|
from typing import Any, Dict, List, Optional, Pattern, Tuple, Union
|
||||||
|
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.types import UserID
|
from synapse.types import UserID
|
||||||
from synapse.util.caches import register_cache
|
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -174,20 +173,21 @@ class PushRuleEvaluatorForEvent:
|
||||||
# Similar to _glob_matches, but do not treat display_name as a glob.
|
# Similar to _glob_matches, but do not treat display_name as a glob.
|
||||||
r = regex_cache.get((display_name, False, True), None)
|
r = regex_cache.get((display_name, False, True), None)
|
||||||
if not r:
|
if not r:
|
||||||
r = re.escape(display_name)
|
r1 = re.escape(display_name)
|
||||||
r = _re_word_boundary(r)
|
r1 = _re_word_boundary(r1)
|
||||||
r = re.compile(r, flags=re.IGNORECASE)
|
r = re.compile(r1, flags=re.IGNORECASE)
|
||||||
regex_cache[(display_name, False, True)] = r
|
regex_cache[(display_name, False, True)] = r
|
||||||
|
|
||||||
return r.search(body)
|
return bool(r.search(body))
|
||||||
|
|
||||||
def _get_value(self, dotted_key: str) -> Optional[str]:
|
def _get_value(self, dotted_key: str) -> Optional[str]:
|
||||||
return self._value_cache.get(dotted_key, None)
|
return self._value_cache.get(dotted_key, None)
|
||||||
|
|
||||||
|
|
||||||
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
|
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
|
||||||
regex_cache = LruCache(50000)
|
regex_cache = LruCache(
|
||||||
register_cache("cache", "regex_push_cache", regex_cache)
|
50000, "regex_push_cache"
|
||||||
|
) # type: LruCache[Tuple[str, bool, bool], Pattern]
|
||||||
|
|
||||||
|
|
||||||
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
|
def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
|
||||||
|
@ -205,7 +205,7 @@ def _glob_matches(glob: str, value: str, word_boundary: bool = False) -> bool:
|
||||||
if not r:
|
if not r:
|
||||||
r = _glob_to_re(glob, word_boundary)
|
r = _glob_to_re(glob, word_boundary)
|
||||||
regex_cache[(glob, True, word_boundary)] = r
|
regex_cache[(glob, True, word_boundary)] = r
|
||||||
return r.search(value)
|
return bool(r.search(value))
|
||||||
except re.error:
|
except re.error:
|
||||||
logger.warning("Failed to parse glob to regex: %r", glob)
|
logger.warning("Failed to parse glob to regex: %r", glob)
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -15,7 +15,7 @@
|
||||||
|
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
|
from synapse.storage.databases.main.client_ips import LAST_SEEN_GRANULARITY
|
||||||
from synapse.util.caches.deferred_cache import DeferredCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
|
||||||
from ._base import BaseSlavedStore
|
from ._base import BaseSlavedStore
|
||||||
|
|
||||||
|
@ -24,9 +24,9 @@ class SlavedClientIpStore(BaseSlavedStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
|
||||||
self.client_ip_last_seen = DeferredCache(
|
self.client_ip_last_seen = LruCache(
|
||||||
name="client_ip_last_seen", keylen=4, max_entries=50000
|
cache_name="client_ip_last_seen", keylen=4, max_size=50000
|
||||||
) # type: DeferredCache[tuple, int]
|
) # type: LruCache[tuple, int]
|
||||||
|
|
||||||
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
|
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id):
|
||||||
now = int(self._clock.time_msec())
|
now = int(self._clock.time_msec())
|
||||||
|
@ -41,7 +41,7 @@ class SlavedClientIpStore(BaseSlavedStore):
|
||||||
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.client_ip_last_seen.prefill(key, now)
|
self.client_ip_last_seen.set(key, now)
|
||||||
|
|
||||||
self.hs.get_tcp_replication().send_user_ip(
|
self.hs.get_tcp_replication().send_user_ip(
|
||||||
user_id, access_token, ip, user_agent, device_id, now
|
user_id, access_token, ip, user_agent, device_id, now
|
||||||
|
|
|
@ -1,41 +1,47 @@
|
||||||
{% for message in notif.messages %}
|
{%- for message in notif.messages %}
|
||||||
<tr class="{{ "historical_message" if message.is_historical else "message" }}">
|
<tr class="{{ "historical_message" if message.is_historical else "message" }}">
|
||||||
<td class="sender_avatar">
|
<td class="sender_avatar">
|
||||||
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
|
{%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
|
||||||
{% if message.sender_avatar_url %}
|
{%- if message.sender_avatar_url %}
|
||||||
<img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
|
<img alt="" class="sender_avatar" src="{{ message.sender_avatar_url|mxc_to_http(32,32) }}" />
|
||||||
{% else %}
|
{%- else %}
|
||||||
{% if message.sender_hash % 3 == 0 %}
|
{%- if message.sender_hash % 3 == 0 %}
|
||||||
<img class="sender_avatar" src="https://riot.im/img/external/avatar-1.png" />
|
<img class="sender_avatar" src="https://riot.im/img/external/avatar-1.png" />
|
||||||
{% elif message.sender_hash % 3 == 1 %}
|
{%- elif message.sender_hash % 3 == 1 %}
|
||||||
<img class="sender_avatar" src="https://riot.im/img/external/avatar-2.png" />
|
<img class="sender_avatar" src="https://riot.im/img/external/avatar-2.png" />
|
||||||
{% else %}
|
{%- else %}
|
||||||
<img class="sender_avatar" src="https://riot.im/img/external/avatar-3.png" />
|
<img class="sender_avatar" src="https://riot.im/img/external/avatar-3.png" />
|
||||||
{% endif %}
|
{%- endif %}
|
||||||
{% endif %}
|
{%- endif %}
|
||||||
{% endif %}
|
{%- endif %}
|
||||||
</td>
|
</td>
|
||||||
<td class="message_contents">
|
<td class="message_contents">
|
||||||
{% if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
|
{%- if loop.index0 == 0 or notif.messages[loop.index0 - 1].sender_name != notif.messages[loop.index0].sender_name %}
|
||||||
<div class="sender_name">{% if message.msgtype == "m.emote" %}*{% endif %} {{ message.sender_name }}</div>
|
<div class="sender_name">{%- if message.msgtype == "m.emote" %}*{%- endif %} {{ message.sender_name }}</div>
|
||||||
{% endif %}
|
{%- endif %}
|
||||||
<div class="message_body">
|
<div class="message_body">
|
||||||
{% if message.msgtype == "m.text" %}
|
{%- if message.event_type == "m.room.encrypted" %}
|
||||||
{{ message.body_text_html }}
|
An encrypted message.
|
||||||
{% elif message.msgtype == "m.emote" %}
|
{%- elif message.event_type == "m.room.message" %}
|
||||||
{{ message.body_text_html }}
|
{%- if message.msgtype == "m.text" %}
|
||||||
{% elif message.msgtype == "m.notice" %}
|
{{ message.body_text_html }}
|
||||||
{{ message.body_text_html }}
|
{%- elif message.msgtype == "m.emote" %}
|
||||||
{% elif message.msgtype == "m.image" %}
|
{{ message.body_text_html }}
|
||||||
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
|
{%- elif message.msgtype == "m.notice" %}
|
||||||
{% elif message.msgtype == "m.file" %}
|
{{ message.body_text_html }}
|
||||||
<span class="filename">{{ message.body_text_plain }}</span>
|
{%- elif message.msgtype == "m.image" %}
|
||||||
{% endif %}
|
<img src="{{ message.image_url|mxc_to_http(640, 480, scale) }}" />
|
||||||
|
{%- elif message.msgtype == "m.file" %}
|
||||||
|
<span class="filename">{{ message.body_text_plain }}</span>
|
||||||
|
{%- else %}
|
||||||
|
A message with unrecognised content.
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
</div>
|
</div>
|
||||||
</td>
|
</td>
|
||||||
<td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
|
<td class="message_time">{{ message.ts|format_ts("%H:%M") }}</td>
|
||||||
</tr>
|
</tr>
|
||||||
{% endfor %}
|
{%- endfor %}
|
||||||
<tr class="notif_link">
|
<tr class="notif_link">
|
||||||
<td></td>
|
<td></td>
|
||||||
<td>
|
<td>
|
||||||
|
|
|
@ -1,16 +1,22 @@
|
||||||
{% for message in notif.messages %}
|
{%- for message in notif.messages %}
|
||||||
{% if message.msgtype == "m.emote" %}* {% endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
|
{%- if message.event_type == "m.room.encrypted" %}
|
||||||
{% if message.msgtype == "m.text" %}
|
An encrypted message.
|
||||||
|
{%- elif message.event_type == "m.room.message" %}
|
||||||
|
{%- if message.msgtype == "m.emote" %}* {%- endif %}{{ message.sender_name }} ({{ message.ts|format_ts("%H:%M") }})
|
||||||
|
{%- if message.msgtype == "m.text" %}
|
||||||
{{ message.body_text_plain }}
|
{{ message.body_text_plain }}
|
||||||
{% elif message.msgtype == "m.emote" %}
|
{%- elif message.msgtype == "m.emote" %}
|
||||||
{{ message.body_text_plain }}
|
{{ message.body_text_plain }}
|
||||||
{% elif message.msgtype == "m.notice" %}
|
{%- elif message.msgtype == "m.notice" %}
|
||||||
{{ message.body_text_plain }}
|
{{ message.body_text_plain }}
|
||||||
{% elif message.msgtype == "m.image" %}
|
{%- elif message.msgtype == "m.image" %}
|
||||||
{{ message.body_text_plain }}
|
{{ message.body_text_plain }}
|
||||||
{% elif message.msgtype == "m.file" %}
|
{%- elif message.msgtype == "m.file" %}
|
||||||
{{ message.body_text_plain }}
|
{{ message.body_text_plain }}
|
||||||
{% endif %}
|
{%- else %}
|
||||||
{% endfor %}
|
A message with unrecognised content.
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
|
||||||
View {{ room.title }} at {{ notif.link }}
|
View {{ room.title }} at {{ notif.link }}
|
||||||
|
|
|
@ -2,8 +2,8 @@
|
||||||
<html lang="en">
|
<html lang="en">
|
||||||
<head>
|
<head>
|
||||||
<style type="text/css">
|
<style type="text/css">
|
||||||
{% include 'mail.css' without context %}
|
{%- include 'mail.css' without context %}
|
||||||
{% include "mail-%s.css" % app_name ignore missing without context %}
|
{%- include "mail-%s.css" % app_name ignore missing without context %}
|
||||||
</style>
|
</style>
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
|
@ -18,21 +18,21 @@
|
||||||
<div class="summarytext">{{ summary_text }}</div>
|
<div class="summarytext">{{ summary_text }}</div>
|
||||||
</td>
|
</td>
|
||||||
<td class="logo">
|
<td class="logo">
|
||||||
{% if app_name == "Riot" %}
|
{%- if app_name == "Riot" %}
|
||||||
<img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
|
<img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
|
||||||
{% elif app_name == "Vector" %}
|
{%- elif app_name == "Vector" %}
|
||||||
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
|
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
|
||||||
{% elif app_name == "Element" %}
|
{%- elif app_name == "Element" %}
|
||||||
<img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/>
|
<img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/>
|
||||||
{% else %}
|
{%- else %}
|
||||||
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
|
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
|
||||||
{% endif %}
|
{%- endif %}
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
</table>
|
</table>
|
||||||
{% for room in rooms %}
|
{%- for room in rooms %}
|
||||||
{% include 'room.html' with context %}
|
{%- include 'room.html' with context %}
|
||||||
{% endfor %}
|
{%- endfor %}
|
||||||
<div class="footer">
|
<div class="footer">
|
||||||
<a href="{{ unsubscribe_link }}">Unsubscribe</a>
|
<a href="{{ unsubscribe_link }}">Unsubscribe</a>
|
||||||
<br/>
|
<br/>
|
||||||
|
@ -41,12 +41,12 @@
|
||||||
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
|
Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because
|
||||||
an event was received at {{ reason.received_at|format_ts("%c") }}
|
an event was received at {{ reason.received_at|format_ts("%c") }}
|
||||||
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
|
which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago,
|
||||||
{% if reason.last_sent_ts %}
|
{%- if reason.last_sent_ts %}
|
||||||
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
|
and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }},
|
||||||
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
|
which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago.
|
||||||
{% else %}
|
{%- else %}
|
||||||
and we don't have a last time we sent a mail for this room.
|
and we don't have a last time we sent a mail for this room.
|
||||||
{% endif %}
|
{%- endif %}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</td>
|
</td>
|
||||||
|
|
|
@ -2,9 +2,9 @@ Hi {{ user_display_name }},
|
||||||
|
|
||||||
{{ summary_text }}
|
{{ summary_text }}
|
||||||
|
|
||||||
{% for room in rooms %}
|
{%- for room in rooms %}
|
||||||
{% include 'room.txt' with context %}
|
{%- include 'room.txt' with context %}
|
||||||
{% endfor %}
|
{%- endfor %}
|
||||||
|
|
||||||
You can disable these notifications at {{ unsubscribe_link }}
|
You can disable these notifications at {{ unsubscribe_link }}
|
||||||
|
|
||||||
|
|
|
@ -1,23 +1,23 @@
|
||||||
<table class="room">
|
<table class="room">
|
||||||
<tr class="room_header">
|
<tr class="room_header">
|
||||||
<td class="room_avatar">
|
<td class="room_avatar">
|
||||||
{% if room.avatar_url %}
|
{%- if room.avatar_url %}
|
||||||
<img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
|
<img alt="" src="{{ room.avatar_url|mxc_to_http(48,48) }}" />
|
||||||
{% else %}
|
{%- else %}
|
||||||
{% if room.hash % 3 == 0 %}
|
{%- if room.hash % 3 == 0 %}
|
||||||
<img alt="" src="https://riot.im/img/external/avatar-1.png" />
|
<img alt="" src="https://riot.im/img/external/avatar-1.png" />
|
||||||
{% elif room.hash % 3 == 1 %}
|
{%- elif room.hash % 3 == 1 %}
|
||||||
<img alt="" src="https://riot.im/img/external/avatar-2.png" />
|
<img alt="" src="https://riot.im/img/external/avatar-2.png" />
|
||||||
{% else %}
|
{%- else %}
|
||||||
<img alt="" src="https://riot.im/img/external/avatar-3.png" />
|
<img alt="" src="https://riot.im/img/external/avatar-3.png" />
|
||||||
{% endif %}
|
{%- endif %}
|
||||||
{% endif %}
|
{%- endif %}
|
||||||
</td>
|
</td>
|
||||||
<td class="room_name" colspan="2">
|
<td class="room_name" colspan="2">
|
||||||
{{ room.title }}
|
{{ room.title }}
|
||||||
</td>
|
</td>
|
||||||
</tr>
|
</tr>
|
||||||
{% if room.invite %}
|
{%- if room.invite %}
|
||||||
<tr>
|
<tr>
|
||||||
<td></td>
|
<td></td>
|
||||||
<td>
|
<td>
|
||||||
|
@ -25,9 +25,9 @@
|
||||||
</td>
|
</td>
|
||||||
<td></td>
|
<td></td>
|
||||||
</tr>
|
</tr>
|
||||||
{% else %}
|
{%- else %}
|
||||||
{% for notif in room.notifs %}
|
{%- for notif in room.notifs %}
|
||||||
{% include 'notif.html' with context %}
|
{%- include 'notif.html' with context %}
|
||||||
{% endfor %}
|
{%- endfor %}
|
||||||
{% endif %}
|
{%- endif %}
|
||||||
</table>
|
</table>
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
{{ room.title }}
|
{{ room.title }}
|
||||||
|
|
||||||
{% if room.invite %}
|
{%- if room.invite %}
|
||||||
You've been invited, join at {{ room.link }}
|
You've been invited, join at {{ room.link }}
|
||||||
{% else %}
|
{%- else %}
|
||||||
{% for notif in room.notifs %}
|
{%- for notif in room.notifs %}
|
||||||
{% include 'notif.txt' with context %}
|
{%- include 'notif.txt' with context %}
|
||||||
{% endfor %}
|
{%- endfor %}
|
||||||
{% endif %}
|
{%- endif %}
|
||||||
|
|
|
@ -110,6 +110,8 @@ class LoginRestServlet(RestServlet):
|
||||||
({"type": t} for t in self.auth_handler.get_supported_login_types())
|
({"type": t} for t in self.auth_handler.get_supported_login_types())
|
||||||
)
|
)
|
||||||
|
|
||||||
|
flows.append({"type": LoginRestServlet.APPSERVICE_TYPE})
|
||||||
|
|
||||||
return 200, {"flows": flows}
|
return 200, {"flows": flows}
|
||||||
|
|
||||||
def on_OPTIONS(self, request: SynapseRequest):
|
def on_OPTIONS(self, request: SynapseRequest):
|
||||||
|
|
|
@ -76,14 +76,16 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if key is None:
|
cache = getattr(self, cache_name)
|
||||||
getattr(self, cache_name).invalidate_all()
|
|
||||||
else:
|
|
||||||
getattr(self, cache_name).invalidate(tuple(key))
|
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# We probably haven't pulled in the cache in this worker,
|
# We probably haven't pulled in the cache in this worker,
|
||||||
# which is fine.
|
# which is fine.
|
||||||
pass
|
return
|
||||||
|
|
||||||
|
if key is None:
|
||||||
|
cache.invalidate_all()
|
||||||
|
else:
|
||||||
|
cache.invalidate(tuple(key))
|
||||||
|
|
||||||
|
|
||||||
def db_to_json(db_content):
|
def db_to_json(db_content):
|
||||||
|
|
|
@ -160,7 +160,7 @@ class LoggingDatabaseConnection:
|
||||||
self.conn.__enter__()
|
self.conn.__enter__()
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback) -> bool:
|
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
|
||||||
return self.conn.__exit__(exc_type, exc_value, traceback)
|
return self.conn.__exit__(exc_type, exc_value, traceback)
|
||||||
|
|
||||||
# Proxy through any unknown lookups to the DB conn class.
|
# Proxy through any unknown lookups to the DB conn class.
|
||||||
|
|
|
@ -19,7 +19,7 @@ from typing import Dict, Optional, Tuple
|
||||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
|
from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
|
||||||
from synapse.util.caches.deferred_cache import DeferredCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -410,8 +410,8 @@ class ClientIpWorkerStore(ClientIpBackgroundUpdateStore):
|
||||||
class ClientIpStore(ClientIpWorkerStore):
|
class ClientIpStore(ClientIpWorkerStore):
|
||||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||||
|
|
||||||
self.client_ip_last_seen = DeferredCache(
|
self.client_ip_last_seen = LruCache(
|
||||||
name="client_ip_last_seen", keylen=4, max_entries=50000
|
cache_name="client_ip_last_seen", keylen=4, max_size=50000
|
||||||
)
|
)
|
||||||
|
|
||||||
super().__init__(database, db_conn, hs)
|
super().__init__(database, db_conn, hs)
|
||||||
|
@ -442,7 +442,7 @@ class ClientIpStore(ClientIpWorkerStore):
|
||||||
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.client_ip_last_seen.prefill(key, now)
|
self.client_ip_last_seen.set(key, now)
|
||||||
|
|
||||||
self._batch_row_update[key] = (user_agent, device_id, now)
|
self._batch_row_update[key] = (user_agent, device_id, now)
|
||||||
|
|
||||||
|
|
|
@ -34,8 +34,8 @@ from synapse.storage.database import (
|
||||||
)
|
)
|
||||||
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
|
from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
|
||||||
from synapse.util import json_decoder, json_encoder
|
from synapse.util import json_decoder, json_encoder
|
||||||
from synapse.util.caches.deferred_cache import DeferredCache
|
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
from synapse.util.stringutils import shortstr
|
from synapse.util.stringutils import shortstr
|
||||||
|
|
||||||
|
@ -1005,8 +1005,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
|
|
||||||
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
# Map of (user_id, device_id) -> bool. If there is an entry that implies
|
||||||
# the device exists.
|
# the device exists.
|
||||||
self.device_id_exists_cache = DeferredCache(
|
self.device_id_exists_cache = LruCache(
|
||||||
name="device_id_exists", keylen=2, max_entries=10000
|
cache_name="device_id_exists", keylen=2, max_size=10000
|
||||||
)
|
)
|
||||||
|
|
||||||
async def store_device(
|
async def store_device(
|
||||||
|
@ -1052,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||||
)
|
)
|
||||||
if hidden:
|
if hidden:
|
||||||
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
|
raise StoreError(400, "The device ID is in use", Codes.FORBIDDEN)
|
||||||
self.device_id_exists_cache.prefill(key, True)
|
self.device_id_exists_cache.set(key, True)
|
||||||
return inserted
|
return inserted
|
||||||
except StoreError:
|
except StoreError:
|
||||||
raise
|
raise
|
||||||
|
|
|
@ -1051,9 +1051,7 @@ class PersistEventsStore:
|
||||||
|
|
||||||
def prefill():
|
def prefill():
|
||||||
for cache_entry in to_prefill:
|
for cache_entry in to_prefill:
|
||||||
self.store._get_event_cache.prefill(
|
self.store._get_event_cache.set((cache_entry[0].event_id,), cache_entry)
|
||||||
(cache_entry[0].event_id,), cache_entry
|
|
||||||
)
|
|
||||||
|
|
||||||
txn.call_after(prefill)
|
txn.call_after(prefill)
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,10 @@ from synapse.api.room_versions import (
|
||||||
from synapse.events import EventBase, make_event_from_dict
|
from synapse.events import EventBase, make_event_from_dict
|
||||||
from synapse.events.utils import prune_event
|
from synapse.events.utils import prune_event
|
||||||
from synapse.logging.context import PreserveLoggingContext, current_context
|
from synapse.logging.context import PreserveLoggingContext, current_context
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import (
|
||||||
|
run_as_background_process,
|
||||||
|
wrap_as_background_process,
|
||||||
|
)
|
||||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||||
from synapse.replication.tcp.streams import BackfillStream
|
from synapse.replication.tcp.streams import BackfillStream
|
||||||
from synapse.replication.tcp.streams.events import EventsStream
|
from synapse.replication.tcp.streams.events import EventsStream
|
||||||
|
@ -42,8 +45,8 @@ from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.engines import PostgresEngine
|
from synapse.storage.engines import PostgresEngine
|
||||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||||
from synapse.types import Collection, get_domain_from_id
|
from synapse.types import Collection, get_domain_from_id
|
||||||
from synapse.util.caches.deferred_cache import DeferredCache
|
|
||||||
from synapse.util.caches.descriptors import cached
|
from synapse.util.caches.descriptors import cached
|
||||||
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.iterutils import batch_iter
|
from synapse.util.iterutils import batch_iter
|
||||||
from synapse.util.metrics import Measure
|
from synapse.util.metrics import Measure
|
||||||
|
|
||||||
|
@ -137,20 +140,16 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
db_conn, "events", "stream_ordering", step=-1
|
db_conn, "events", "stream_ordering", step=-1
|
||||||
)
|
)
|
||||||
|
|
||||||
if not hs.config.worker.worker_app:
|
if hs.config.run_background_tasks:
|
||||||
# We periodically clean out old transaction ID mappings
|
# We periodically clean out old transaction ID mappings
|
||||||
self._clock.looping_call(
|
self._clock.looping_call(
|
||||||
run_as_background_process,
|
self._cleanup_old_transaction_ids, 5 * 60 * 1000,
|
||||||
5 * 60 * 1000,
|
|
||||||
"_cleanup_old_transaction_ids",
|
|
||||||
self._cleanup_old_transaction_ids,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._get_event_cache = DeferredCache(
|
self._get_event_cache = LruCache(
|
||||||
"*getEvent*",
|
cache_name="*getEvent*",
|
||||||
keylen=3,
|
keylen=3,
|
||||||
max_entries=hs.config.caches.event_cache_size,
|
max_size=hs.config.caches.event_cache_size,
|
||||||
apply_cache_factor_from_config=False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._event_fetch_lock = threading.Condition()
|
self._event_fetch_lock = threading.Condition()
|
||||||
|
@ -749,7 +748,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
event=original_ev, redacted_event=redacted_event
|
event=original_ev, redacted_event=redacted_event
|
||||||
)
|
)
|
||||||
|
|
||||||
self._get_event_cache.prefill((event_id,), cache_entry)
|
self._get_event_cache.set((event_id,), cache_entry)
|
||||||
result_map[event_id] = cache_entry
|
result_map[event_id] = cache_entry
|
||||||
|
|
||||||
return result_map
|
return result_map
|
||||||
|
@ -1375,6 +1374,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
return mapping
|
return mapping
|
||||||
|
|
||||||
|
@wrap_as_background_process("_cleanup_old_transaction_ids")
|
||||||
async def _cleanup_old_transaction_ids(self):
|
async def _cleanup_old_transaction_ids(self):
|
||||||
"""Cleans out transaction id mappings older than 24hrs.
|
"""Cleans out transaction id mappings older than 24hrs.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -12,7 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from synapse.api.errors import StoreError
|
from synapse.api.errors import StoreError
|
||||||
from synapse.storage._base import SQLBaseStore
|
from synapse.storage._base import SQLBaseStore
|
||||||
|
@ -72,7 +72,7 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_profile_displayname(
|
async def set_profile_displayname(
|
||||||
self, user_localpart: str, new_displayname: str
|
self, user_localpart: str, new_displayname: Optional[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.db_pool.simple_update_one(
|
await self.db_pool.simple_update_one(
|
||||||
table="profiles",
|
table="profiles",
|
||||||
|
@ -144,7 +144,7 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||||
|
|
||||||
async def get_remote_profile_cache_entries_that_expire(
|
async def get_remote_profile_cache_entries_that_expire(
|
||||||
self, last_checked: int
|
self, last_checked: int
|
||||||
) -> Dict[str, str]:
|
) -> List[Dict[str, str]]:
|
||||||
"""Get all users who haven't been checked since `last_checked`
|
"""Get all users who haven't been checked since `last_checked`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
|
@ -303,7 +303,7 @@ class PusherStore(PusherWorkerStore):
|
||||||
lock=False,
|
lock=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
user_has_pusher = self.get_if_user_has_pusher.cache.get(
|
user_has_pusher = self.get_if_user_has_pusher.cache.get_immediate(
|
||||||
(user_id,), None, update_metrics=False
|
(user_id,), None, update_metrics=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,6 @@ from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||||
from synapse.types import JsonDict
|
from synapse.types import JsonDict
|
||||||
from synapse.util import json_encoder
|
from synapse.util import json_encoder
|
||||||
from synapse.util.async_helpers import ObservableDeferred
|
|
||||||
from synapse.util.caches.descriptors import cached, cachedList
|
from synapse.util.caches.descriptors import cached, cachedList
|
||||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||||
|
|
||||||
|
@ -413,18 +412,10 @@ class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
|
||||||
if receipt_type != "m.read":
|
if receipt_type != "m.read":
|
||||||
return
|
return
|
||||||
|
|
||||||
# Returns either an ObservableDeferred or the raw result
|
res = self.get_users_with_read_receipts_in_room.cache.get_immediate(
|
||||||
res = self.get_users_with_read_receipts_in_room.cache.get(
|
|
||||||
room_id, None, update_metrics=False
|
room_id, None, update_metrics=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# first handle the ObservableDeferred case
|
|
||||||
if isinstance(res, ObservableDeferred):
|
|
||||||
if res.has_called():
|
|
||||||
res = res.get_result()
|
|
||||||
else:
|
|
||||||
res = None
|
|
||||||
|
|
||||||
if res and user_id in res:
|
if res and user_id in res:
|
||||||
# We'd only be adding to the set, so no point invalidating if the
|
# We'd only be adding to the set, so no point invalidating if the
|
||||||
# user is already there
|
# user is already there
|
||||||
|
|
|
@ -20,7 +20,10 @@ from synapse.api.constants import EventTypes, Membership
|
||||||
from synapse.events import EventBase
|
from synapse.events import EventBase
|
||||||
from synapse.events.snapshot import EventContext
|
from synapse.events.snapshot import EventContext
|
||||||
from synapse.metrics import LaterGauge
|
from synapse.metrics import LaterGauge
|
||||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
from synapse.metrics.background_process_metrics import (
|
||||||
|
run_as_background_process,
|
||||||
|
wrap_as_background_process,
|
||||||
|
)
|
||||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||||
from synapse.storage.database import DatabasePool
|
from synapse.storage.database import DatabasePool
|
||||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||||
|
@ -67,16 +70,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
):
|
):
|
||||||
self._known_servers_count = 1
|
self._known_servers_count = 1
|
||||||
self.hs.get_clock().looping_call(
|
self.hs.get_clock().looping_call(
|
||||||
run_as_background_process,
|
self._count_known_servers, 60 * 1000,
|
||||||
60 * 1000,
|
|
||||||
"_count_known_servers",
|
|
||||||
self._count_known_servers,
|
|
||||||
)
|
)
|
||||||
self.hs.get_clock().call_later(
|
self.hs.get_clock().call_later(
|
||||||
1000,
|
1000, self._count_known_servers,
|
||||||
run_as_background_process,
|
|
||||||
"_count_known_servers",
|
|
||||||
self._count_known_servers,
|
|
||||||
)
|
)
|
||||||
LaterGauge(
|
LaterGauge(
|
||||||
"synapse_federation_known_servers",
|
"synapse_federation_known_servers",
|
||||||
|
@ -85,6 +82,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
lambda: self._known_servers_count,
|
lambda: self._known_servers_count,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@wrap_as_background_process("_count_known_servers")
|
||||||
async def _count_known_servers(self):
|
async def _count_known_servers(self):
|
||||||
"""
|
"""
|
||||||
Count the servers that this server knows about.
|
Count the servers that this server knows about.
|
||||||
|
@ -531,7 +529,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||||
# If we do then we can reuse that result and simply update it with
|
# If we do then we can reuse that result and simply update it with
|
||||||
# any membership changes in `delta_ids`
|
# any membership changes in `delta_ids`
|
||||||
if context.prev_group and context.delta_ids:
|
if context.prev_group and context.delta_ids:
|
||||||
prev_res = self._get_joined_users_from_context.cache.get(
|
prev_res = self._get_joined_users_from_context.cache.get_immediate(
|
||||||
(room_id, context.prev_group), None
|
(room_id, context.prev_group), None
|
||||||
)
|
)
|
||||||
if prev_res and isinstance(prev_res, dict):
|
if prev_res and isinstance(prev_res, dict):
|
||||||
|
|
|
@ -13,6 +13,5 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
ALTER TABLE application_services_state
|
ALTER TABLE application_services_state ADD COLUMN read_receipt_stream_id INT;
|
||||||
ADD COLUMN read_receipt_stream_id INT,
|
ALTER TABLE application_services_state ADD COLUMN presence_stream_id INT;
|
||||||
ADD COLUMN presence_stream_id INT;
|
|
|
@ -0,0 +1 @@
|
||||||
|
DROP TABLE device_max_stream_id;
|
|
@ -12,7 +12,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Any, Iterable, Iterator, List, Tuple
|
from typing import Any, Iterable, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
from typing_extensions import Protocol
|
from typing_extensions import Protocol
|
||||||
|
|
||||||
|
@ -65,5 +65,5 @@ class Connection(Protocol):
|
||||||
def __enter__(self) -> "Connection":
|
def __enter__(self) -> "Connection":
|
||||||
...
|
...
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback) -> bool:
|
def __exit__(self, exc_type, exc_value, traceback) -> Optional[bool]:
|
||||||
...
|
...
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from sys import intern
|
from sys import intern
|
||||||
from typing import Callable, Dict, Optional
|
from typing import Callable, Dict, Optional, Sized
|
||||||
|
|
||||||
import attr
|
import attr
|
||||||
from prometheus_client.core import Gauge
|
from prometheus_client.core import Gauge
|
||||||
|
@ -92,7 +92,7 @@ class CacheMetric:
|
||||||
def register_cache(
|
def register_cache(
|
||||||
cache_type: str,
|
cache_type: str,
|
||||||
cache_name: str,
|
cache_name: str,
|
||||||
cache,
|
cache: Sized,
|
||||||
collect_callback: Optional[Callable] = None,
|
collect_callback: Optional[Callable] = None,
|
||||||
resizable: bool = True,
|
resizable: bool = True,
|
||||||
resize_callback: Optional[Callable] = None,
|
resize_callback: Optional[Callable] = None,
|
||||||
|
@ -100,12 +100,15 @@ def register_cache(
|
||||||
"""Register a cache object for metric collection and resizing.
|
"""Register a cache object for metric collection and resizing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cache_type
|
cache_type: a string indicating the "type" of the cache. This is used
|
||||||
|
only for deduplication so isn't too important provided it's constant.
|
||||||
cache_name: name of the cache
|
cache_name: name of the cache
|
||||||
cache: cache itself
|
cache: cache itself, which must implement __len__(), and may optionally implement
|
||||||
|
a max_size property
|
||||||
collect_callback: If given, a function which is called during metric
|
collect_callback: If given, a function which is called during metric
|
||||||
collection to update additional metrics.
|
collection to update additional metrics.
|
||||||
resizable: Whether this cache supports being resized.
|
resizable: Whether this cache supports being resized, in which case either
|
||||||
|
resize_callback must be provided, or the cache must support set_max_size().
|
||||||
resize_callback: A function which can be called to resize the cache.
|
resize_callback: A function which can be called to resize the cache.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
|
@ -17,14 +17,23 @@
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import threading
|
import threading
|
||||||
from typing import Callable, Generic, Iterable, MutableMapping, Optional, TypeVar, cast
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Generic,
|
||||||
|
Iterable,
|
||||||
|
MutableMapping,
|
||||||
|
Optional,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
)
|
||||||
|
|
||||||
from prometheus_client import Gauge
|
from prometheus_client import Gauge
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
from twisted.python import failure
|
||||||
|
|
||||||
from synapse.util.async_helpers import ObservableDeferred
|
from synapse.util.async_helpers import ObservableDeferred
|
||||||
from synapse.util.caches import register_cache
|
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
|
||||||
|
|
||||||
|
@ -34,7 +43,7 @@ cache_pending_metric = Gauge(
|
||||||
["name"],
|
["name"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
KT = TypeVar("KT")
|
KT = TypeVar("KT")
|
||||||
VT = TypeVar("VT")
|
VT = TypeVar("VT")
|
||||||
|
|
||||||
|
@ -49,15 +58,12 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
"""Wraps an LruCache, adding support for Deferred results.
|
"""Wraps an LruCache, adding support for Deferred results.
|
||||||
|
|
||||||
It expects that each entry added with set() will be a Deferred; likewise get()
|
It expects that each entry added with set() will be a Deferred; likewise get()
|
||||||
may return an ObservableDeferred.
|
will return a Deferred.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__slots__ = (
|
__slots__ = (
|
||||||
"cache",
|
"cache",
|
||||||
"name",
|
|
||||||
"keylen",
|
|
||||||
"thread",
|
"thread",
|
||||||
"metrics",
|
|
||||||
"_pending_deferred_cache",
|
"_pending_deferred_cache",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -89,37 +95,27 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
cache_type()
|
cache_type()
|
||||||
) # type: MutableMapping[KT, CacheEntry]
|
) # type: MutableMapping[KT, CacheEntry]
|
||||||
|
|
||||||
|
def metrics_cb():
|
||||||
|
cache_pending_metric.labels(name).set(len(self._pending_deferred_cache))
|
||||||
|
|
||||||
# cache is used for completed results and maps to the result itself, rather than
|
# cache is used for completed results and maps to the result itself, rather than
|
||||||
# a Deferred.
|
# a Deferred.
|
||||||
self.cache = LruCache(
|
self.cache = LruCache(
|
||||||
max_size=max_entries,
|
max_size=max_entries,
|
||||||
keylen=keylen,
|
keylen=keylen,
|
||||||
|
cache_name=name,
|
||||||
cache_type=cache_type,
|
cache_type=cache_type,
|
||||||
size_callback=(lambda d: len(d)) if iterable else None,
|
size_callback=(lambda d: len(d)) if iterable else None,
|
||||||
evicted_callback=self._on_evicted,
|
metrics_collection_callback=metrics_cb,
|
||||||
apply_cache_factor_from_config=apply_cache_factor_from_config,
|
apply_cache_factor_from_config=apply_cache_factor_from_config,
|
||||||
)
|
) # type: LruCache[KT, VT]
|
||||||
|
|
||||||
self.name = name
|
|
||||||
self.keylen = keylen
|
|
||||||
self.thread = None # type: Optional[threading.Thread]
|
self.thread = None # type: Optional[threading.Thread]
|
||||||
self.metrics = register_cache(
|
|
||||||
"cache",
|
|
||||||
name,
|
|
||||||
self.cache,
|
|
||||||
collect_callback=self._metrics_collection_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_entries(self):
|
def max_entries(self):
|
||||||
return self.cache.max_size
|
return self.cache.max_size
|
||||||
|
|
||||||
def _on_evicted(self, evicted_count):
|
|
||||||
self.metrics.inc_evictions(evicted_count)
|
|
||||||
|
|
||||||
def _metrics_collection_callback(self):
|
|
||||||
cache_pending_metric.labels(self.name).set(len(self._pending_deferred_cache))
|
|
||||||
|
|
||||||
def check_thread(self):
|
def check_thread(self):
|
||||||
expected_thread = self.thread
|
expected_thread = self.thread
|
||||||
if expected_thread is None:
|
if expected_thread is None:
|
||||||
|
@ -133,62 +129,113 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
def get(
|
def get(
|
||||||
self,
|
self,
|
||||||
key: KT,
|
key: KT,
|
||||||
default=_Sentinel.sentinel,
|
|
||||||
callback: Optional[Callable[[], None]] = None,
|
callback: Optional[Callable[[], None]] = None,
|
||||||
update_metrics: bool = True,
|
update_metrics: bool = True,
|
||||||
):
|
) -> defer.Deferred:
|
||||||
"""Looks the key up in the caches.
|
"""Looks the key up in the caches.
|
||||||
|
|
||||||
|
For symmetry with set(), this method does *not* follow the synapse logcontext
|
||||||
|
rules: the logcontext will not be cleared on return, and the Deferred will run
|
||||||
|
its callbacks in the sentinel context. In other words: wrap the result with
|
||||||
|
make_deferred_yieldable() before `await`ing it.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key(tuple)
|
key:
|
||||||
default: What is returned if key is not in the caches. If not
|
callback: Gets called when the entry in the cache is invalidated
|
||||||
specified then function throws KeyError instead
|
|
||||||
callback(fn): Gets called when the entry in the cache is invalidated
|
|
||||||
update_metrics (bool): whether to update the cache hit rate metrics
|
update_metrics (bool): whether to update the cache hit rate metrics
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Either an ObservableDeferred or the result itself
|
A Deferred which completes with the result. Note that this may later fail
|
||||||
|
if there is an ongoing set() operation which later completes with a failure.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
KeyError if the key is not found in the cache
|
||||||
"""
|
"""
|
||||||
callbacks = [callback] if callback else []
|
callbacks = [callback] if callback else []
|
||||||
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
|
val = self._pending_deferred_cache.get(key, _Sentinel.sentinel)
|
||||||
if val is not _Sentinel.sentinel:
|
if val is not _Sentinel.sentinel:
|
||||||
val.callbacks.update(callbacks)
|
val.callbacks.update(callbacks)
|
||||||
if update_metrics:
|
if update_metrics:
|
||||||
self.metrics.inc_hits()
|
m = self.cache.metrics
|
||||||
return val.deferred
|
assert m # we always have a name, so should always have metrics
|
||||||
|
m.inc_hits()
|
||||||
|
return val.deferred.observe()
|
||||||
|
|
||||||
val = self.cache.get(key, _Sentinel.sentinel, callbacks=callbacks)
|
val2 = self.cache.get(
|
||||||
if val is not _Sentinel.sentinel:
|
key, _Sentinel.sentinel, callbacks=callbacks, update_metrics=update_metrics
|
||||||
self.metrics.inc_hits()
|
)
|
||||||
return val
|
if val2 is _Sentinel.sentinel:
|
||||||
|
|
||||||
if update_metrics:
|
|
||||||
self.metrics.inc_misses()
|
|
||||||
|
|
||||||
if default is _Sentinel.sentinel:
|
|
||||||
raise KeyError()
|
raise KeyError()
|
||||||
else:
|
else:
|
||||||
return default
|
return defer.succeed(val2)
|
||||||
|
|
||||||
|
def get_immediate(
|
||||||
|
self, key: KT, default: T, update_metrics: bool = True
|
||||||
|
) -> Union[VT, T]:
|
||||||
|
"""If we have a *completed* cached value, return it."""
|
||||||
|
return self.cache.get(key, default, update_metrics=update_metrics)
|
||||||
|
|
||||||
def set(
|
def set(
|
||||||
self,
|
self,
|
||||||
key: KT,
|
key: KT,
|
||||||
value: defer.Deferred,
|
value: defer.Deferred,
|
||||||
callback: Optional[Callable[[], None]] = None,
|
callback: Optional[Callable[[], None]] = None,
|
||||||
) -> ObservableDeferred:
|
) -> defer.Deferred:
|
||||||
|
"""Adds a new entry to the cache (or updates an existing one).
|
||||||
|
|
||||||
|
The given `value` *must* be a Deferred.
|
||||||
|
|
||||||
|
First any existing entry for the same key is invalidated. Then a new entry
|
||||||
|
is added to the cache for the given key.
|
||||||
|
|
||||||
|
Until the `value` completes, calls to `get()` for the key will also result in an
|
||||||
|
incomplete Deferred, which will ultimately complete with the same result as
|
||||||
|
`value`.
|
||||||
|
|
||||||
|
If `value` completes successfully, subsequent calls to `get()` will then return
|
||||||
|
a completed deferred with the same result. If it *fails*, the cache is
|
||||||
|
invalidated and subequent calls to `get()` will raise a KeyError.
|
||||||
|
|
||||||
|
If another call to `set()` happens before `value` completes, then (a) any
|
||||||
|
invalidation callbacks registered in the interim will be called, (b) any
|
||||||
|
`get()`s in the interim will continue to complete with the result from the
|
||||||
|
*original* `value`, (c) any future calls to `get()` will complete with the
|
||||||
|
result from the *new* `value`.
|
||||||
|
|
||||||
|
It is expected that `value` does *not* follow the synapse logcontext rules - ie,
|
||||||
|
if it is incomplete, it runs its callbacks in the sentinel context.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: Key to be set
|
||||||
|
value: a deferred which will complete with a result to add to the cache
|
||||||
|
callback: An optional callback to be called when the entry is invalidated
|
||||||
|
"""
|
||||||
if not isinstance(value, defer.Deferred):
|
if not isinstance(value, defer.Deferred):
|
||||||
raise TypeError("not a Deferred")
|
raise TypeError("not a Deferred")
|
||||||
|
|
||||||
callbacks = [callback] if callback else []
|
callbacks = [callback] if callback else []
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
observable = ObservableDeferred(value, consumeErrors=True)
|
|
||||||
observer = observable.observe()
|
|
||||||
entry = CacheEntry(deferred=observable, callbacks=callbacks)
|
|
||||||
|
|
||||||
existing_entry = self._pending_deferred_cache.pop(key, None)
|
existing_entry = self._pending_deferred_cache.pop(key, None)
|
||||||
if existing_entry:
|
if existing_entry:
|
||||||
existing_entry.invalidate()
|
existing_entry.invalidate()
|
||||||
|
|
||||||
|
# XXX: why don't we invalidate the entry in `self.cache` yet?
|
||||||
|
|
||||||
|
# we can save a whole load of effort if the deferred is ready.
|
||||||
|
if value.called:
|
||||||
|
result = value.result
|
||||||
|
if not isinstance(result, failure.Failure):
|
||||||
|
self.cache.set(key, result, callbacks)
|
||||||
|
return value
|
||||||
|
|
||||||
|
# otherwise, we'll add an entry to the _pending_deferred_cache for now,
|
||||||
|
# and add callbacks to add it to the cache properly later.
|
||||||
|
|
||||||
|
observable = ObservableDeferred(value, consumeErrors=True)
|
||||||
|
observer = observable.observe()
|
||||||
|
entry = CacheEntry(deferred=observable, callbacks=callbacks)
|
||||||
|
|
||||||
self._pending_deferred_cache[key] = entry
|
self._pending_deferred_cache[key] = entry
|
||||||
|
|
||||||
def compare_and_pop():
|
def compare_and_pop():
|
||||||
|
@ -232,7 +279,9 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
# _pending_deferred_cache to the real cache.
|
# _pending_deferred_cache to the real cache.
|
||||||
#
|
#
|
||||||
observer.addCallbacks(cb, eb)
|
observer.addCallbacks(cb, eb)
|
||||||
return observable
|
|
||||||
|
# we return a new Deferred which will be called before any subsequent observers.
|
||||||
|
return observable.observe()
|
||||||
|
|
||||||
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
|
def prefill(self, key: KT, value: VT, callback: Callable[[], None] = None):
|
||||||
callbacks = [callback] if callback else []
|
callbacks = [callback] if callback else []
|
||||||
|
@ -257,11 +306,12 @@ class DeferredCache(Generic[KT, VT]):
|
||||||
self.check_thread()
|
self.check_thread()
|
||||||
if not isinstance(key, tuple):
|
if not isinstance(key, tuple):
|
||||||
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
|
raise TypeError("The cache key must be a tuple not %r" % (type(key),))
|
||||||
|
key = cast(KT, key)
|
||||||
self.cache.del_multi(key)
|
self.cache.del_multi(key)
|
||||||
|
|
||||||
# if we have a pending lookup for this key, remove it from the
|
# if we have a pending lookup for this key, remove it from the
|
||||||
# _pending_deferred_cache, as above
|
# _pending_deferred_cache, as above
|
||||||
entry_dict = self._pending_deferred_cache.pop(cast(KT, key), None)
|
entry_dict = self._pending_deferred_cache.pop(key, None)
|
||||||
if entry_dict is not None:
|
if entry_dict is not None:
|
||||||
for entry in iterate_tree_cache_entry(entry_dict):
|
for entry in iterate_tree_cache_entry(entry_dict):
|
||||||
entry.invalidate()
|
entry.invalidate()
|
||||||
|
|
|
@ -23,7 +23,6 @@ from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
from synapse.logging.context import make_deferred_yieldable, preserve_fn
|
||||||
from synapse.util import unwrapFirstError
|
from synapse.util import unwrapFirstError
|
||||||
from synapse.util.async_helpers import ObservableDeferred
|
|
||||||
from synapse.util.caches.deferred_cache import DeferredCache
|
from synapse.util.caches.deferred_cache import DeferredCache
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -156,7 +155,7 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
keylen=self.num_args,
|
keylen=self.num_args,
|
||||||
tree=self.tree,
|
tree=self.tree,
|
||||||
iterable=self.iterable,
|
iterable=self.iterable,
|
||||||
) # type: DeferredCache[Tuple, Any]
|
) # type: DeferredCache[CacheKey, Any]
|
||||||
|
|
||||||
def get_cache_key_gen(args, kwargs):
|
def get_cache_key_gen(args, kwargs):
|
||||||
"""Given some args/kwargs return a generator that resolves into
|
"""Given some args/kwargs return a generator that resolves into
|
||||||
|
@ -202,32 +201,20 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||||
|
|
||||||
cache_key = get_cache_key(args, kwargs)
|
cache_key = get_cache_key(args, kwargs)
|
||||||
|
|
||||||
# Add our own `cache_context` to argument list if the wrapped function
|
|
||||||
# has asked for one
|
|
||||||
if self.add_cache_context:
|
|
||||||
kwargs["cache_context"] = _CacheContext.get_instance(cache, cache_key)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
cached_result_d = cache.get(cache_key, callback=invalidate_callback)
|
ret = cache.get(cache_key, callback=invalidate_callback)
|
||||||
|
|
||||||
if isinstance(cached_result_d, ObservableDeferred):
|
|
||||||
observer = cached_result_d.observe()
|
|
||||||
else:
|
|
||||||
observer = defer.succeed(cached_result_d)
|
|
||||||
|
|
||||||
except KeyError:
|
except KeyError:
|
||||||
|
# Add our own `cache_context` to argument list if the wrapped function
|
||||||
|
# has asked for one
|
||||||
|
if self.add_cache_context:
|
||||||
|
kwargs["cache_context"] = _CacheContext.get_instance(
|
||||||
|
cache, cache_key
|
||||||
|
)
|
||||||
|
|
||||||
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
|
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
|
||||||
|
ret = cache.set(cache_key, ret, callback=invalidate_callback)
|
||||||
|
|
||||||
def onErr(f):
|
return make_deferred_yieldable(ret)
|
||||||
cache.invalidate(cache_key)
|
|
||||||
return f
|
|
||||||
|
|
||||||
ret.addErrback(onErr)
|
|
||||||
|
|
||||||
result_d = cache.set(cache_key, ret, callback=invalidate_callback)
|
|
||||||
observer = result_d.observe()
|
|
||||||
|
|
||||||
return make_deferred_yieldable(observer)
|
|
||||||
|
|
||||||
wrapped = cast(_CachedFunction, _wrapped)
|
wrapped = cast(_CachedFunction, _wrapped)
|
||||||
|
|
||||||
|
@ -286,7 +273,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
||||||
|
|
||||||
def __get__(self, obj, objtype=None):
|
def __get__(self, obj, objtype=None):
|
||||||
cached_method = getattr(obj, self.cached_method_name)
|
cached_method = getattr(obj, self.cached_method_name)
|
||||||
cache = cached_method.cache
|
cache = cached_method.cache # type: DeferredCache[CacheKey, Any]
|
||||||
num_args = cached_method.num_args
|
num_args = cached_method.num_args
|
||||||
|
|
||||||
@functools.wraps(self.orig)
|
@functools.wraps(self.orig)
|
||||||
|
@ -326,14 +313,11 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
||||||
for arg in list_args:
|
for arg in list_args:
|
||||||
try:
|
try:
|
||||||
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
|
res = cache.get(arg_to_cache_key(arg), callback=invalidate_callback)
|
||||||
if not isinstance(res, ObservableDeferred):
|
if not res.called:
|
||||||
results[arg] = res
|
|
||||||
elif not res.has_succeeded():
|
|
||||||
res = res.observe()
|
|
||||||
res.addCallback(update_results_dict, arg)
|
res.addCallback(update_results_dict, arg)
|
||||||
cached_defers.append(res)
|
cached_defers.append(res)
|
||||||
else:
|
else:
|
||||||
results[arg] = res.get_result()
|
results[arg] = res.result
|
||||||
except KeyError:
|
except KeyError:
|
||||||
missing.add(arg)
|
missing.add(arg)
|
||||||
|
|
||||||
|
|
|
@ -12,15 +12,14 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
import enum
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
|
|
||||||
from . import register_cache
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -40,24 +39,25 @@ class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "va
|
||||||
return len(self.value)
|
return len(self.value)
|
||||||
|
|
||||||
|
|
||||||
|
class _Sentinel(enum.Enum):
|
||||||
|
# defining a sentinel in this way allows mypy to correctly handle the
|
||||||
|
# type of a dictionary lookup.
|
||||||
|
sentinel = object()
|
||||||
|
|
||||||
|
|
||||||
class DictionaryCache:
|
class DictionaryCache:
|
||||||
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
|
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
|
||||||
fetching a subset of dictionary keys for a particular key.
|
fetching a subset of dictionary keys for a particular key.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, name, max_entries=1000):
|
def __init__(self, name, max_entries=1000):
|
||||||
self.cache = LruCache(max_size=max_entries, size_callback=len)
|
self.cache = LruCache(
|
||||||
|
max_size=max_entries, cache_name=name, size_callback=len
|
||||||
|
) # type: LruCache[Any, DictionaryEntry]
|
||||||
|
|
||||||
self.name = name
|
self.name = name
|
||||||
self.sequence = 0
|
self.sequence = 0
|
||||||
self.thread = None
|
self.thread = None
|
||||||
# caches_by_name[name] = self.cache
|
|
||||||
|
|
||||||
class Sentinel:
|
|
||||||
__slots__ = []
|
|
||||||
|
|
||||||
self.sentinel = Sentinel()
|
|
||||||
self.metrics = register_cache("dictionary", name, self.cache)
|
|
||||||
|
|
||||||
def check_thread(self):
|
def check_thread(self):
|
||||||
expected_thread = self.thread
|
expected_thread = self.thread
|
||||||
|
@ -80,10 +80,8 @@ class DictionaryCache:
|
||||||
Returns:
|
Returns:
|
||||||
DictionaryEntry
|
DictionaryEntry
|
||||||
"""
|
"""
|
||||||
entry = self.cache.get(key, self.sentinel)
|
entry = self.cache.get(key, _Sentinel.sentinel)
|
||||||
if entry is not self.sentinel:
|
if entry is not _Sentinel.sentinel:
|
||||||
self.metrics.inc_hits()
|
|
||||||
|
|
||||||
if dict_keys is None:
|
if dict_keys is None:
|
||||||
return DictionaryEntry(
|
return DictionaryEntry(
|
||||||
entry.full, entry.known_absent, dict(entry.value)
|
entry.full, entry.known_absent, dict(entry.value)
|
||||||
|
@ -95,7 +93,6 @@ class DictionaryCache:
|
||||||
{k: entry.value[k] for k in dict_keys if k in entry.value},
|
{k: entry.value[k] for k in dict_keys if k in entry.value},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.metrics.inc_misses()
|
|
||||||
return DictionaryEntry(False, set(), {})
|
return DictionaryEntry(False, set(), {})
|
||||||
|
|
||||||
def invalidate(self, key):
|
def invalidate(self, key):
|
||||||
|
|
|
@ -15,11 +15,35 @@
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable, Optional, Type, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Generic,
|
||||||
|
Iterable,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
cast,
|
||||||
|
overload,
|
||||||
|
)
|
||||||
|
|
||||||
|
from typing_extensions import Literal
|
||||||
|
|
||||||
from synapse.config import cache as cache_config
|
from synapse.config import cache as cache_config
|
||||||
|
from synapse.util.caches import CacheMetric, register_cache
|
||||||
from synapse.util.caches.treecache import TreeCache
|
from synapse.util.caches.treecache import TreeCache
|
||||||
|
|
||||||
|
# Function type: the type used for invalidation callbacks
|
||||||
|
FT = TypeVar("FT", bound=Callable[..., Any])
|
||||||
|
|
||||||
|
# Key and Value type for the cache
|
||||||
|
KT = TypeVar("KT")
|
||||||
|
VT = TypeVar("VT")
|
||||||
|
|
||||||
|
# a general type var, distinct from either KT or VT
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def enumerate_leaves(node, depth):
|
def enumerate_leaves(node, depth):
|
||||||
if depth == 0:
|
if depth == 0:
|
||||||
|
@ -41,29 +65,31 @@ class _Node:
|
||||||
self.callbacks = callbacks
|
self.callbacks = callbacks
|
||||||
|
|
||||||
|
|
||||||
class LruCache:
|
class LruCache(Generic[KT, VT]):
|
||||||
"""
|
"""
|
||||||
Least-recently-used cache.
|
Least-recently-used cache, supporting prometheus metrics and invalidation callbacks.
|
||||||
|
|
||||||
Supports del_multi only if cache_type=TreeCache
|
Supports del_multi only if cache_type=TreeCache
|
||||||
If cache_type=TreeCache, all keys must be tuples.
|
If cache_type=TreeCache, all keys must be tuples.
|
||||||
|
|
||||||
Can also set callbacks on objects when getting/setting which are fired
|
|
||||||
when that key gets invalidated/evicted.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
max_size: int,
|
max_size: int,
|
||||||
|
cache_name: Optional[str] = None,
|
||||||
keylen: int = 1,
|
keylen: int = 1,
|
||||||
cache_type: Type[Union[dict, TreeCache]] = dict,
|
cache_type: Type[Union[dict, TreeCache]] = dict,
|
||||||
size_callback: Optional[Callable] = None,
|
size_callback: Optional[Callable] = None,
|
||||||
evicted_callback: Optional[Callable] = None,
|
metrics_collection_callback: Optional[Callable[[], None]] = None,
|
||||||
apply_cache_factor_from_config: bool = True,
|
apply_cache_factor_from_config: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
max_size: The maximum amount of entries the cache can hold
|
max_size: The maximum amount of entries the cache can hold
|
||||||
|
|
||||||
|
cache_name: The name of this cache, for the prometheus metrics. If unset,
|
||||||
|
no metrics will be reported on this cache.
|
||||||
|
|
||||||
keylen: The length of the tuple used as the cache key. Ignored unless
|
keylen: The length of the tuple used as the cache key. Ignored unless
|
||||||
cache_type is `TreeCache`.
|
cache_type is `TreeCache`.
|
||||||
|
|
||||||
|
@ -73,9 +99,13 @@ class LruCache:
|
||||||
|
|
||||||
size_callback (func(V) -> int | None):
|
size_callback (func(V) -> int | None):
|
||||||
|
|
||||||
evicted_callback (func(int)|None):
|
metrics_collection_callback:
|
||||||
if not None, called on eviction with the size of the evicted
|
metrics collection callback. This is called early in the metrics
|
||||||
entry
|
collection process, before any of the metrics registered with the
|
||||||
|
prometheus Registry are collected, so can be used to update any dynamic
|
||||||
|
metrics.
|
||||||
|
|
||||||
|
Ignored if cache_name is None.
|
||||||
|
|
||||||
apply_cache_factor_from_config (bool): If true, `max_size` will be
|
apply_cache_factor_from_config (bool): If true, `max_size` will be
|
||||||
multiplied by a cache factor derived from the homeserver config
|
multiplied by a cache factor derived from the homeserver config
|
||||||
|
@ -94,6 +124,23 @@ class LruCache:
|
||||||
else:
|
else:
|
||||||
self.max_size = int(max_size)
|
self.max_size = int(max_size)
|
||||||
|
|
||||||
|
# register_cache might call our "set_cache_factor" callback; there's nothing to
|
||||||
|
# do yet when we get resized.
|
||||||
|
self._on_resize = None # type: Optional[Callable[[],None]]
|
||||||
|
|
||||||
|
if cache_name is not None:
|
||||||
|
metrics = register_cache(
|
||||||
|
"lru_cache",
|
||||||
|
cache_name,
|
||||||
|
self,
|
||||||
|
collect_callback=metrics_collection_callback,
|
||||||
|
) # type: Optional[CacheMetric]
|
||||||
|
else:
|
||||||
|
metrics = None
|
||||||
|
|
||||||
|
# this is exposed for access from outside this class
|
||||||
|
self.metrics = metrics
|
||||||
|
|
||||||
list_root = _Node(None, None, None, None)
|
list_root = _Node(None, None, None, None)
|
||||||
list_root.next_node = list_root
|
list_root.next_node = list_root
|
||||||
list_root.prev_node = list_root
|
list_root.prev_node = list_root
|
||||||
|
@ -105,16 +152,16 @@ class LruCache:
|
||||||
todelete = list_root.prev_node
|
todelete = list_root.prev_node
|
||||||
evicted_len = delete_node(todelete)
|
evicted_len = delete_node(todelete)
|
||||||
cache.pop(todelete.key, None)
|
cache.pop(todelete.key, None)
|
||||||
if evicted_callback:
|
if metrics:
|
||||||
evicted_callback(evicted_len)
|
metrics.inc_evictions(evicted_len)
|
||||||
|
|
||||||
def synchronized(f):
|
def synchronized(f: FT) -> FT:
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
def inner(*args, **kwargs):
|
def inner(*args, **kwargs):
|
||||||
with lock:
|
with lock:
|
||||||
return f(*args, **kwargs)
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
return inner
|
return cast(FT, inner)
|
||||||
|
|
||||||
cached_cache_len = [0]
|
cached_cache_len = [0]
|
||||||
if size_callback is not None:
|
if size_callback is not None:
|
||||||
|
@ -168,18 +215,45 @@ class LruCache:
|
||||||
node.callbacks.clear()
|
node.callbacks.clear()
|
||||||
return deleted_len
|
return deleted_len
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def cache_get(
|
||||||
|
key: KT,
|
||||||
|
default: Literal[None] = None,
|
||||||
|
callbacks: Iterable[Callable[[], None]] = ...,
|
||||||
|
update_metrics: bool = ...,
|
||||||
|
) -> Optional[VT]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def cache_get(
|
||||||
|
key: KT,
|
||||||
|
default: T,
|
||||||
|
callbacks: Iterable[Callable[[], None]] = ...,
|
||||||
|
update_metrics: bool = ...,
|
||||||
|
) -> Union[T, VT]:
|
||||||
|
...
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_get(key, default=None, callbacks=[]):
|
def cache_get(
|
||||||
|
key: KT,
|
||||||
|
default: Optional[T] = None,
|
||||||
|
callbacks: Iterable[Callable[[], None]] = [],
|
||||||
|
update_metrics: bool = True,
|
||||||
|
):
|
||||||
node = cache.get(key, None)
|
node = cache.get(key, None)
|
||||||
if node is not None:
|
if node is not None:
|
||||||
move_node_to_front(node)
|
move_node_to_front(node)
|
||||||
node.callbacks.update(callbacks)
|
node.callbacks.update(callbacks)
|
||||||
|
if update_metrics and metrics:
|
||||||
|
metrics.inc_hits()
|
||||||
return node.value
|
return node.value
|
||||||
else:
|
else:
|
||||||
|
if update_metrics and metrics:
|
||||||
|
metrics.inc_misses()
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_set(key, value, callbacks=[]):
|
def cache_set(key: KT, value: VT, callbacks: Iterable[Callable[[], None]] = []):
|
||||||
node = cache.get(key, None)
|
node = cache.get(key, None)
|
||||||
if node is not None:
|
if node is not None:
|
||||||
# We sometimes store large objects, e.g. dicts, which cause
|
# We sometimes store large objects, e.g. dicts, which cause
|
||||||
|
@ -208,7 +282,7 @@ class LruCache:
|
||||||
evict()
|
evict()
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_set_default(key, value):
|
def cache_set_default(key: KT, value: VT) -> VT:
|
||||||
node = cache.get(key, None)
|
node = cache.get(key, None)
|
||||||
if node is not None:
|
if node is not None:
|
||||||
return node.value
|
return node.value
|
||||||
|
@ -217,8 +291,16 @@ class LruCache:
|
||||||
evict()
|
evict()
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def cache_pop(key: KT, default: Literal[None] = None) -> Optional[VT]:
|
||||||
|
...
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def cache_pop(key: KT, default: T) -> Union[T, VT]:
|
||||||
|
...
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_pop(key, default=None):
|
def cache_pop(key: KT, default: Optional[T] = None):
|
||||||
node = cache.get(key, None)
|
node = cache.get(key, None)
|
||||||
if node:
|
if node:
|
||||||
delete_node(node)
|
delete_node(node)
|
||||||
|
@ -228,18 +310,18 @@ class LruCache:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_del_multi(key):
|
def cache_del_multi(key: KT) -> None:
|
||||||
"""
|
"""
|
||||||
This will only work if constructed with cache_type=TreeCache
|
This will only work if constructed with cache_type=TreeCache
|
||||||
"""
|
"""
|
||||||
popped = cache.pop(key)
|
popped = cache.pop(key)
|
||||||
if popped is None:
|
if popped is None:
|
||||||
return
|
return
|
||||||
for leaf in enumerate_leaves(popped, keylen - len(key)):
|
for leaf in enumerate_leaves(popped, keylen - len(cast(tuple, key))):
|
||||||
delete_node(leaf)
|
delete_node(leaf)
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_clear():
|
def cache_clear() -> None:
|
||||||
list_root.next_node = list_root
|
list_root.next_node = list_root
|
||||||
list_root.prev_node = list_root
|
list_root.prev_node = list_root
|
||||||
for node in cache.values():
|
for node in cache.values():
|
||||||
|
@ -250,15 +332,21 @@ class LruCache:
|
||||||
cached_cache_len[0] = 0
|
cached_cache_len[0] = 0
|
||||||
|
|
||||||
@synchronized
|
@synchronized
|
||||||
def cache_contains(key):
|
def cache_contains(key: KT) -> bool:
|
||||||
return key in cache
|
return key in cache
|
||||||
|
|
||||||
self.sentinel = object()
|
self.sentinel = object()
|
||||||
|
|
||||||
|
# make sure that we clear out any excess entries after we get resized.
|
||||||
self._on_resize = evict
|
self._on_resize = evict
|
||||||
|
|
||||||
self.get = cache_get
|
self.get = cache_get
|
||||||
self.set = cache_set
|
self.set = cache_set
|
||||||
self.setdefault = cache_set_default
|
self.setdefault = cache_set_default
|
||||||
self.pop = cache_pop
|
self.pop = cache_pop
|
||||||
|
# `invalidate` is exposed for consistency with DeferredCache, so that it can be
|
||||||
|
# invalidated by the cache invalidation replication stream.
|
||||||
|
self.invalidate = cache_pop
|
||||||
if cache_type is TreeCache:
|
if cache_type is TreeCache:
|
||||||
self.del_multi = cache_del_multi
|
self.del_multi = cache_del_multi
|
||||||
self.len = synchronized(cache_len)
|
self.len = synchronized(cache_len)
|
||||||
|
@ -302,6 +390,7 @@ class LruCache:
|
||||||
new_size = int(self._original_max_size * factor)
|
new_size = int(self._original_max_size * factor)
|
||||||
if new_size != self.max_size:
|
if new_size != self.max_size:
|
||||||
self.max_size = new_size
|
self.max_size = new_size
|
||||||
self._on_resize()
|
if self._on_resize:
|
||||||
|
self._on_resize()
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -15,7 +15,10 @@
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
from twisted.internet import epollreactor
|
try:
|
||||||
|
from twisted.internet.epollreactor import EPollReactor as Reactor
|
||||||
|
except ImportError:
|
||||||
|
from twisted.internet.pollreactor import PollReactor as Reactor
|
||||||
from twisted.internet.main import installReactor
|
from twisted.internet.main import installReactor
|
||||||
|
|
||||||
from synapse.config.homeserver import HomeServerConfig
|
from synapse.config.homeserver import HomeServerConfig
|
||||||
|
@ -41,7 +44,7 @@ async def make_homeserver(reactor, config=None):
|
||||||
config_obj = HomeServerConfig()
|
config_obj = HomeServerConfig()
|
||||||
config_obj.parse_config_dict(config, "", "")
|
config_obj.parse_config_dict(config, "", "")
|
||||||
|
|
||||||
hs = await setup_test_homeserver(
|
hs = setup_test_homeserver(
|
||||||
cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock
|
cleanup_tasks.append, config=config_obj, reactor=reactor, clock=clock
|
||||||
)
|
)
|
||||||
stor = hs.get_datastore()
|
stor = hs.get_datastore()
|
||||||
|
@ -63,7 +66,7 @@ def make_reactor():
|
||||||
Instantiate and install a Twisted reactor suitable for testing (i.e. not the
|
Instantiate and install a Twisted reactor suitable for testing (i.e. not the
|
||||||
default global one).
|
default global one).
|
||||||
"""
|
"""
|
||||||
reactor = epollreactor.EPollReactor()
|
reactor = Reactor()
|
||||||
|
|
||||||
if "twisted.internet.reactor" in sys.modules:
|
if "twisted.internet.reactor" in sys.modules:
|
||||||
del sys.modules["twisted.internet.reactor"]
|
del sys.modules["twisted.internet.reactor"]
|
||||||
|
|
|
@ -260,6 +260,31 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
||||||
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
|
self.txn_ctrl.send.assert_called_with(srv2, [srv_2_event2], [])
|
||||||
self.assertEquals(3, self.txn_ctrl.send.call_count)
|
self.assertEquals(3, self.txn_ctrl.send.call_count)
|
||||||
|
|
||||||
|
def test_send_large_txns(self):
|
||||||
|
srv_1_defer = defer.Deferred()
|
||||||
|
srv_2_defer = defer.Deferred()
|
||||||
|
send_return_list = [srv_1_defer, srv_2_defer]
|
||||||
|
|
||||||
|
def do_send(x, y, z):
|
||||||
|
return make_deferred_yieldable(send_return_list.pop(0))
|
||||||
|
|
||||||
|
self.txn_ctrl.send = Mock(side_effect=do_send)
|
||||||
|
|
||||||
|
service = Mock(id=4, name="service")
|
||||||
|
event_list = [Mock(name="event%i" % (i + 1)) for i in range(200)]
|
||||||
|
for event in event_list:
|
||||||
|
self.queuer.enqueue_event(service, event)
|
||||||
|
|
||||||
|
# Expect the first event to be sent immediately.
|
||||||
|
self.txn_ctrl.send.assert_called_with(service, [event_list[0]], [])
|
||||||
|
srv_1_defer.callback(service)
|
||||||
|
# Then send the next 100 events
|
||||||
|
self.txn_ctrl.send.assert_called_with(service, event_list[1:101], [])
|
||||||
|
srv_2_defer.callback(service)
|
||||||
|
# Then the final 99 events
|
||||||
|
self.txn_ctrl.send.assert_called_with(service, event_list[101:], [])
|
||||||
|
self.assertEquals(3, self.txn_ctrl.send.call_count)
|
||||||
|
|
||||||
def test_send_single_ephemeral_no_queue(self):
|
def test_send_single_ephemeral_no_queue(self):
|
||||||
# Expect the event to be sent immediately.
|
# Expect the event to be sent immediately.
|
||||||
service = Mock(id=4, name="service")
|
service = Mock(id=4, name="service")
|
||||||
|
@ -296,3 +321,19 @@ class ApplicationServiceSchedulerQueuerTestCase(unittest.TestCase):
|
||||||
# Expect the queued events to be sent
|
# Expect the queued events to be sent
|
||||||
self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
|
self.txn_ctrl.send.assert_called_with(service, [], event_list_2 + event_list_3)
|
||||||
self.assertEquals(2, self.txn_ctrl.send.call_count)
|
self.assertEquals(2, self.txn_ctrl.send.call_count)
|
||||||
|
|
||||||
|
def test_send_large_txns_ephemeral(self):
|
||||||
|
d = defer.Deferred()
|
||||||
|
self.txn_ctrl.send = Mock(
|
||||||
|
side_effect=lambda x, y, z: make_deferred_yieldable(d)
|
||||||
|
)
|
||||||
|
# Expect the event to be sent immediately.
|
||||||
|
service = Mock(id=4, name="service")
|
||||||
|
first_chunk = [Mock(name="event%i" % (i + 1)) for i in range(100)]
|
||||||
|
second_chunk = [Mock(name="event%i" % (i + 101)) for i in range(50)]
|
||||||
|
event_list = first_chunk + second_chunk
|
||||||
|
self.queuer.enqueue_ephemeral(service, event_list)
|
||||||
|
self.txn_ctrl.send.assert_called_once_with(service, [], first_chunk)
|
||||||
|
d.callback(service)
|
||||||
|
self.txn_ctrl.send.assert_called_with(service, [], second_chunk)
|
||||||
|
self.assertEquals(2, self.txn_ctrl.send.call_count)
|
||||||
|
|
|
@ -78,7 +78,7 @@ class TerseJSONTCPTestCase(StructuredLoggingTestBase, HomeserverTestCase):
|
||||||
"server_name",
|
"server_name",
|
||||||
"name",
|
"name",
|
||||||
]
|
]
|
||||||
self.assertEqual(set(log.keys()), set(expected_log_keys))
|
self.assertCountEqual(log.keys(), expected_log_keys)
|
||||||
|
|
||||||
# It contains the data we expect.
|
# It contains the data we expect.
|
||||||
self.assertEqual(log["name"], "wally")
|
self.assertEqual(log["name"], "wally")
|
||||||
|
|
|
@ -158,8 +158,21 @@ class EmailPusherTests(HomeserverTestCase):
|
||||||
# We should get emailed about those messages
|
# We should get emailed about those messages
|
||||||
self._check_for_mail()
|
self._check_for_mail()
|
||||||
|
|
||||||
|
def test_encrypted_message(self):
|
||||||
|
room = self.helper.create_room_as(self.user_id, tok=self.access_token)
|
||||||
|
self.helper.invite(
|
||||||
|
room=room, src=self.user_id, tok=self.access_token, targ=self.others[0].id
|
||||||
|
)
|
||||||
|
self.helper.join(room=room, user=self.others[0].id, tok=self.others[0].token)
|
||||||
|
|
||||||
|
# The other user sends some messages
|
||||||
|
self.helper.send_event(room, "m.room.encrypted", {}, tok=self.others[0].token)
|
||||||
|
|
||||||
|
# We should get emailed about that message
|
||||||
|
self._check_for_mail()
|
||||||
|
|
||||||
def _check_for_mail(self):
|
def _check_for_mail(self):
|
||||||
"Check that the user receives an email notification"
|
"""Check that the user receives an email notification"""
|
||||||
|
|
||||||
# Get the stream ordering before it gets sent
|
# Get the stream ordering before it gets sent
|
||||||
pushers = self.get_success(
|
pushers = self.get_success(
|
||||||
|
|
|
@ -352,7 +352,6 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
|
||||||
self.render(request)
|
self.render(request)
|
||||||
self.assertEqual(request.code, 401)
|
self.assertEqual(request.code, 401)
|
||||||
|
|
||||||
@unittest.INFO
|
|
||||||
def test_pending_invites(self):
|
def test_pending_invites(self):
|
||||||
"""Tests that deactivating a user rejects every pending invite for them."""
|
"""Tests that deactivating a user rejects every pending invite for them."""
|
||||||
store = self.hs.get_datastore()
|
store = self.hs.get_datastore()
|
||||||
|
|
|
@ -104,7 +104,6 @@ class FallbackAuthTests(unittest.HomeserverTestCase):
|
||||||
self.assertEqual(len(attempts), 1)
|
self.assertEqual(len(attempts), 1)
|
||||||
self.assertEqual(attempts[0][0]["response"], "a")
|
self.assertEqual(attempts[0][0]["response"], "a")
|
||||||
|
|
||||||
@unittest.INFO
|
|
||||||
def test_fallback_captcha(self):
|
def test_fallback_captcha(self):
|
||||||
"""Ensure that fallback auth via a captcha works."""
|
"""Ensure that fallback auth via a captcha works."""
|
||||||
# Returns a 401 as per the spec
|
# Returns a 401 as per the spec
|
||||||
|
|
|
@ -15,237 +15,9 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from mock import Mock
|
|
||||||
|
|
||||||
from twisted.internet import defer
|
|
||||||
|
|
||||||
from synapse.util.async_helpers import ObservableDeferred
|
|
||||||
from synapse.util.caches.descriptors import cached
|
|
||||||
|
|
||||||
from tests import unittest
|
from tests import unittest
|
||||||
|
|
||||||
|
|
||||||
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_passthrough(self):
|
|
||||||
class A:
|
|
||||||
@cached()
|
|
||||||
def func(self, key):
|
|
||||||
return key
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
|
|
||||||
self.assertEquals((yield a.func("foo")), "foo")
|
|
||||||
self.assertEquals((yield a.func("bar")), "bar")
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_hit(self):
|
|
||||||
callcount = [0]
|
|
||||||
|
|
||||||
class A:
|
|
||||||
@cached()
|
|
||||||
def func(self, key):
|
|
||||||
callcount[0] += 1
|
|
||||||
return key
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
yield a.func("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 1)
|
|
||||||
|
|
||||||
self.assertEquals((yield a.func("foo")), "foo")
|
|
||||||
self.assertEquals(callcount[0], 1)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_invalidate(self):
|
|
||||||
callcount = [0]
|
|
||||||
|
|
||||||
class A:
|
|
||||||
@cached()
|
|
||||||
def func(self, key):
|
|
||||||
callcount[0] += 1
|
|
||||||
return key
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
yield a.func("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 1)
|
|
||||||
|
|
||||||
a.func.invalidate(("foo",))
|
|
||||||
|
|
||||||
yield a.func("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 2)
|
|
||||||
|
|
||||||
def test_invalidate_missing(self):
|
|
||||||
class A:
|
|
||||||
@cached()
|
|
||||||
def func(self, key):
|
|
||||||
return key
|
|
||||||
|
|
||||||
A().func.invalidate(("what",))
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_max_entries(self):
|
|
||||||
callcount = [0]
|
|
||||||
|
|
||||||
class A:
|
|
||||||
@cached(max_entries=10)
|
|
||||||
def func(self, key):
|
|
||||||
callcount[0] += 1
|
|
||||||
return key
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
|
|
||||||
for k in range(0, 12):
|
|
||||||
yield a.func(k)
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 12)
|
|
||||||
|
|
||||||
# There must have been at least 2 evictions, meaning if we calculate
|
|
||||||
# all 12 values again, we must get called at least 2 more times
|
|
||||||
for k in range(0, 12):
|
|
||||||
yield a.func(k)
|
|
||||||
|
|
||||||
self.assertTrue(
|
|
||||||
callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_prefill(self):
|
|
||||||
callcount = [0]
|
|
||||||
|
|
||||||
d = defer.succeed(123)
|
|
||||||
|
|
||||||
class A:
|
|
||||||
@cached()
|
|
||||||
def func(self, key):
|
|
||||||
callcount[0] += 1
|
|
||||||
return d
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
|
|
||||||
a.func.prefill(("foo",), ObservableDeferred(d))
|
|
||||||
|
|
||||||
self.assertEquals(a.func("foo").result, d.result)
|
|
||||||
self.assertEquals(callcount[0], 0)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_invalidate_context(self):
|
|
||||||
callcount = [0]
|
|
||||||
callcount2 = [0]
|
|
||||||
|
|
||||||
class A:
|
|
||||||
@cached()
|
|
||||||
def func(self, key):
|
|
||||||
callcount[0] += 1
|
|
||||||
return key
|
|
||||||
|
|
||||||
@cached(cache_context=True)
|
|
||||||
def func2(self, key, cache_context):
|
|
||||||
callcount2[0] += 1
|
|
||||||
return self.func(key, on_invalidate=cache_context.invalidate)
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
yield a.func2("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 1)
|
|
||||||
self.assertEquals(callcount2[0], 1)
|
|
||||||
|
|
||||||
a.func.invalidate(("foo",))
|
|
||||||
yield a.func("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 2)
|
|
||||||
self.assertEquals(callcount2[0], 1)
|
|
||||||
|
|
||||||
yield a.func2("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 2)
|
|
||||||
self.assertEquals(callcount2[0], 2)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_eviction_context(self):
|
|
||||||
callcount = [0]
|
|
||||||
callcount2 = [0]
|
|
||||||
|
|
||||||
class A:
|
|
||||||
@cached(max_entries=2)
|
|
||||||
def func(self, key):
|
|
||||||
callcount[0] += 1
|
|
||||||
return key
|
|
||||||
|
|
||||||
@cached(cache_context=True)
|
|
||||||
def func2(self, key, cache_context):
|
|
||||||
callcount2[0] += 1
|
|
||||||
return self.func(key, on_invalidate=cache_context.invalidate)
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
yield a.func2("foo")
|
|
||||||
yield a.func2("foo2")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 2)
|
|
||||||
self.assertEquals(callcount2[0], 2)
|
|
||||||
|
|
||||||
yield a.func2("foo")
|
|
||||||
self.assertEquals(callcount[0], 2)
|
|
||||||
self.assertEquals(callcount2[0], 2)
|
|
||||||
|
|
||||||
yield a.func("foo3")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 3)
|
|
||||||
self.assertEquals(callcount2[0], 2)
|
|
||||||
|
|
||||||
yield a.func2("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 4)
|
|
||||||
self.assertEquals(callcount2[0], 3)
|
|
||||||
|
|
||||||
@defer.inlineCallbacks
|
|
||||||
def test_double_get(self):
|
|
||||||
callcount = [0]
|
|
||||||
callcount2 = [0]
|
|
||||||
|
|
||||||
class A:
|
|
||||||
@cached()
|
|
||||||
def func(self, key):
|
|
||||||
callcount[0] += 1
|
|
||||||
return key
|
|
||||||
|
|
||||||
@cached(cache_context=True)
|
|
||||||
def func2(self, key, cache_context):
|
|
||||||
callcount2[0] += 1
|
|
||||||
return self.func(key, on_invalidate=cache_context.invalidate)
|
|
||||||
|
|
||||||
a = A()
|
|
||||||
a.func2.cache.cache = Mock(wraps=a.func2.cache.cache)
|
|
||||||
|
|
||||||
yield a.func2("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 1)
|
|
||||||
self.assertEquals(callcount2[0], 1)
|
|
||||||
|
|
||||||
a.func2.invalidate(("foo",))
|
|
||||||
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
|
|
||||||
|
|
||||||
yield a.func2("foo")
|
|
||||||
a.func2.invalidate(("foo",))
|
|
||||||
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 1)
|
|
||||||
self.assertEquals(callcount2[0], 2)
|
|
||||||
|
|
||||||
a.func.invalidate(("foo",))
|
|
||||||
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
|
|
||||||
yield a.func("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 2)
|
|
||||||
self.assertEquals(callcount2[0], 2)
|
|
||||||
|
|
||||||
yield a.func2("foo")
|
|
||||||
|
|
||||||
self.assertEquals(callcount[0], 2)
|
|
||||||
self.assertEquals(callcount2[0], 3)
|
|
||||||
|
|
||||||
|
|
||||||
class UpsertManyTests(unittest.HomeserverTestCase):
|
class UpsertManyTests(unittest.HomeserverTestCase):
|
||||||
def prepare(self, reactor, clock, hs):
|
def prepare(self, reactor, clock, hs):
|
||||||
self.storage = hs.get_datastore()
|
self.storage = hs.get_datastore()
|
||||||
|
|
|
@ -13,15 +13,16 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import unittest
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
from twisted.internet import defer
|
from twisted.internet import defer
|
||||||
|
|
||||||
from synapse.util.caches.deferred_cache import DeferredCache
|
from synapse.util.caches.deferred_cache import DeferredCache
|
||||||
|
|
||||||
|
from tests.unittest import TestCase
|
||||||
|
|
||||||
class DeferredCacheTestCase(unittest.TestCase):
|
|
||||||
|
class DeferredCacheTestCase(TestCase):
|
||||||
def test_empty(self):
|
def test_empty(self):
|
||||||
cache = DeferredCache("test")
|
cache = DeferredCache("test")
|
||||||
failed = False
|
failed = False
|
||||||
|
@ -36,7 +37,118 @@ class DeferredCacheTestCase(unittest.TestCase):
|
||||||
cache = DeferredCache("test")
|
cache = DeferredCache("test")
|
||||||
cache.prefill("foo", 123)
|
cache.prefill("foo", 123)
|
||||||
|
|
||||||
self.assertEquals(cache.get("foo"), 123)
|
self.assertEquals(self.successResultOf(cache.get("foo")), 123)
|
||||||
|
|
||||||
|
def test_hit_deferred(self):
|
||||||
|
cache = DeferredCache("test")
|
||||||
|
origin_d = defer.Deferred()
|
||||||
|
set_d = cache.set("k1", origin_d)
|
||||||
|
|
||||||
|
# get should return an incomplete deferred
|
||||||
|
get_d = cache.get("k1")
|
||||||
|
self.assertFalse(get_d.called)
|
||||||
|
|
||||||
|
# add a callback that will make sure that the set_d gets called before the get_d
|
||||||
|
def check1(r):
|
||||||
|
self.assertTrue(set_d.called)
|
||||||
|
return r
|
||||||
|
|
||||||
|
# TODO: Actually ObservableDeferred *doesn't* run its tests in order on py3.8.
|
||||||
|
# maybe we should fix that?
|
||||||
|
# get_d.addCallback(check1)
|
||||||
|
|
||||||
|
# now fire off all the deferreds
|
||||||
|
origin_d.callback(99)
|
||||||
|
self.assertEqual(self.successResultOf(origin_d), 99)
|
||||||
|
self.assertEqual(self.successResultOf(set_d), 99)
|
||||||
|
self.assertEqual(self.successResultOf(get_d), 99)
|
||||||
|
|
||||||
|
def test_callbacks(self):
|
||||||
|
"""Invalidation callbacks are called at the right time"""
|
||||||
|
cache = DeferredCache("test")
|
||||||
|
callbacks = set()
|
||||||
|
|
||||||
|
# start with an entry, with a callback
|
||||||
|
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
|
||||||
|
|
||||||
|
# now replace that entry with a pending result
|
||||||
|
origin_d = defer.Deferred()
|
||||||
|
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
|
||||||
|
|
||||||
|
# ... and also make a get request
|
||||||
|
get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
|
||||||
|
|
||||||
|
# we don't expect the invalidation callback for the original value to have
|
||||||
|
# been called yet, even though get() will now return a different result.
|
||||||
|
# I'm not sure if that is by design or not.
|
||||||
|
self.assertEqual(callbacks, set())
|
||||||
|
|
||||||
|
# now fire off all the deferreds
|
||||||
|
origin_d.callback(20)
|
||||||
|
self.assertEqual(self.successResultOf(set_d), 20)
|
||||||
|
self.assertEqual(self.successResultOf(get_d), 20)
|
||||||
|
|
||||||
|
# now the original invalidation callback should have been called, but none of
|
||||||
|
# the others
|
||||||
|
self.assertEqual(callbacks, {"prefill"})
|
||||||
|
callbacks.clear()
|
||||||
|
|
||||||
|
# another update should invalidate both the previous results
|
||||||
|
cache.prefill("k1", 30)
|
||||||
|
self.assertEqual(callbacks, {"set", "get"})
|
||||||
|
|
||||||
|
def test_set_fail(self):
|
||||||
|
cache = DeferredCache("test")
|
||||||
|
callbacks = set()
|
||||||
|
|
||||||
|
# start with an entry, with a callback
|
||||||
|
cache.prefill("k1", 10, callback=lambda: callbacks.add("prefill"))
|
||||||
|
|
||||||
|
# now replace that entry with a pending result
|
||||||
|
origin_d = defer.Deferred()
|
||||||
|
set_d = cache.set("k1", origin_d, callback=lambda: callbacks.add("set"))
|
||||||
|
|
||||||
|
# ... and also make a get request
|
||||||
|
get_d = cache.get("k1", callback=lambda: callbacks.add("get"))
|
||||||
|
|
||||||
|
# none of the callbacks should have been called yet
|
||||||
|
self.assertEqual(callbacks, set())
|
||||||
|
|
||||||
|
# oh noes! fails!
|
||||||
|
e = Exception("oops")
|
||||||
|
origin_d.errback(e)
|
||||||
|
self.assertIs(self.failureResultOf(set_d, Exception).value, e)
|
||||||
|
self.assertIs(self.failureResultOf(get_d, Exception).value, e)
|
||||||
|
|
||||||
|
# the callbacks for the failed requests should have been called.
|
||||||
|
# I'm not sure if this is deliberate or not.
|
||||||
|
self.assertEqual(callbacks, {"get", "set"})
|
||||||
|
callbacks.clear()
|
||||||
|
|
||||||
|
# the old value should still be returned now?
|
||||||
|
get_d2 = cache.get("k1", callback=lambda: callbacks.add("get2"))
|
||||||
|
self.assertEqual(self.successResultOf(get_d2), 10)
|
||||||
|
|
||||||
|
# replacing the value now should run the callbacks for those requests
|
||||||
|
# which got the original result
|
||||||
|
cache.prefill("k1", 30)
|
||||||
|
self.assertEqual(callbacks, {"prefill", "get2"})
|
||||||
|
|
||||||
|
def test_get_immediate(self):
|
||||||
|
cache = DeferredCache("test")
|
||||||
|
d1 = defer.Deferred()
|
||||||
|
cache.set("key1", d1)
|
||||||
|
|
||||||
|
# get_immediate should return default
|
||||||
|
v = cache.get_immediate("key1", 1)
|
||||||
|
self.assertEqual(v, 1)
|
||||||
|
|
||||||
|
# now complete the set
|
||||||
|
d1.callback(2)
|
||||||
|
|
||||||
|
# get_immediate should return result
|
||||||
|
v = cache.get_immediate("key1", 1)
|
||||||
|
self.assertEqual(v, 2)
|
||||||
|
|
||||||
def test_invalidate(self):
|
def test_invalidate(self):
|
||||||
cache = DeferredCache("test")
|
cache = DeferredCache("test")
|
||||||
|
@ -66,23 +178,24 @@ class DeferredCacheTestCase(unittest.TestCase):
|
||||||
d2 = defer.Deferred()
|
d2 = defer.Deferred()
|
||||||
cache.set("key2", d2, partial(record_callback, 1))
|
cache.set("key2", d2, partial(record_callback, 1))
|
||||||
|
|
||||||
# lookup should return observable deferreds
|
# lookup should return pending deferreds
|
||||||
self.assertFalse(cache.get("key1").has_called())
|
self.assertFalse(cache.get("key1").called)
|
||||||
self.assertFalse(cache.get("key2").has_called())
|
self.assertFalse(cache.get("key2").called)
|
||||||
|
|
||||||
# let one of the lookups complete
|
# let one of the lookups complete
|
||||||
d2.callback("result2")
|
d2.callback("result2")
|
||||||
|
|
||||||
# for now at least, the cache will return real results rather than an
|
# now the cache will return a completed deferred
|
||||||
# observabledeferred
|
self.assertEqual(self.successResultOf(cache.get("key2")), "result2")
|
||||||
self.assertEqual(cache.get("key2"), "result2")
|
|
||||||
|
|
||||||
# now do the invalidation
|
# now do the invalidation
|
||||||
cache.invalidate_all()
|
cache.invalidate_all()
|
||||||
|
|
||||||
# lookup should return none
|
# lookup should fail
|
||||||
self.assertIsNone(cache.get("key1", None))
|
with self.assertRaises(KeyError):
|
||||||
self.assertIsNone(cache.get("key2", None))
|
cache.get("key1")
|
||||||
|
with self.assertRaises(KeyError):
|
||||||
|
cache.get("key2")
|
||||||
|
|
||||||
# both callbacks should have been callbacked
|
# both callbacks should have been callbacked
|
||||||
self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
|
self.assertTrue(callback_record[0], "Invalidation callback for key1 not called")
|
||||||
|
@ -90,7 +203,8 @@ class DeferredCacheTestCase(unittest.TestCase):
|
||||||
|
|
||||||
# letting the other lookup complete should do nothing
|
# letting the other lookup complete should do nothing
|
||||||
d1.callback("result1")
|
d1.callback("result1")
|
||||||
self.assertIsNone(cache.get("key1", None))
|
with self.assertRaises(KeyError):
|
||||||
|
cache.get("key1", None)
|
||||||
|
|
||||||
def test_eviction(self):
|
def test_eviction(self):
|
||||||
cache = DeferredCache(
|
cache = DeferredCache(
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import logging
|
import logging
|
||||||
|
from typing import Set
|
||||||
|
|
||||||
import mock
|
import mock
|
||||||
|
|
||||||
|
@ -130,6 +131,57 @@ class DescriptorTestCase(unittest.TestCase):
|
||||||
d = obj.fn(1)
|
d = obj.fn(1)
|
||||||
self.failureResultOf(d, SynapseError)
|
self.failureResultOf(d, SynapseError)
|
||||||
|
|
||||||
|
def test_cache_with_async_exception(self):
|
||||||
|
"""The wrapped function returns a failure
|
||||||
|
"""
|
||||||
|
|
||||||
|
class Cls:
|
||||||
|
result = None
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
@cached()
|
||||||
|
def fn(self, arg1):
|
||||||
|
self.call_count += 1
|
||||||
|
return self.result
|
||||||
|
|
||||||
|
obj = Cls()
|
||||||
|
callbacks = set() # type: Set[str]
|
||||||
|
|
||||||
|
# set off an asynchronous request
|
||||||
|
obj.result = origin_d = defer.Deferred()
|
||||||
|
|
||||||
|
d1 = obj.fn(1, on_invalidate=lambda: callbacks.add("d1"))
|
||||||
|
self.assertFalse(d1.called)
|
||||||
|
|
||||||
|
# a second request should also return a deferred, but should not call the
|
||||||
|
# function itself.
|
||||||
|
d2 = obj.fn(1, on_invalidate=lambda: callbacks.add("d2"))
|
||||||
|
self.assertFalse(d2.called)
|
||||||
|
self.assertEqual(obj.call_count, 1)
|
||||||
|
|
||||||
|
# no callbacks yet
|
||||||
|
self.assertEqual(callbacks, set())
|
||||||
|
|
||||||
|
# the original request fails
|
||||||
|
e = Exception("bzz")
|
||||||
|
origin_d.errback(e)
|
||||||
|
|
||||||
|
# ... which should cause the lookups to fail similarly
|
||||||
|
self.assertIs(self.failureResultOf(d1, Exception).value, e)
|
||||||
|
self.assertIs(self.failureResultOf(d2, Exception).value, e)
|
||||||
|
|
||||||
|
# ... and the callbacks to have been, uh, called.
|
||||||
|
self.assertEqual(callbacks, {"d1", "d2"})
|
||||||
|
|
||||||
|
# ... leaving the cache empty
|
||||||
|
self.assertEqual(len(obj.fn.cache.cache), 0)
|
||||||
|
|
||||||
|
# and a second call should work as normal
|
||||||
|
obj.result = defer.succeed(100)
|
||||||
|
d3 = obj.fn(1)
|
||||||
|
self.assertEqual(self.successResultOf(d3), 100)
|
||||||
|
self.assertEqual(obj.call_count, 2)
|
||||||
|
|
||||||
def test_cache_logcontexts(self):
|
def test_cache_logcontexts(self):
|
||||||
"""Check that logcontexts are set and restored correctly when
|
"""Check that logcontexts are set and restored correctly when
|
||||||
using the cache."""
|
using the cache."""
|
||||||
|
@ -311,6 +363,235 @@ class DescriptorTestCase(unittest.TestCase):
|
||||||
self.failureResultOf(d, SynapseError)
|
self.failureResultOf(d, SynapseError)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheDecoratorTestCase(unittest.HomeserverTestCase):
|
||||||
|
"""More tests for @cached
|
||||||
|
|
||||||
|
The following is a set of tests that got lost in a different file for a while.
|
||||||
|
|
||||||
|
There are probably duplicates of the tests in DescriptorTestCase. Ideally the
|
||||||
|
duplicates would be removed and the two sets of classes combined.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_passthrough(self):
|
||||||
|
class A:
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
return key
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
|
||||||
|
self.assertEquals((yield a.func("foo")), "foo")
|
||||||
|
self.assertEquals((yield a.func("bar")), "bar")
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_hit(self):
|
||||||
|
callcount = [0]
|
||||||
|
|
||||||
|
class A:
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
|
||||||
|
self.assertEquals((yield a.func("foo")), "foo")
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_invalidate(self):
|
||||||
|
callcount = [0]
|
||||||
|
|
||||||
|
class A:
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
|
||||||
|
a.func.invalidate(("foo",))
|
||||||
|
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
|
||||||
|
def test_invalidate_missing(self):
|
||||||
|
class A:
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
return key
|
||||||
|
|
||||||
|
A().func.invalidate(("what",))
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_max_entries(self):
|
||||||
|
callcount = [0]
|
||||||
|
|
||||||
|
class A:
|
||||||
|
@cached(max_entries=10)
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
|
||||||
|
for k in range(0, 12):
|
||||||
|
yield a.func(k)
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 12)
|
||||||
|
|
||||||
|
# There must have been at least 2 evictions, meaning if we calculate
|
||||||
|
# all 12 values again, we must get called at least 2 more times
|
||||||
|
for k in range(0, 12):
|
||||||
|
yield a.func(k)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
callcount[0] >= 14, msg="Expected callcount >= 14, got %d" % (callcount[0])
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_prefill(self):
|
||||||
|
callcount = [0]
|
||||||
|
|
||||||
|
d = defer.succeed(123)
|
||||||
|
|
||||||
|
class A:
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return d
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
|
||||||
|
a.func.prefill(("foo",), 456)
|
||||||
|
|
||||||
|
self.assertEquals(a.func("foo").result, 456)
|
||||||
|
self.assertEquals(callcount[0], 0)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_invalidate_context(self):
|
||||||
|
callcount = [0]
|
||||||
|
callcount2 = [0]
|
||||||
|
|
||||||
|
class A:
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
@cached(cache_context=True)
|
||||||
|
def func2(self, key, cache_context):
|
||||||
|
callcount2[0] += 1
|
||||||
|
return self.func(key, on_invalidate=cache_context.invalidate)
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
self.assertEquals(callcount2[0], 1)
|
||||||
|
|
||||||
|
a.func.invalidate(("foo",))
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 1)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_eviction_context(self):
|
||||||
|
callcount = [0]
|
||||||
|
callcount2 = [0]
|
||||||
|
|
||||||
|
class A:
|
||||||
|
@cached(max_entries=2)
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
@cached(cache_context=True)
|
||||||
|
def func2(self, key, cache_context):
|
||||||
|
callcount2[0] += 1
|
||||||
|
return self.func(key, on_invalidate=cache_context.invalidate)
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
yield a.func2("foo")
|
||||||
|
yield a.func2("foo2")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
yield a.func("foo3")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 3)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 4)
|
||||||
|
self.assertEquals(callcount2[0], 3)
|
||||||
|
|
||||||
|
@defer.inlineCallbacks
|
||||||
|
def test_double_get(self):
|
||||||
|
callcount = [0]
|
||||||
|
callcount2 = [0]
|
||||||
|
|
||||||
|
class A:
|
||||||
|
@cached()
|
||||||
|
def func(self, key):
|
||||||
|
callcount[0] += 1
|
||||||
|
return key
|
||||||
|
|
||||||
|
@cached(cache_context=True)
|
||||||
|
def func2(self, key, cache_context):
|
||||||
|
callcount2[0] += 1
|
||||||
|
return self.func(key, on_invalidate=cache_context.invalidate)
|
||||||
|
|
||||||
|
a = A()
|
||||||
|
a.func2.cache.cache = mock.Mock(wraps=a.func2.cache.cache)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
self.assertEquals(callcount2[0], 1)
|
||||||
|
|
||||||
|
a.func2.invalidate(("foo",))
|
||||||
|
self.assertEquals(a.func2.cache.cache.pop.call_count, 1)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
a.func2.invalidate(("foo",))
|
||||||
|
self.assertEquals(a.func2.cache.cache.pop.call_count, 2)
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 1)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
a.func.invalidate(("foo",))
|
||||||
|
self.assertEquals(a.func2.cache.cache.pop.call_count, 3)
|
||||||
|
yield a.func("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 2)
|
||||||
|
|
||||||
|
yield a.func2("foo")
|
||||||
|
|
||||||
|
self.assertEquals(callcount[0], 2)
|
||||||
|
self.assertEquals(callcount2[0], 3)
|
||||||
|
|
||||||
|
|
||||||
class CachedListDescriptorTestCase(unittest.TestCase):
|
class CachedListDescriptorTestCase(unittest.TestCase):
|
||||||
@defer.inlineCallbacks
|
@defer.inlineCallbacks
|
||||||
def test_cache(self):
|
def test_cache(self):
|
||||||
|
|
|
@ -19,7 +19,8 @@ from mock import Mock
|
||||||
from synapse.util.caches.lrucache import LruCache
|
from synapse.util.caches.lrucache import LruCache
|
||||||
from synapse.util.caches.treecache import TreeCache
|
from synapse.util.caches.treecache import TreeCache
|
||||||
|
|
||||||
from .. import unittest
|
from tests import unittest
|
||||||
|
from tests.unittest import override_config
|
||||||
|
|
||||||
|
|
||||||
class LruCacheTestCase(unittest.HomeserverTestCase):
|
class LruCacheTestCase(unittest.HomeserverTestCase):
|
||||||
|
@ -59,7 +60,7 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
|
||||||
self.assertEquals(cache.pop("key"), None)
|
self.assertEquals(cache.pop("key"), None)
|
||||||
|
|
||||||
def test_del_multi(self):
|
def test_del_multi(self):
|
||||||
cache = LruCache(4, 2, cache_type=TreeCache)
|
cache = LruCache(4, keylen=2, cache_type=TreeCache)
|
||||||
cache[("animal", "cat")] = "mew"
|
cache[("animal", "cat")] = "mew"
|
||||||
cache[("animal", "dog")] = "woof"
|
cache[("animal", "dog")] = "woof"
|
||||||
cache[("vehicles", "car")] = "vroom"
|
cache[("vehicles", "car")] = "vroom"
|
||||||
|
@ -83,6 +84,11 @@ class LruCacheTestCase(unittest.HomeserverTestCase):
|
||||||
cache.clear()
|
cache.clear()
|
||||||
self.assertEquals(len(cache), 0)
|
self.assertEquals(len(cache), 0)
|
||||||
|
|
||||||
|
@override_config({"caches": {"per_cache_factors": {"mycache": 10}}})
|
||||||
|
def test_special_size(self):
|
||||||
|
cache = LruCache(10, "mycache")
|
||||||
|
self.assertEqual(cache.max_size, 100)
|
||||||
|
|
||||||
|
|
||||||
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
|
class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
|
||||||
def test_get(self):
|
def test_get(self):
|
||||||
|
@ -160,7 +166,7 @@ class LruCacheCallbacksTestCase(unittest.HomeserverTestCase):
|
||||||
m2 = Mock()
|
m2 = Mock()
|
||||||
m3 = Mock()
|
m3 = Mock()
|
||||||
m4 = Mock()
|
m4 = Mock()
|
||||||
cache = LruCache(4, 2, cache_type=TreeCache)
|
cache = LruCache(4, keylen=2, cache_type=TreeCache)
|
||||||
|
|
||||||
cache.set(("a", "1"), "value", callbacks=[m1])
|
cache.set(("a", "1"), "value", callbacks=[m1])
|
||||||
cache.set(("a", "2"), "value", callbacks=[m2])
|
cache.set(("a", "2"), "value", callbacks=[m2])
|
||||||
|
|
5
tox.ini
5
tox.ini
|
@ -158,12 +158,9 @@ commands=
|
||||||
coverage html
|
coverage html
|
||||||
|
|
||||||
[testenv:mypy]
|
[testenv:mypy]
|
||||||
skip_install = True
|
|
||||||
deps =
|
deps =
|
||||||
{[base]deps}
|
{[base]deps}
|
||||||
mypy==0.782
|
extras = all,mypy
|
||||||
mypy-zope
|
|
||||||
extras = all
|
|
||||||
commands = mypy
|
commands = mypy
|
||||||
|
|
||||||
# To find all folders that pass mypy you run:
|
# To find all folders that pass mypy you run:
|
||||||
|
|
Loading…
Reference in a new issue