Merge remote-tracking branch 'origin/develop' into matrix-org-hotfixes

This commit is contained in:
Erik Johnston 2021-03-03 10:59:10 +00:00
commit 81c7b0515d
49 changed files with 466 additions and 278 deletions

3
.gitignore vendored
View file

@ -6,13 +6,14 @@
*.egg *.egg
*.egg-info *.egg-info
*.lock *.lock
*.pyc *.py[cod]
*.snap *.snap
*.tac *.tac
_trial_temp/ _trial_temp/
_trial_temp*/ _trial_temp*/
/out /out
.DS_Store .DS_Store
__pycache__/
# stuff that is likely to exist when you run a server locally # stuff that is likely to exist when you run a server locally
/*.db /*.db

View file

@ -1 +0,0 @@
Added a fix that invalidates cache for empty timed-out sync responses.

1
changelog.d/9372.feature Normal file
View file

@ -0,0 +1 @@
The `no_proxy` and `NO_PROXY` environment variables are now respected in proxied HTTP clients with the lowercase form taking precedence if both are present. Additionally, the lowercase `https_proxy` environment variable is now respected in proxied HTTP clients on top of existing support for the uppercase `HTTPS_PROXY` form and takes precedence if both are present. Contributed by Timothy Leung.

1
changelog.d/9497.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a long-standing bug where the media repository could leak file descriptors while previewing media.

1
changelog.d/9502.misc Normal file
View file

@ -0,0 +1 @@
Allow python to generate bytecode for synapse.

1
changelog.d/9503.bugfix Normal file
View file

@ -0,0 +1 @@
Fix missing chain cover index due to a schema delta not being applied correctly. Only affected servers that ran development versions.

1
changelog.d/9506.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug introduced in v1.25.0 where `/_synapse/admin/join/` would fail when given a room alias.

1
changelog.d/9515.misc Normal file
View file

@ -0,0 +1 @@
Fix incorrect type hints.

1
changelog.d/9516.bugfix Normal file
View file

@ -0,0 +1 @@
Fix a bug where users' pushers were not all deleted when they deactivated their account.

1
changelog.d/9519.misc Normal file
View file

@ -0,0 +1 @@
Add type hints to device and event report admin API.

1
changelog.d/9530.bugfix Normal file
View file

@ -0,0 +1 @@
Prevent presence background jobs from running when presence is disabled.

View file

@ -58,10 +58,10 @@ trap "rm -r $tmpdir" EXIT
cp -r tests "$tmpdir" cp -r tests "$tmpdir"
PYTHONPATH="$tmpdir" \ PYTHONPATH="$tmpdir" \
"${TARGET_PYTHON}" -B -m twisted.trial --reporter=text -j2 tests "${TARGET_PYTHON}" -m twisted.trial --reporter=text -j2 tests
# build the config file # build the config file
"${TARGET_PYTHON}" -B "${VIRTUALENV_DIR}/bin/generate_config" \ "${TARGET_PYTHON}" "${VIRTUALENV_DIR}/bin/generate_config" \
--config-dir="/etc/matrix-synapse" \ --config-dir="/etc/matrix-synapse" \
--data-dir="/var/lib/matrix-synapse" | --data-dir="/var/lib/matrix-synapse" |
perl -pe ' perl -pe '
@ -87,7 +87,7 @@ PYTHONPATH="$tmpdir" \
' > "${PACKAGE_BUILD_DIR}/etc/matrix-synapse/homeserver.yaml" ' > "${PACKAGE_BUILD_DIR}/etc/matrix-synapse/homeserver.yaml"
# build the log config file # build the log config file
"${TARGET_PYTHON}" -B "${VIRTUALENV_DIR}/bin/generate_log_config" \ "${TARGET_PYTHON}" "${VIRTUALENV_DIR}/bin/generate_log_config" \
--output-file="${PACKAGE_BUILD_DIR}/etc/matrix-synapse/log.yaml" --output-file="${PACKAGE_BUILD_DIR}/etc/matrix-synapse/log.yaml"
# add a dependency on the right version of python to substvars. # add a dependency on the right version of python to substvars.

7
debian/changelog vendored
View file

@ -1,3 +1,10 @@
matrix-synapse-py3 (1.29.0) UNRELEASED; urgency=medium
[ Jonathan de Jong ]
* Remove the python -B flag (don't generate bytecode) in scripts and documentation.
-- Synapse Packaging team <packages@matrix.org> Fri, 26 Feb 2021 14:41:31 +0100
matrix-synapse-py3 (1.28.0) stable; urgency=medium matrix-synapse-py3 (1.28.0) stable; urgency=medium
* New synapse release 1.28.0. * New synapse release 1.28.0.

2
debian/synctl.1 vendored
View file

@ -44,7 +44,7 @@ Configuration file may be generated as follows:
. .
.nf .nf
$ python \-B \-m synapse\.app\.homeserver \-c config\.yaml \-\-generate\-config \-\-server\-name=<server name> $ python \-m synapse\.app\.homeserver \-c config\.yaml \-\-generate\-config \-\-server\-name=<server name>
. .
.fi .fi
. .

2
debian/synctl.ronn vendored
View file

@ -41,7 +41,7 @@ process.
Configuration file may be generated as follows: Configuration file may be generated as follows:
$ python -B -m synapse.app.homeserver -c config.yaml --generate-config --server-name=<server name> $ python -m synapse.app.homeserver -c config.yaml --generate-config --server-name=<server name>
## ENVIRONMENT ## ENVIRONMENT

View file

@ -47,6 +47,7 @@ from synapse.storage.databases.main.events_bg_updates import (
from synapse.storage.databases.main.media_repository import ( from synapse.storage.databases.main.media_repository import (
MediaRepositoryBackgroundUpdateStore, MediaRepositoryBackgroundUpdateStore,
) )
from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.registration import ( from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore, RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart, find_max_generated_user_id_localpart,
@ -177,6 +178,7 @@ class Store(
UserDirectoryBackgroundUpdateStore, UserDirectoryBackgroundUpdateStore,
EndToEndKeyBackgroundStore, EndToEndKeyBackgroundStore,
StatsStore, StatsStore,
PusherWorkerStore,
): ):
def execute(self, f, *args, **kwargs): def execute(self, f, *args, **kwargs):
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs) return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)

View file

@ -17,8 +17,6 @@ import sys
from synapse import python_dependencies # noqa: E402 from synapse import python_dependencies # noqa: E402
sys.dont_write_bytecode = True
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
try: try:

View file

@ -36,7 +36,7 @@ import attr
import bcrypt import bcrypt
import pymacaroons import pymacaroons
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import ( from synapse.api.errors import (
@ -481,7 +481,7 @@ class AuthHandler(BaseHandler):
sid = authdict["session"] sid = authdict["session"]
# Convert the URI and method to strings. # Convert the URI and method to strings.
uri = request.uri.decode("utf-8") uri = request.uri.decode("utf-8") # type: ignore
method = request.method.decode("utf-8") method = request.method.decode("utf-8")
# If there's no session ID, create a new session. # If there's no session ID, create a new session.

View file

@ -274,22 +274,25 @@ class PresenceHandler(BasePresenceHandler):
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer") self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
# Start a LoopingCall in 30s that fires every 5s. if self._presence_enabled:
# The initial delay is to allow disconnected clients a chance to # Start a LoopingCall in 30s that fires every 5s.
# reconnect before we treat them as offline. # The initial delay is to allow disconnected clients a chance to
def run_timeout_handler(): # reconnect before we treat them as offline.
return run_as_background_process( def run_timeout_handler():
"handle_presence_timeouts", self._handle_timeouts return run_as_background_process(
"handle_presence_timeouts", self._handle_timeouts
)
self.clock.call_later(
30, self.clock.looping_call, run_timeout_handler, 5000
) )
self.clock.call_later(30, self.clock.looping_call, run_timeout_handler, 5000) def run_persister():
return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes
)
def run_persister(): self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes
)
self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
LaterGauge( LaterGauge(
"synapse_handlers_presence_wheel_timer_size", "synapse_handlers_presence_wheel_timer_size",
@ -299,7 +302,7 @@ class PresenceHandler(BasePresenceHandler):
) )
# Used to handle sending of presence to newly joined users/servers # Used to handle sending of presence to newly joined users/servers
if hs.config.use_presence: if self._presence_enabled:
self.notifier.add_replication_callback(self.notify_new_event) self.notifier.add_replication_callback(self.notify_new_event)
# Presence is best effort and quickly heals itself, so lets just always # Presence is best effort and quickly heals itself, so lets just always

View file

@ -31,8 +31,8 @@ from urllib.parse import urlencode
import attr import attr
from typing_extensions import NoReturn, Protocol from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request
from twisted.web.iweb import IRequest from twisted.web.iweb import IRequest
from twisted.web.server import Request
from synapse.api.constants import LoginType from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError

View file

@ -278,9 +278,8 @@ class SyncHandler:
user_id = sync_config.user.to_string() user_id = sync_config.user.to_string()
await self.auth.check_auth_blocking(requester=requester) await self.auth.check_auth_blocking(requester=requester)
res = await self.response_cache.wrap_conditional( res = await self.response_cache.wrap(
sync_config.request_key, sync_config.request_key,
lambda result: since_token != result.next_batch,
self._wait_for_sync_for_user, self._wait_for_sync_for_user,
sync_config, sync_config,
since_token, since_token,

View file

@ -289,8 +289,7 @@ class SimpleHttpClient:
treq_args: Dict[str, Any] = {}, treq_args: Dict[str, Any] = {},
ip_whitelist: Optional[IPSet] = None, ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None, ip_blacklist: Optional[IPSet] = None,
http_proxy: Optional[bytes] = None, use_proxy: bool = False,
https_proxy: Optional[bytes] = None,
): ):
""" """
Args: Args:
@ -300,8 +299,8 @@ class SimpleHttpClient:
we may not request. we may not request.
ip_whitelist: The whitelisted IP addresses, that we can ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist. request if it were otherwise caught in a blacklist.
http_proxy: proxy server to use for http connections. host[:port] use_proxy: Whether proxy settings should be discovered and used
https_proxy: proxy server to use for https connections. host[:port] from conventional environment variables.
""" """
self.hs = hs self.hs = hs
@ -345,8 +344,7 @@ class SimpleHttpClient:
connectTimeout=15, connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(), contextFactory=self.hs.get_http_client_context_factory(),
pool=pool, pool=pool,
http_proxy=http_proxy, use_proxy=use_proxy,
https_proxy=https_proxy,
) )
if self._ip_blacklist: if self._ip_blacklist:
@ -750,7 +748,32 @@ class BodyExceededMaxSize(Exception):
"""The maximum allowed size of the HTTP body was exceeded.""" """The maximum allowed size of the HTTP body was exceeded."""
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data."""
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred
def _maybe_fail(self):
"""
Report a max size exceed error and disconnect the first time this is called.
"""
if not self.deferred.called:
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
self.transport.abortConnection()
def dataReceived(self, data: bytes) -> None:
self._maybe_fail()
def connectionLost(self, reason: Failure) -> None:
self._maybe_fail()
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol): class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
def __init__( def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
): ):
@ -807,13 +830,15 @@ def read_body_with_max_size(
Returns: Returns:
A Deferred which resolves to the length of the read body. A Deferred which resolves to the length of the read body.
""" """
d = defer.Deferred()
# If the Content-Length header gives a size larger than the maximum allowed # If the Content-Length header gives a size larger than the maximum allowed
# size, do not bother downloading the body. # size, do not bother downloading the body.
if max_size is not None and response.length != UNKNOWN_LENGTH: if max_size is not None and response.length != UNKNOWN_LENGTH:
if response.length > max_size: if response.length > max_size:
return defer.fail(BodyExceededMaxSize()) response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
return d
d = defer.Deferred()
response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size)) response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
return d return d

View file

@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re import re
from urllib.request import getproxies_environment, proxy_bypass_environment
from zope.interface import implementer from zope.interface import implementer
@ -58,6 +59,9 @@ class ProxyAgent(_AgentBase):
pool (HTTPConnectionPool|None): connection pool to be used. If None, a pool (HTTPConnectionPool|None): connection pool to be used. If None, a
non-persistent pool instance will be created. non-persistent pool instance will be created.
use_proxy (bool): Whether proxy settings should be discovered and used
from conventional environment variables.
""" """
def __init__( def __init__(
@ -68,8 +72,7 @@ class ProxyAgent(_AgentBase):
connectTimeout=None, connectTimeout=None,
bindAddress=None, bindAddress=None,
pool=None, pool=None,
http_proxy=None, use_proxy=False,
https_proxy=None,
): ):
_AgentBase.__init__(self, reactor, pool) _AgentBase.__init__(self, reactor, pool)
@ -84,6 +87,15 @@ class ProxyAgent(_AgentBase):
if bindAddress is not None: if bindAddress is not None:
self._endpoint_kwargs["bindAddress"] = bindAddress self._endpoint_kwargs["bindAddress"] = bindAddress
http_proxy = None
https_proxy = None
no_proxy = None
if use_proxy:
proxies = getproxies_environment()
http_proxy = proxies["http"].encode() if "http" in proxies else None
https_proxy = proxies["https"].encode() if "https" in proxies else None
no_proxy = proxies["no"] if "no" in proxies else None
self.http_proxy_endpoint = _http_proxy_endpoint( self.http_proxy_endpoint = _http_proxy_endpoint(
http_proxy, self.proxy_reactor, **self._endpoint_kwargs http_proxy, self.proxy_reactor, **self._endpoint_kwargs
) )
@ -92,6 +104,8 @@ class ProxyAgent(_AgentBase):
https_proxy, self.proxy_reactor, **self._endpoint_kwargs https_proxy, self.proxy_reactor, **self._endpoint_kwargs
) )
self.no_proxy = no_proxy
self._policy_for_https = contextFactory self._policy_for_https = contextFactory
self._reactor = reactor self._reactor = reactor
@ -139,13 +153,28 @@ class ProxyAgent(_AgentBase):
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port) pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
request_path = parsed_uri.originForm request_path = parsed_uri.originForm
if parsed_uri.scheme == b"http" and self.http_proxy_endpoint: should_skip_proxy = False
if self.no_proxy is not None:
should_skip_proxy = proxy_bypass_environment(
parsed_uri.host.decode(),
proxies={"no": self.no_proxy},
)
if (
parsed_uri.scheme == b"http"
and self.http_proxy_endpoint
and not should_skip_proxy
):
# Cache *all* connections under the same key, since we are only # Cache *all* connections under the same key, since we are only
# connecting to a single destination, the proxy: # connecting to a single destination, the proxy:
pool_key = ("http-proxy", self.http_proxy_endpoint) pool_key = ("http-proxy", self.http_proxy_endpoint)
endpoint = self.http_proxy_endpoint endpoint = self.http_proxy_endpoint
request_path = uri request_path = uri
elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint: elif (
parsed_uri.scheme == b"https"
and self.https_proxy_endpoint
and not should_skip_proxy
):
endpoint = HTTPConnectProxyEndpoint( endpoint = HTTPConnectProxyEndpoint(
self.proxy_reactor, self.proxy_reactor,
self.https_proxy_endpoint, self.https_proxy_endpoint,

View file

@ -15,9 +15,10 @@
import logging import logging
from typing import TYPE_CHECKING, List, Optional, Tuple from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.web.http import Request from twisted.web.server import Request
from synapse.http.servlet import parse_json_object_from_request from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.replication.http._base import ReplicationEndpoint from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID from synapse.types import JsonDict, Requester, UserID
from synapse.util.distributor import user_left_room from synapse.util.distributor import user_left_room
@ -78,7 +79,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
} }
async def _handle_request( # type: ignore async def _handle_request( # type: ignore
self, request: Request, room_id: str, user_id: str self, request: SynapseRequest, room_id: str, user_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -86,7 +87,6 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
event_content = content["content"] event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"]) requester = Requester.deserialize(self.store, content["requester"])
request.requester = requester request.requester = requester
logger.info("remote_join: %s into room: %s", user_id, room_id) logger.info("remote_join: %s into room: %s", user_id, room_id)
@ -147,7 +147,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
} }
async def _handle_request( # type: ignore async def _handle_request( # type: ignore
self, request: Request, invite_event_id: str self, request: SynapseRequest, invite_event_id: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request) content = parse_json_object_from_request(request)
@ -155,7 +155,6 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
event_content = content["content"] event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"]) requester = Requester.deserialize(self.store, content["requester"])
request.requester = requester request.requester = requester
# hopefully we're now on the master, so this won't recurse! # hopefully we're now on the master, so this won't recurse!

View file

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import NotFoundError, SynapseError from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.servlet import ( from synapse.http.servlet import (
@ -20,8 +21,12 @@ from synapse.http.servlet import (
assert_params_in_dict, assert_params_in_dict,
parse_json_object_from_request, parse_json_object_from_request,
) )
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import UserID from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -35,14 +40,16 @@ class DeviceRestServlet(RestServlet):
"/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2" "/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2"
) )
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, user_id, device_id): async def on_GET(
self, request: SynapseRequest, user_id, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -58,7 +65,9 @@ class DeviceRestServlet(RestServlet):
) )
return 200, device return 200, device
async def on_DELETE(self, request, user_id, device_id): async def on_DELETE(
self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -72,7 +81,9 @@ class DeviceRestServlet(RestServlet):
await self.device_handler.delete_device(target_user.to_string(), device_id) await self.device_handler.delete_device(target_user.to_string(), device_id)
return 200, {} return 200, {}
async def on_PUT(self, request, user_id, device_id): async def on_PUT(
self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -97,7 +108,7 @@ class DevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
""" """
Args: Args:
hs (synapse.server.HomeServer): server hs (synapse.server.HomeServer): server
@ -107,7 +118,9 @@ class DevicesRestServlet(RestServlet):
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, user_id): async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)
@ -130,13 +143,15 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2") PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler() self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_POST(self, request, user_id): async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id) target_user = UserID.from_string(user_id)

View file

@ -14,10 +14,16 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -45,12 +51,12 @@ class EventReportsRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports$") PATTERNS = admin_patterns("/event_reports$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0)
@ -106,26 +112,28 @@ class EventReportDetailRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$") PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def on_GET(self, request, report_id): async def on_GET(
self, request: SynapseRequest, report_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
message = ( message = (
"The report_id parameter must be a string representing a positive integer." "The report_id parameter must be a string representing a positive integer."
) )
try: try:
report_id = int(report_id) resolved_report_id = int(report_id)
except ValueError: except ValueError:
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
if report_id < 0: if resolved_report_id < 0:
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM) raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
ret = await self.store.get_event_report(report_id) ret = await self.store.get_event_report(resolved_report_id)
if not ret: if not ret:
raise NotFoundError("Event report not found") raise NotFoundError("Event report not found")

View file

@ -17,7 +17,7 @@
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer from synapse.http.servlet import RestServlet, parse_boolean, parse_integer

View file

@ -44,6 +44,48 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ResolveRoomIdMixin:
def __init__(self, hs: "HomeServer"):
self.room_member_handler = hs.get_room_member_handler()
async def resolve_room_id(
self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None
) -> Tuple[str, Optional[List[str]]]:
"""
Resolve a room identifier to a room ID, if necessary.
This also performanes checks to ensure the room ID is of the proper form.
Args:
room_identifier: The room ID or alias.
remote_room_hosts: The potential remote room hosts to use.
Returns:
The resolved room ID.
Raises:
SynapseError if the room ID is of the wrong form.
"""
if RoomID.is_valid(room_identifier):
resolved_room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
(
room_id,
remote_room_hosts,
) = await self.room_member_handler.lookup_room_alias(room_alias)
resolved_room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
if not resolved_room_id:
raise SynapseError(
400, "Unknown room ID or room alias %s" % room_identifier
)
return resolved_room_id, remote_room_hosts
class ShutdownRoomRestServlet(RestServlet): class ShutdownRoomRestServlet(RestServlet):
"""Shuts down a room by removing all local users from the room and blocking """Shuts down a room by removing all local users from the room and blocking
all future invites and joins to the room. Any local aliases will be repointed all future invites and joins to the room. Any local aliases will be repointed
@ -334,14 +376,14 @@ class RoomStateRestServlet(RestServlet):
return 200, ret return 200, ret
class JoinRoomAliasServlet(RestServlet): class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)") PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
@ -362,22 +404,16 @@ class JoinRoomAliasServlet(RestServlet):
if not await self.admin_handler.get_user(target_user): if not await self.admin_handler.get_user(target_user):
raise NotFoundError("User not found") raise NotFoundError("User not found")
if RoomID.is_valid(room_identifier): # Get the room ID from the identifier.
room_id = room_identifier try:
try: remote_room_hosts = [
remote_room_hosts = [ x.decode("ascii") for x in request.args[b"server_name"]
x.decode("ascii") for x in request.args[b"server_name"] ] # type: Optional[List[str]]
] # type: Optional[List[str]] except Exception:
except Exception: remote_room_hosts = None
remote_room_hosts = None room_id, remote_room_hosts = await self.resolve_room_id(
elif RoomAlias.is_valid(room_identifier): room_identifier, remote_room_hosts
handler = self.room_member_handler )
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
fake_requester = create_requester( fake_requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity target_user, authenticated_entity=requester.authenticated_entity
@ -412,7 +448,7 @@ class JoinRoomAliasServlet(RestServlet):
return 200, {"room_id": room_id} return 200, {"room_id": room_id}
class MakeRoomAdminRestServlet(RestServlet): class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
"""Allows a server admin to get power in a room if a local user has power in """Allows a server admin to get power in a room if a local user has power in
a room. Will also invite the user if they're not in the room and it's a a room. Will also invite the user if they're not in the room and it's a
private room. Can specify another user (rather than the admin user) to be private room. Can specify another user (rather than the admin user) to be
@ -427,29 +463,21 @@ class MakeRoomAdminRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin") PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.event_creation_handler = hs.get_event_creation_handler() self.event_creation_handler = hs.get_event_creation_handler()
self.state_handler = hs.get_state_handler() self.state_handler = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id self.is_mine_id = hs.is_mine_id
async def on_POST(self, request, room_identifier): async def on_POST(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
content = parse_json_object_from_request(request, allow_empty_body=True) content = parse_json_object_from_request(request, allow_empty_body=True)
# Resolve to a room ID, if necessary. room_id, _ = await self.resolve_room_id(room_identifier)
if RoomID.is_valid(room_identifier):
room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
# Which user to grant room admin rights to. # Which user to grant room admin rights to.
user_to_add = content.get("user_id", requester.user.to_string()) user_to_add = content.get("user_id", requester.user.to_string())
@ -556,7 +584,7 @@ class MakeRoomAdminRestServlet(RestServlet):
return 200, {} return 200, {}
class ForwardExtremitiesRestServlet(RestServlet): class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
"""Allows a server admin to get or clear forward extremities. """Allows a server admin to get or clear forward extremities.
Clearing does not require restarting the server. Clearing does not require restarting the server.
@ -571,43 +599,29 @@ class ForwardExtremitiesRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities") PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs self.hs = hs
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.store = hs.get_datastore() self.store = hs.get_datastore()
async def resolve_room_id(self, room_identifier: str) -> str: async def on_DELETE(
"""Resolve to a room ID, if necessary.""" self, request: SynapseRequest, room_identifier: str
if RoomID.is_valid(room_identifier): ) -> Tuple[int, JsonDict]:
resolved_room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
resolved_room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
if not resolved_room_id:
raise SynapseError(
400, "Unknown room ID or room alias %s" % room_identifier
)
return resolved_room_id
async def on_DELETE(self, request, room_identifier):
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
room_id = await self.resolve_room_id(room_identifier) room_id, _ = await self.resolve_room_id(room_identifier)
deleted_count = await self.store.delete_forward_extremities_for_room(room_id) deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
return 200, {"deleted": deleted_count} return 200, {"deleted": deleted_count}
async def on_GET(self, request, room_identifier): async def on_GET(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)
room_id = await self.resolve_room_id(room_identifier) room_id, _ = await self.resolve_room_id(room_identifier)
extremities = await self.store.get_forward_extremities_for_room(room_id) extremities = await self.store.get_forward_extremities_for_room(room_id)
return 200, {"count": len(extremities), "results": extremities} return 200, {"count": len(extremities), "results": extremities}
@ -623,14 +637,16 @@ class RoomEventContextServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$") PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$")
def __init__(self, hs): def __init__(self, hs: "HomeServer"):
super().__init__() super().__init__()
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler() self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer() self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth() self.auth = hs.get_auth()
async def on_GET(self, request, room_id, event_id): async def on_GET(
self, request: SynapseRequest, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False) requester = await self.auth.get_user_by_req(request, allow_guest=False)
await assert_user_is_admin(self.auth, requester.user) await assert_user_is_admin(self.auth, requester.user)

View file

@ -18,7 +18,7 @@ import logging
from functools import wraps from functools import wraps
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Optional, Tuple
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.constants import ( from synapse.api.constants import (
MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_CATEGORYID_LENGTH,

View file

@ -21,7 +21,7 @@ from typing import Awaitable, Dict, Generator, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender from twisted.protocols.basic import FileSender
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError, cs_error from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json from synapse.http.server import finish_request, respond_with_json
@ -49,18 +49,20 @@ TEXT_CONTENT_TYPES = [
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]: def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try: try:
# The type on postpath seems incorrect in Twisted 21.2.0.
postpath = request.postpath # type: List[bytes] # type: ignore
assert postpath
# This allows users to append e.g. /test.png to the URL. Useful for # This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type. # clients that parse the URL to see content type.
server_name, media_id = request.postpath[:2] server_name_bytes, media_id_bytes = postpath[:2]
server_name = server_name_bytes.decode("utf-8")
if isinstance(server_name, bytes): media_id = media_id_bytes.decode("utf8")
server_name = server_name.decode("utf-8")
media_id = media_id.decode("utf8")
file_name = None file_name = None
if len(request.postpath) > 2: if len(postpath) > 2:
try: try:
file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8")) file_name = urllib.parse.unquote(postpath[-1].decode("utf-8"))
except UnicodeDecodeError: except UnicodeDecodeError:
pass pass
return server_name, media_id, file_name return server_name, media_id, file_name

View file

@ -17,7 +17,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.http import Request from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json

View file

@ -16,7 +16,7 @@
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.http import Request from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean from synapse.http.servlet import parse_boolean

View file

@ -22,8 +22,8 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error import twisted.internet.error
import twisted.web.http import twisted.web.http
from twisted.web.http import Request
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.api.errors import ( from synapse.api.errors import (
FederationDeniedError, FederationDeniedError,

View file

@ -29,7 +29,7 @@ from urllib import parse as urlparse
import attr import attr
from twisted.internet.error import DNSLookupError from twisted.internet.error import DNSLookupError
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient from synapse.http.client import SimpleHttpClient
@ -149,8 +149,7 @@ class PreviewUrlResource(DirectServeJsonResource):
treq_args={"browser_like_redirects": True}, treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist, ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist, ip_blacklist=hs.config.url_preview_ip_range_blacklist,
http_proxy=os.getenvb(b"http_proxy"), use_proxy=True,
https_proxy=os.getenvb(b"HTTPS_PROXY"),
) )
self.media_repo = media_repo self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path self.primary_base_path = media_repo.primary_base_path

View file

@ -18,7 +18,7 @@
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers from synapse.http.server import DirectServeJsonResource, set_cors_headers

View file

@ -15,9 +15,9 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING from typing import IO, TYPE_CHECKING
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json from synapse.http.server import DirectServeJsonResource, respond_with_json
@ -79,7 +79,9 @@ class UploadResource(DirectServeJsonResource):
headers = request.requestHeaders headers = request.requestHeaders
if headers.hasHeader(b"Content-Type"): if headers.hasHeader(b"Content-Type"):
media_type = headers.getRawHeaders(b"Content-Type")[0].decode("ascii") content_type_headers = headers.getRawHeaders(b"Content-Type")
assert content_type_headers # for mypy
media_type = content_type_headers[0].decode("ascii")
else: else:
raise SynapseError(msg="Upload request missing 'Content-Type'", code=400) raise SynapseError(msg="Upload request missing 'Content-Type'", code=400)
@ -88,8 +90,9 @@ class UploadResource(DirectServeJsonResource):
# TODO(markjh): parse content-dispostion # TODO(markjh): parse content-dispostion
try: try:
content = request.content # type: IO # type: ignore
content_uri = await self.media_repo.create_content( content_uri = await self.media_repo.create_content(
media_type, upload_name, request.content, content_length, requester.user media_type, upload_name, content, content_length, requester.user
) )
except SpamMediaException: except SpamMediaException:
# For uploading of media we want to respond with a 400, instead of # For uploading of media we want to respond with a 400, instead of

View file

@ -15,7 +15,7 @@
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.handlers.sso import get_username_mapping_session_cookie_from_request

View file

@ -15,7 +15,7 @@
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Tuple
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import ThreepidValidationError from synapse.api.errors import ThreepidValidationError
from synapse.config.emailconfig import ThreepidBehaviour from synapse.config.emailconfig import ThreepidBehaviour

View file

@ -16,8 +16,8 @@
import logging import logging
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
from twisted.web.http import Request
from twisted.web.resource import Resource from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.handlers.sso import get_username_mapping_session_cookie_from_request

View file

@ -16,7 +16,7 @@
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from twisted.web.http import Request from twisted.web.server import Request
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request from synapse.handlers.sso import get_username_mapping_session_cookie_from_request

View file

@ -24,7 +24,6 @@
import abc import abc
import functools import functools
import logging import logging
import os
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -370,11 +369,7 @@ class HomeServer(metaclass=abc.ABCMeta):
""" """
An HTTP client that uses configured HTTP(S) proxies. An HTTP client that uses configured HTTP(S) proxies.
""" """
return SimpleHttpClient( return SimpleHttpClient(self, use_proxy=True)
self,
http_proxy=os.getenvb(b"http_proxy"),
https_proxy=os.getenvb(b"HTTPS_PROXY"),
)
@cache_in_self @cache_in_self
def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient: def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
@ -386,8 +381,7 @@ class HomeServer(metaclass=abc.ABCMeta):
self, self,
ip_whitelist=self.config.ip_range_whitelist, ip_whitelist=self.config.ip_range_whitelist,
ip_blacklist=self.config.ip_range_blacklist, ip_blacklist=self.config.ip_range_blacklist,
http_proxy=os.getenvb(b"http_proxy"), use_proxy=True,
https_proxy=os.getenvb(b"HTTPS_PROXY"),
) )
@cache_in_self @cache_in_self

View file

@ -39,6 +39,11 @@ class PusherWorkerStore(SQLBaseStore):
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")] db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
) )
self.db_pool.updates.register_background_update_handler(
"remove_deactivated_pushers",
self._remove_deactivated_pushers,
)
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]: def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table """JSON-decode the data in the rows returned from the `pushers` table
@ -284,6 +289,54 @@ class PusherWorkerStore(SQLBaseStore):
lock=False, lock=False,
) )
async def _remove_deactivated_pushers(self, progress: dict, batch_size: int) -> int:
"""A background update that deletes all pushers for deactivated users.
Note that we don't proacively tell the pusherpool that we've deleted
these (just because its a bit off a faff to do from here), but they will
get cleaned up at the next restart
"""
last_user = progress.get("last_user", "")
def _delete_pushers(txn) -> int:
sql = """
SELECT name FROM users
WHERE deactivated = ? and name > ?
ORDER BY name ASC
LIMIT ?
"""
txn.execute(sql, (1, last_user, batch_size))
users = [row[0] for row in txn]
self.db_pool.simple_delete_many_txn(
txn,
table="pushers",
column="user_name",
iterable=users,
keyvalues={},
)
if users:
self.db_pool.updates._background_update_progress_txn(
txn, "remove_deactivated_pushers", {"last_user": users[-1]}
)
return len(users)
number_deleted = await self.db_pool.runInteraction(
"_remove_deactivated_pushers", _delete_pushers
)
if number_deleted < batch_size:
await self.db_pool.updates._end_background_update(
"remove_deactivated_pushers"
)
return number_deleted
class PusherStore(PusherWorkerStore): class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self) -> int: def get_pushers_stream_token(self) -> int:

View file

@ -14,8 +14,7 @@
*/ */
-- We may not have deleted all pushers for deactivated accounts. Do so now. -- We may not have deleted all pushers for deactivated accounts, so we set up a
-- -- background job to delete them.
-- Note: We don't bother updating the `deleted_pushers` table as it's just use INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
-- to stop pushers on workers, and that will happen when they get next restarted. (5908, 'remove_deactivated_pushers', '{}');
DELETE FROM pushers WHERE user_name IN (SELECT name FROM users WHERE deactivated = 1);

View file

@ -13,5 +13,14 @@
* limitations under the License. * limitations under the License.
*/ */
-- This originally was in 58/, but landed after 59/ was created, and so some
-- servers running develop didn't run this delta. Running it again should be
-- safe.
--
-- We first delete any in progress `rejected_events_metadata` background update,
-- to ensure that we don't conflict when trying to insert the new one. (We could
-- alternatively do an ON CONFLICT DO NOTHING, but that syntax isn't supported
-- by older SQLite versions. Plus, this should be a rare case).
DELETE FROM background_updates WHERE update_name = 'rejected_events_metadata';
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5828, 'rejected_events_metadata', '{}'); (5828, 'rejected_events_metadata', '{}');

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Set, TypeVar from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
from twisted.internet import defer from twisted.internet import defer
@ -40,7 +40,6 @@ class ResponseCache(Generic[T]):
def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0): def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
# Requests that haven't finished yet. # Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred] self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
self.pending_conditionals = {} # type: Dict[T, Set[Callable[[Any], bool]]]
self.clock = hs.get_clock() self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.0 self.timeout_sec = timeout_ms / 1000.0
@ -102,11 +101,7 @@ class ResponseCache(Generic[T]):
self.pending_result_cache[key] = result self.pending_result_cache[key] = result
def remove(r): def remove(r):
should_cache = all( if self.timeout_sec:
func(r) for func in self.pending_conditionals.pop(key, [])
)
if self.timeout_sec and should_cache:
self.clock.call_later( self.clock.call_later(
self.timeout_sec, self.pending_result_cache.pop, key, None self.timeout_sec, self.pending_result_cache.pop, key, None
) )
@ -117,31 +112,6 @@ class ResponseCache(Generic[T]):
result.addBoth(remove) result.addBoth(remove)
return result.observe() return result.observe()
def add_conditional(self, key: T, conditional: Callable[[Any], bool]):
self.pending_conditionals.setdefault(key, set()).add(conditional)
def wrap_conditional(
self,
key: T,
should_cache: Callable[[Any], bool],
callback: "Callable[..., Any]",
*args: Any,
**kwargs: Any
) -> defer.Deferred:
"""The same as wrap(), but adds a conditional to the final execution.
When the final execution completes, *all* conditionals need to return True for it to properly cache,
else it'll not be cached in a timed fashion.
"""
# See if there's already a result on this key that hasn't yet completed. Due to the single-threaded nature of
# python, adding a key immediately in the same execution thread will not cause a race condition.
result = self.get(key)
if not result or isinstance(result, defer.Deferred) and not result.called:
self.add_conditional(key, should_cache)
return self.wrap(key, callback, *args, **kwargs)
def wrap( def wrap(
self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
) -> defer.Deferred: ) -> defer.Deferred:

3
synctl
View file

@ -30,7 +30,7 @@ import yaml
from synapse.config import find_config_files from synapse.config import find_config_files
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"] SYNAPSE = [sys.executable, "-m", "synapse.app.homeserver"]
GREEN = "\x1b[1;32m" GREEN = "\x1b[1;32m"
YELLOW = "\x1b[1;33m" YELLOW = "\x1b[1;33m"
@ -117,7 +117,6 @@ def start_worker(app: str, configfile: str, worker_configfile: str) -> bool:
args = [ args = [
sys.executable, sys.executable,
"-B",
"-m", "-m",
app, app,
"-c", "-c",

View file

@ -26,77 +26,96 @@ from tests.unittest import TestCase
class ReadBodyWithMaxSizeTests(TestCase): class ReadBodyWithMaxSizeTests(TestCase):
def setUp(self): def _build_response(self, length=UNKNOWN_LENGTH):
"""Start reading the body, returns the response, result and proto""" """Start reading the body, returns the response, result and proto"""
response = Mock(length=UNKNOWN_LENGTH) response = Mock(length=length)
self.result = BytesIO() result = BytesIO()
self.deferred = read_body_with_max_size(response, self.result, 6) deferred = read_body_with_max_size(response, result, 6)
# Fish the protocol out of the response. # Fish the protocol out of the response.
self.protocol = response.deliverBody.call_args[0][0] protocol = response.deliverBody.call_args[0][0]
self.protocol.transport = Mock() protocol.transport = Mock()
def _cleanup_error(self): return result, deferred, protocol
def _assert_error(self, deferred, protocol):
"""Ensure that the expected error is received."""
self.assertIsInstance(deferred.result, Failure)
self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
protocol.transport.abortConnection.assert_called_once()
def _cleanup_error(self, deferred):
"""Ensure that the error in the Deferred is handled gracefully.""" """Ensure that the error in the Deferred is handled gracefully."""
called = [False] called = [False]
def errback(f): def errback(f):
called[0] = True called[0] = True
self.deferred.addErrback(errback) deferred.addErrback(errback)
self.assertTrue(called[0]) self.assertTrue(called[0])
def test_no_error(self): def test_no_error(self):
"""A response that is NOT too large.""" """A response that is NOT too large."""
result, deferred, protocol = self._build_response()
# Start sending data. # Start sending data.
self.protocol.dataReceived(b"12345") protocol.dataReceived(b"12345")
# Close the connection. # Close the connection.
self.protocol.connectionLost(Failure(ResponseDone())) protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.result.getvalue(), b"12345") self.assertEqual(result.getvalue(), b"12345")
self.assertEqual(self.deferred.result, 5) self.assertEqual(deferred.result, 5)
def test_too_large(self): def test_too_large(self):
"""A response which is too large raises an exception.""" """A response which is too large raises an exception."""
result, deferred, protocol = self._build_response()
# Start sending data. # Start sending data.
self.protocol.dataReceived(b"1234567890") protocol.dataReceived(b"1234567890")
# Close the connection.
self.protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.result.getvalue(), b"1234567890") self.assertEqual(result.getvalue(), b"1234567890")
self.assertIsInstance(self.deferred.result, Failure) self._assert_error(deferred, protocol)
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) self._cleanup_error(deferred)
self._cleanup_error()
def test_multiple_packets(self): def test_multiple_packets(self):
"""Data should be accummulated through mutliple packets.""" """Data should be accumulated through mutliple packets."""
result, deferred, protocol = self._build_response()
# Start sending data. # Start sending data.
self.protocol.dataReceived(b"12") protocol.dataReceived(b"12")
self.protocol.dataReceived(b"34") protocol.dataReceived(b"34")
# Close the connection. # Close the connection.
self.protocol.connectionLost(Failure(ResponseDone())) protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.result.getvalue(), b"1234") self.assertEqual(result.getvalue(), b"1234")
self.assertEqual(self.deferred.result, 4) self.assertEqual(deferred.result, 4)
def test_additional_data(self): def test_additional_data(self):
"""A connection can receive data after being closed.""" """A connection can receive data after being closed."""
result, deferred, protocol = self._build_response()
# Start sending data. # Start sending data.
self.protocol.dataReceived(b"1234567890") protocol.dataReceived(b"1234567890")
self.assertIsInstance(self.deferred.result, Failure) self._assert_error(deferred, protocol)
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
self.protocol.transport.abortConnection.assert_called_once()
# More data might have come in. # More data might have come in.
self.protocol.dataReceived(b"1234567890") protocol.dataReceived(b"1234567890")
# Close the connection.
self.protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.result.getvalue(), b"1234567890") self.assertEqual(result.getvalue(), b"1234567890")
self.assertIsInstance(self.deferred.result, Failure) self._assert_error(deferred, protocol)
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize) self._cleanup_error(deferred)
self._cleanup_error()
def test_content_length(self):
"""The body shouldn't be read (at all) if the Content-Length header is too large."""
result, deferred, protocol = self._build_response(length=10)
# Deferred shouldn't be called yet.
self.assertFalse(deferred.called)
# Start sending data.
protocol.dataReceived(b"12345")
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)
# The data is never consumed.
self.assertEqual(result.getvalue(), b"")

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
import os
from unittest.mock import patch
import treq import treq
from netaddr import IPSet from netaddr import IPSet
@ -100,62 +102,36 @@ class MatrixFederationAgentTests(TestCase):
return http_protocol return http_protocol
def test_http_request(self): def _test_request_direct_connection(self, agent, scheme, hostname, path):
agent = ProxyAgent(self.reactor) """Runs a test case for a direct connection not going through a proxy.
self.reactor.lookups["test.com"] = "1.2.3.4" Args:
d = agent.request(b"GET", b"http://test.com") agent (ProxyAgent): the proxy agent being tested
scheme (bytes): expected to be either "http" or "https"
hostname (bytes): the hostname to connect to in the test
path (bytes): the path to connect to in the test
"""
is_https = scheme == b"https"
self.reactor.lookups[hostname.decode()] = "1.2.3.4"
d = agent.request(b"GET", scheme + b"://" + hostname + b"/" + path)
# there should be a pending TCP connection # there should be a pending TCP connection
clients = self.reactor.tcpClients clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1) self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0] (host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, "1.2.3.4") self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 80) self.assertEqual(port, 443 if is_https else 80)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory, _get_test_protocol_factory()
)
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
# now there should be a pending request
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b"GET")
self.assertEqual(request.path, b"/")
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
request.write(b"result")
request.finish()
self.reactor.advance(0)
resp = self.successResultOf(d)
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
def test_https_request(self):
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
self.reactor.lookups["test.com"] = "1.2.3.4"
d = agent.request(b"GET", b"https://test.com/abc")
# there should be a pending TCP connection
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 443)
# make a test server, and wire up the client # make a test server, and wire up the client
http_server = self._make_connection( http_server = self._make_connection(
client_factory, client_factory,
_get_test_protocol_factory(), _get_test_protocol_factory(),
ssl=True, ssl=is_https,
expected_sni=b"test.com", expected_sni=hostname if is_https else None,
) )
# the FakeTransport is async, so we need to pump the reactor # the FakeTransport is async, so we need to pump the reactor
@ -166,8 +142,8 @@ class MatrixFederationAgentTests(TestCase):
request = http_server.requests[0] request = http_server.requests[0]
self.assertEqual(request.method, b"GET") self.assertEqual(request.method, b"GET")
self.assertEqual(request.path, b"/abc") self.assertEqual(request.path, b"/" + path)
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"]) self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [hostname])
request.write(b"result") request.write(b"result")
request.finish() request.finish()
@ -177,8 +153,58 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp)) body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result") self.assertEqual(body, b"result")
def test_http_request(self):
agent = ProxyAgent(self.reactor)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
def test_https_request(self):
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
def test_http_request_use_proxy_empty_environment(self):
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"})
def test_http_request_via_uppercase_no_proxy(self):
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(
os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"}
)
def test_http_request_via_no_proxy(self):
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(
os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"}
)
def test_https_request_via_no_proxy(self):
agent = ProxyAgent(
self.reactor,
contextFactory=get_test_https_policy(),
use_proxy=True,
)
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"})
def test_http_request_via_no_proxy_star(self):
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
def test_https_request_via_no_proxy_star(self):
agent = ProxyAgent(
self.reactor,
contextFactory=get_test_https_policy(),
use_proxy=True,
)
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
def test_http_request_via_proxy(self): def test_http_request_via_proxy(self):
agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888") agent = ProxyAgent(self.reactor, use_proxy=True)
self.reactor.lookups["proxy.com"] = "1.2.3.5" self.reactor.lookups["proxy.com"] = "1.2.3.5"
d = agent.request(b"GET", b"http://test.com") d = agent.request(b"GET", b"http://test.com")
@ -214,11 +240,12 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp)) body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result") self.assertEqual(body, b"result")
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
def test_https_request_via_proxy(self): def test_https_request_via_proxy(self):
agent = ProxyAgent( agent = ProxyAgent(
self.reactor, self.reactor,
contextFactory=get_test_https_policy(), contextFactory=get_test_https_policy(),
https_proxy=b"proxy.com", use_proxy=True,
) )
self.reactor.lookups["proxy.com"] = "1.2.3.5" self.reactor.lookups["proxy.com"] = "1.2.3.5"
@ -294,6 +321,7 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp)) body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result") self.assertEqual(body, b"result")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
def test_http_request_via_proxy_with_blacklist(self): def test_http_request_via_proxy_with_blacklist(self):
# The blacklist includes the configured proxy IP. # The blacklist includes the configured proxy IP.
agent = ProxyAgent( agent = ProxyAgent(
@ -301,7 +329,7 @@ class MatrixFederationAgentTests(TestCase):
self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"]) self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
), ),
self.reactor, self.reactor,
http_proxy=b"proxy.com:8888", use_proxy=True,
) )
self.reactor.lookups["proxy.com"] = "1.2.3.5" self.reactor.lookups["proxy.com"] = "1.2.3.5"
@ -338,7 +366,8 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp)) body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result") self.assertEqual(body, b"result")
def test_https_request_via_proxy_with_blacklist(self): @patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
def test_https_request_via_uppercase_proxy_with_blacklist(self):
# The blacklist includes the configured proxy IP. # The blacklist includes the configured proxy IP.
agent = ProxyAgent( agent = ProxyAgent(
BlacklistingReactorWrapper( BlacklistingReactorWrapper(
@ -346,7 +375,7 @@ class MatrixFederationAgentTests(TestCase):
), ),
self.reactor, self.reactor,
contextFactory=get_test_https_policy(), contextFactory=get_test_https_policy(),
https_proxy=b"proxy.com", use_proxy=True,
) )
self.reactor.lookups["proxy.com"] = "1.2.3.5" self.reactor.lookups["proxy.com"] = "1.2.3.5"

View file

@ -189,5 +189,7 @@ commands=
[testenv:mypy] [testenv:mypy]
deps = deps =
{[base]deps} {[base]deps}
# Type hints are broken with Twisted > 20.3.0, see https://github.com/matrix-org/synapse/issues/9513
twisted==20.3.0
extras = all,mypy extras = all,mypy
commands = mypy commands = mypy