Add missing type hints to synapse.replication. (#11938)

This commit is contained in:
Patrick Cloke 2022-02-08 11:03:08 -05:00 committed by GitHub
parent 8c94b3abe9
commit d0e78af35e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 209 additions and 147 deletions

1
changelog.d/11938.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to replication code.

View file

@ -169,6 +169,9 @@ disallow_untyped_defs = True
[mypy-synapse.push.*] [mypy-synapse.push.*]
disallow_untyped_defs = True disallow_untyped_defs = True
[mypy-synapse.replication.*]
disallow_untyped_defs = True
[mypy-synapse.rest.*] [mypy-synapse.rest.*]
disallow_untyped_defs = True disallow_untyped_defs = True

View file

@ -40,7 +40,7 @@ class SlavedIdTracker(AbstractStreamIdTracker):
for table, column in extra_tables: for table, column in extra_tables:
self.advance(None, _load_current_id(db_conn, table, column)) self.advance(None, _load_current_id(db_conn, table, column))
def advance(self, instance_name: Optional[str], new_id: int): def advance(self, instance_name: Optional[str], new_id: int) -> None:
self._current = (max if self.step > 0 else min)(self._current, new_id) self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self) -> int: def get_current_token(self) -> int:

View file

@ -37,7 +37,9 @@ class SlavedClientIpStore(BaseSlavedStore):
cache_name="client_ip_last_seen", max_size=50000 cache_name="client_ip_last_seen", max_size=50000
) )
async def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id): async def insert_client_ip(
self, user_id: str, access_token: str, ip: str, user_agent: str, device_id: str
) -> None:
now = int(self._clock.time_msec()) now = int(self._clock.time_msec())
key = (user_id, access_token, ip) key = (user_id, access_token, ip)

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, Iterable
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@ -60,7 +60,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
def get_device_stream_token(self) -> int: def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token() return self._device_list_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == DeviceListsStream.NAME: if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token) self._device_list_id_gen.advance(instance_name, token)
self._invalidate_caches_for_devices(token, rows) self._invalidate_caches_for_devices(token, rows)
@ -70,7 +72,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
self._user_signature_stream_cache.entity_has_changed(row.user_id, token) self._user_signature_stream_cache.entity_has_changed(row.user_id, token)
return super().process_replication_rows(stream_name, instance_name, token, rows) return super().process_replication_rows(stream_name, instance_name, token, rows)
def _invalidate_caches_for_devices(self, token, rows): def _invalidate_caches_for_devices(
self, token: int, rows: Iterable[DeviceListsStream.DeviceListsStreamRow]
) -> None:
for row in rows: for row in rows:
# The entities are either user IDs (starting with '@') whose devices # The entities are either user IDs (starting with '@') whose devices
# have changed, or remote servers that we need to tell about # have changed, or remote servers that we need to tell about

View file

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, Iterable
from synapse.replication.slave.storage._base import BaseSlavedStore from synapse.replication.slave.storage._base import BaseSlavedStore
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
@ -44,10 +44,12 @@ class SlavedGroupServerStore(GroupServerWorkerStore, BaseSlavedStore):
self._group_updates_id_gen.get_current_token(), self._group_updates_id_gen.get_current_token(),
) )
def get_group_stream_token(self): def get_group_stream_token(self) -> int:
return self._group_updates_id_gen.get_current_token() return self._group_updates_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == GroupServerStream.NAME: if stream_name == GroupServerStream.NAME:
self._group_updates_id_gen.advance(instance_name, token) self._group_updates_id_gen.advance(instance_name, token)
for row in rows: for row in rows:

View file

@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, Iterable
from synapse.replication.tcp.streams import PushRulesStream from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@ -20,10 +21,12 @@ from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore): class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_max_push_rules_stream_id(self): def get_max_push_rules_stream_id(self) -> int:
return self._push_rules_stream_id_gen.get_current_token() return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows): def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None:
if stream_name == PushRulesStream.NAME: if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(instance_name, token) self._push_rules_stream_id_gen.advance(instance_name, token)
for row in rows: for row in rows:

View file

@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, Iterable
from synapse.replication.tcp.streams import PushersStream from synapse.replication.tcp.streams import PushersStream
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
@ -41,8 +41,8 @@ class SlavedPusherStore(PusherWorkerStore, BaseSlavedStore):
return self._pushers_id_gen.get_current_token() return self._pushers_id_gen.get_current_token()
def process_replication_rows( def process_replication_rows(
self, stream_name: str, instance_name: str, token, rows self, stream_name: str, instance_name: str, token: int, rows: Iterable[Any]
) -> None: ) -> None:
if stream_name == PushersStream.NAME: if stream_name == PushersStream.NAME:
self._pushers_id_gen.advance(instance_name, token) # type: ignore self._pushers_id_gen.advance(instance_name, token)
return super().process_replication_rows(stream_name, instance_name, token, rows) return super().process_replication_rows(stream_name, instance_name, token, rows)

View file

@ -14,10 +14,12 @@
"""A replication client for use by synapse workers. """A replication client for use by synapse workers.
""" """
import logging import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from twisted.internet.defer import Deferred from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, IConnector
from twisted.internet.protocol import ReconnectingClientFactory from twisted.internet.protocol import ReconnectingClientFactory
from twisted.python.failure import Failure
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.federation import send_queue from synapse.federation import send_queue
@ -79,10 +81,10 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying) hs.get_reactor().addSystemEventTrigger("before", "shutdown", self.stopTrying)
def startedConnecting(self, connector): def startedConnecting(self, connector: IConnector) -> None:
logger.info("Connecting to replication: %r", connector.getDestination()) logger.info("Connecting to replication: %r", connector.getDestination())
def buildProtocol(self, addr): def buildProtocol(self, addr: IAddress) -> ClientReplicationStreamProtocol:
logger.info("Connected to replication: %r", addr) logger.info("Connected to replication: %r", addr)
return ClientReplicationStreamProtocol( return ClientReplicationStreamProtocol(
self.hs, self.hs,
@ -92,11 +94,11 @@ class DirectTcpReplicationClientFactory(ReconnectingClientFactory):
self.command_handler, self.command_handler,
) )
def clientConnectionLost(self, connector, reason): def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
logger.error("Lost replication conn: %r", reason) logger.error("Lost replication conn: %r", reason)
ReconnectingClientFactory.clientConnectionLost(self, connector, reason) ReconnectingClientFactory.clientConnectionLost(self, connector, reason)
def clientConnectionFailed(self, connector, reason): def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
logger.error("Failed to connect to replication: %r", reason) logger.error("Failed to connect to replication: %r", reason)
ReconnectingClientFactory.clientConnectionFailed(self, connector, reason) ReconnectingClientFactory.clientConnectionFailed(self, connector, reason)
@ -131,7 +133,7 @@ class ReplicationDataHandler:
async def on_rdata( async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list self, stream_name: str, instance_name: str, token: int, rows: list
): ) -> None:
"""Called to handle a batch of replication data with a given stream token. """Called to handle a batch of replication data with a given stream token.
By default this just pokes the slave store. Can be overridden in subclasses to By default this just pokes the slave store. Can be overridden in subclasses to
@ -252,14 +254,16 @@ class ReplicationDataHandler:
# loop. (This maintains the order so no need to resort) # loop. (This maintains the order so no need to resort)
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
async def on_position(self, stream_name: str, instance_name: str, token: int): async def on_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
await self.on_rdata(stream_name, instance_name, token, []) await self.on_rdata(stream_name, instance_name, token, [])
# We poke the generic "replication" notifier to wake anything up that # We poke the generic "replication" notifier to wake anything up that
# may be streaming. # may be streaming.
self.notifier.notify_replication() self.notifier.notify_replication()
def on_remote_server_up(self, server: str): def on_remote_server_up(self, server: str) -> None:
"""Called when get a new REMOTE_SERVER_UP command.""" """Called when get a new REMOTE_SERVER_UP command."""
# Let's wake up the transaction queue for the server in case we have # Let's wake up the transaction queue for the server in case we have
@ -269,7 +273,7 @@ class ReplicationDataHandler:
async def wait_for_stream_position( async def wait_for_stream_position(
self, instance_name: str, stream_name: str, position: int self, instance_name: str, stream_name: str, position: int
): ) -> None:
"""Wait until this instance has received updates up to and including """Wait until this instance has received updates up to and including
the given stream position. the given stream position.
""" """
@ -304,7 +308,7 @@ class ReplicationDataHandler:
"Finished waiting for repl stream %r to reach %s", stream_name, position "Finished waiting for repl stream %r to reach %s", stream_name, position
) )
def stop_pusher(self, user_id, app_id, pushkey): def stop_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
if not self._notify_pushers: if not self._notify_pushers:
return return
@ -316,13 +320,13 @@ class ReplicationDataHandler:
logger.info("Stopping pusher %r / %r", user_id, key) logger.info("Stopping pusher %r / %r", user_id, key)
pusher.on_stop() pusher.on_stop()
async def start_pusher(self, user_id, app_id, pushkey): async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
if not self._notify_pushers: if not self._notify_pushers:
return return
key = "%s:%s" % (app_id, pushkey) key = "%s:%s" % (app_id, pushkey)
logger.info("Starting pusher %r / %r", user_id, key) logger.info("Starting pusher %r / %r", user_id, key)
return await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
class FederationSenderHandler: class FederationSenderHandler:
@ -353,10 +357,12 @@ class FederationSenderHandler:
self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer") self._fed_position_linearizer = Linearizer(name="_fed_position_linearizer")
def wake_destination(self, server: str): def wake_destination(self, server: str) -> None:
self.federation_sender.wake_destination(server) self.federation_sender.wake_destination(server)
async def process_replication_rows(self, stream_name, token, rows): async def process_replication_rows(
self, stream_name: str, token: int, rows: list
) -> None:
# The federation stream contains things that we want to send out, e.g. # The federation stream contains things that we want to send out, e.g.
# presence, typing, etc. # presence, typing, etc.
if stream_name == "federation": if stream_name == "federation":
@ -384,11 +390,12 @@ class FederationSenderHandler:
for host in hosts: for host in hosts:
self.federation_sender.send_device_messages(host) self.federation_sender.send_device_messages(host)
async def _on_new_receipts(self, rows): async def _on_new_receipts(
self, rows: Iterable[ReceiptsStream.ReceiptsStreamRow]
) -> None:
""" """
Args: Args:
rows (Iterable[synapse.replication.tcp.streams.ReceiptsStream.ReceiptsStreamRow]): rows: new receipts to be processed
new receipts to be processed
""" """
for receipt in rows: for receipt in rows:
# we only want to send on receipts for our own users # we only want to send on receipts for our own users
@ -408,7 +415,7 @@ class FederationSenderHandler:
) )
await self.federation_sender.send_read_receipt(receipt_info) await self.federation_sender.send_read_receipt(receipt_info)
async def update_token(self, token): async def update_token(self, token: int) -> None:
"""Update the record of where we have processed to in the federation stream. """Update the record of where we have processed to in the federation stream.
Called after we have processed a an update received over replication. Sends Called after we have processed a an update received over replication. Sends
@ -428,7 +435,7 @@ class FederationSenderHandler:
run_as_background_process("_save_and_send_ack", self._save_and_send_ack) run_as_background_process("_save_and_send_ack", self._save_and_send_ack)
async def _save_and_send_ack(self): async def _save_and_send_ack(self) -> None:
"""Save the current federation position in the database and send an ACK """Save the current federation position in the database and send an ACK
to master with where we're up to. to master with where we're up to.
""" """

View file

@ -18,12 +18,15 @@ allowed to be sent by which side.
""" """
import abc import abc
import logging import logging
from typing import Tuple, Type from typing import Optional, Tuple, Type, TypeVar
from synapse.replication.tcp.streams._base import StreamRow
from synapse.util import json_decoder, json_encoder from synapse.util import json_decoder, json_encoder
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar("T", bound="Command")
class Command(metaclass=abc.ABCMeta): class Command(metaclass=abc.ABCMeta):
"""The base command class. """The base command class.
@ -38,7 +41,7 @@ class Command(metaclass=abc.ABCMeta):
@classmethod @classmethod
@abc.abstractmethod @abc.abstractmethod
def from_line(cls, line): def from_line(cls: Type[T], line: str) -> T:
"""Deserialises a line from the wire into this command. `line` does not """Deserialises a line from the wire into this command. `line` does not
include the command. include the command.
""" """
@ -49,21 +52,24 @@ class Command(metaclass=abc.ABCMeta):
prefix. prefix.
""" """
def get_logcontext_id(self): def get_logcontext_id(self) -> str:
"""Get a suitable string for the logcontext when processing this command""" """Get a suitable string for the logcontext when processing this command"""
# by default, we just use the command name. # by default, we just use the command name.
return self.NAME return self.NAME
SC = TypeVar("SC", bound="_SimpleCommand")
class _SimpleCommand(Command): class _SimpleCommand(Command):
"""An implementation of Command whose argument is just a 'data' string.""" """An implementation of Command whose argument is just a 'data' string."""
def __init__(self, data): def __init__(self, data: str):
self.data = data self.data = data
@classmethod @classmethod
def from_line(cls, line): def from_line(cls: Type[SC], line: str) -> SC:
return cls(line) return cls(line)
def to_line(self) -> str: def to_line(self) -> str:
@ -109,14 +115,16 @@ class RdataCommand(Command):
NAME = "RDATA" NAME = "RDATA"
def __init__(self, stream_name, instance_name, token, row): def __init__(
self, stream_name: str, instance_name: str, token: Optional[int], row: StreamRow
):
self.stream_name = stream_name self.stream_name = stream_name
self.instance_name = instance_name self.instance_name = instance_name
self.token = token self.token = token
self.row = row self.row = row
@classmethod @classmethod
def from_line(cls, line): def from_line(cls: Type["RdataCommand"], line: str) -> "RdataCommand":
stream_name, instance_name, token, row_json = line.split(" ", 3) stream_name, instance_name, token, row_json = line.split(" ", 3)
return cls( return cls(
stream_name, stream_name,
@ -125,7 +133,7 @@ class RdataCommand(Command):
json_decoder.decode(row_json), json_decoder.decode(row_json),
) )
def to_line(self): def to_line(self) -> str:
return " ".join( return " ".join(
( (
self.stream_name, self.stream_name,
@ -135,7 +143,7 @@ class RdataCommand(Command):
) )
) )
def get_logcontext_id(self): def get_logcontext_id(self) -> str:
return "RDATA-" + self.stream_name return "RDATA-" + self.stream_name
@ -164,18 +172,20 @@ class PositionCommand(Command):
NAME = "POSITION" NAME = "POSITION"
def __init__(self, stream_name, instance_name, prev_token, new_token): def __init__(
self, stream_name: str, instance_name: str, prev_token: int, new_token: int
):
self.stream_name = stream_name self.stream_name = stream_name
self.instance_name = instance_name self.instance_name = instance_name
self.prev_token = prev_token self.prev_token = prev_token
self.new_token = new_token self.new_token = new_token
@classmethod @classmethod
def from_line(cls, line): def from_line(cls: Type["PositionCommand"], line: str) -> "PositionCommand":
stream_name, instance_name, prev_token, new_token = line.split(" ", 3) stream_name, instance_name, prev_token, new_token = line.split(" ", 3)
return cls(stream_name, instance_name, int(prev_token), int(new_token)) return cls(stream_name, instance_name, int(prev_token), int(new_token))
def to_line(self): def to_line(self) -> str:
return " ".join( return " ".join(
( (
self.stream_name, self.stream_name,
@ -218,14 +228,14 @@ class ReplicateCommand(Command):
NAME = "REPLICATE" NAME = "REPLICATE"
def __init__(self): def __init__(self) -> None:
pass pass
@classmethod @classmethod
def from_line(cls, line): def from_line(cls: Type[T], line: str) -> T:
return cls() return cls()
def to_line(self): def to_line(self) -> str:
return "" return ""
@ -247,14 +257,16 @@ class UserSyncCommand(Command):
NAME = "USER_SYNC" NAME = "USER_SYNC"
def __init__(self, instance_id, user_id, is_syncing, last_sync_ms): def __init__(
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
):
self.instance_id = instance_id self.instance_id = instance_id
self.user_id = user_id self.user_id = user_id
self.is_syncing = is_syncing self.is_syncing = is_syncing
self.last_sync_ms = last_sync_ms self.last_sync_ms = last_sync_ms
@classmethod @classmethod
def from_line(cls, line): def from_line(cls: Type["UserSyncCommand"], line: str) -> "UserSyncCommand":
instance_id, user_id, state, last_sync_ms = line.split(" ", 3) instance_id, user_id, state, last_sync_ms = line.split(" ", 3)
if state not in ("start", "end"): if state not in ("start", "end"):
@ -262,7 +274,7 @@ class UserSyncCommand(Command):
return cls(instance_id, user_id, state == "start", int(last_sync_ms)) return cls(instance_id, user_id, state == "start", int(last_sync_ms))
def to_line(self): def to_line(self) -> str:
return " ".join( return " ".join(
( (
self.instance_id, self.instance_id,
@ -286,14 +298,16 @@ class ClearUserSyncsCommand(Command):
NAME = "CLEAR_USER_SYNC" NAME = "CLEAR_USER_SYNC"
def __init__(self, instance_id): def __init__(self, instance_id: str):
self.instance_id = instance_id self.instance_id = instance_id
@classmethod @classmethod
def from_line(cls, line): def from_line(
cls: Type["ClearUserSyncsCommand"], line: str
) -> "ClearUserSyncsCommand":
return cls(line) return cls(line)
def to_line(self): def to_line(self) -> str:
return self.instance_id return self.instance_id
@ -316,7 +330,9 @@ class FederationAckCommand(Command):
self.token = token self.token = token
@classmethod @classmethod
def from_line(cls, line: str) -> "FederationAckCommand": def from_line(
cls: Type["FederationAckCommand"], line: str
) -> "FederationAckCommand":
instance_name, token = line.split(" ") instance_name, token = line.split(" ")
return cls(instance_name, int(token)) return cls(instance_name, int(token))
@ -334,7 +350,15 @@ class UserIpCommand(Command):
NAME = "USER_IP" NAME = "USER_IP"
def __init__(self, user_id, access_token, ip, user_agent, device_id, last_seen): def __init__(
self,
user_id: str,
access_token: str,
ip: str,
user_agent: str,
device_id: str,
last_seen: int,
):
self.user_id = user_id self.user_id = user_id
self.access_token = access_token self.access_token = access_token
self.ip = ip self.ip = ip
@ -343,14 +367,14 @@ class UserIpCommand(Command):
self.last_seen = last_seen self.last_seen = last_seen
@classmethod @classmethod
def from_line(cls, line): def from_line(cls: Type["UserIpCommand"], line: str) -> "UserIpCommand":
user_id, jsn = line.split(" ", 1) user_id, jsn = line.split(" ", 1)
access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn) access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
return cls(user_id, access_token, ip, user_agent, device_id, last_seen) return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
def to_line(self): def to_line(self) -> str:
return ( return (
self.user_id self.user_id
+ " " + " "

View file

@ -261,7 +261,7 @@ class ReplicationCommandHandler:
"process-replication-data", self._unsafe_process_queue, stream_name "process-replication-data", self._unsafe_process_queue, stream_name
) )
async def _unsafe_process_queue(self, stream_name: str): async def _unsafe_process_queue(self, stream_name: str) -> None:
"""Processes the command queue for the given stream, until it is empty """Processes the command queue for the given stream, until it is empty
Does not check if there is already a thread processing the queue, hence "unsafe" Does not check if there is already a thread processing the queue, hence "unsafe"
@ -294,7 +294,7 @@ class ReplicationCommandHandler:
# This shouldn't be possible # This shouldn't be possible
raise Exception("Unrecognised command %s in stream queue", cmd.NAME) raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
def start_replication(self, hs: "HomeServer"): def start_replication(self, hs: "HomeServer") -> None:
"""Helper method to start a replication connection to the remote server """Helper method to start a replication connection to the remote server
using TCP. using TCP.
""" """
@ -345,10 +345,10 @@ class ReplicationCommandHandler:
"""Get a list of streams that this instances replicates.""" """Get a list of streams that this instances replicates."""
return self._streams_to_replicate return self._streams_to_replicate
def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand): def on_REPLICATE(self, conn: IReplicationConnection, cmd: ReplicateCommand) -> None:
self.send_positions_to_connection(conn) self.send_positions_to_connection(conn)
def send_positions_to_connection(self, conn: IReplicationConnection): def send_positions_to_connection(self, conn: IReplicationConnection) -> None:
"""Send current position of all streams this process is source of to """Send current position of all streams this process is source of to
the connection. the connection.
""" """
@ -392,7 +392,7 @@ class ReplicationCommandHandler:
def on_FEDERATION_ACK( def on_FEDERATION_ACK(
self, conn: IReplicationConnection, cmd: FederationAckCommand self, conn: IReplicationConnection, cmd: FederationAckCommand
): ) -> None:
federation_ack_counter.inc() federation_ack_counter.inc()
if self._federation_sender: if self._federation_sender:
@ -408,7 +408,7 @@ class ReplicationCommandHandler:
else: else:
return None return None
async def _handle_user_ip(self, cmd: UserIpCommand): async def _handle_user_ip(self, cmd: UserIpCommand) -> None:
await self._store.insert_client_ip( await self._store.insert_client_ip(
cmd.user_id, cmd.user_id,
cmd.access_token, cmd.access_token,
@ -421,7 +421,7 @@ class ReplicationCommandHandler:
assert self._server_notices_sender is not None assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id) await self._server_notices_sender.on_user_ip(cmd.user_id)
def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand): def on_RDATA(self, conn: IReplicationConnection, cmd: RdataCommand) -> None:
if cmd.instance_name == self._instance_name: if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes # Ignore RDATA that are just our own echoes
return return
@ -497,7 +497,7 @@ class ReplicationCommandHandler:
async def on_rdata( async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list self, stream_name: str, instance_name: str, token: int, rows: list
): ) -> None:
"""Called to handle a batch of replication data with a given stream token. """Called to handle a batch of replication data with a given stream token.
Args: Args:
@ -512,7 +512,7 @@ class ReplicationCommandHandler:
stream_name, instance_name, token, rows stream_name, instance_name, token, rows
) )
def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand): def on_POSITION(self, conn: IReplicationConnection, cmd: PositionCommand) -> None:
if cmd.instance_name == self._instance_name: if cmd.instance_name == self._instance_name:
# Ignore POSITION that are just our own echoes # Ignore POSITION that are just our own echoes
return return
@ -581,7 +581,7 @@ class ReplicationCommandHandler:
def on_REMOTE_SERVER_UP( def on_REMOTE_SERVER_UP(
self, conn: IReplicationConnection, cmd: RemoteServerUpCommand self, conn: IReplicationConnection, cmd: RemoteServerUpCommand
): ) -> None:
"""Called when get a new REMOTE_SERVER_UP command.""" """Called when get a new REMOTE_SERVER_UP command."""
self._replication_data_handler.on_remote_server_up(cmd.data) self._replication_data_handler.on_remote_server_up(cmd.data)
@ -604,7 +604,7 @@ class ReplicationCommandHandler:
# between two instances, but that is not currently supported). # between two instances, but that is not currently supported).
self.send_command(cmd, ignore_conn=conn) self.send_command(cmd, ignore_conn=conn)
def new_connection(self, connection: IReplicationConnection): def new_connection(self, connection: IReplicationConnection) -> None:
"""Called when we have a new connection.""" """Called when we have a new connection."""
self._connections.append(connection) self._connections.append(connection)
@ -631,7 +631,7 @@ class ReplicationCommandHandler:
UserSyncCommand(self._instance_id, user_id, True, now) UserSyncCommand(self._instance_id, user_id, True, now)
) )
def lost_connection(self, connection: IReplicationConnection): def lost_connection(self, connection: IReplicationConnection) -> None:
"""Called when a connection is closed/lost.""" """Called when a connection is closed/lost."""
# we no longer need _streams_by_connection for this connection. # we no longer need _streams_by_connection for this connection.
streams = self._streams_by_connection.pop(connection, None) streams = self._streams_by_connection.pop(connection, None)
@ -653,7 +653,7 @@ class ReplicationCommandHandler:
def send_command( def send_command(
self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None self, cmd: Command, ignore_conn: Optional[IReplicationConnection] = None
): ) -> None:
"""Send a command to all connected connections. """Send a command to all connected connections.
Args: Args:
@ -680,7 +680,7 @@ class ReplicationCommandHandler:
else: else:
logger.warning("Dropping command as not connected: %r", cmd.NAME) logger.warning("Dropping command as not connected: %r", cmd.NAME)
def send_federation_ack(self, token: int): def send_federation_ack(self, token: int) -> None:
"""Ack data for the federation stream. This allows the master to drop """Ack data for the federation stream. This allows the master to drop
data stored purely in memory. data stored purely in memory.
""" """
@ -688,7 +688,7 @@ class ReplicationCommandHandler:
def send_user_sync( def send_user_sync(
self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int self, instance_id: str, user_id: str, is_syncing: bool, last_sync_ms: int
): ) -> None:
"""Poke the master that a user has started/stopped syncing.""" """Poke the master that a user has started/stopped syncing."""
self.send_command( self.send_command(
UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms) UserSyncCommand(instance_id, user_id, is_syncing, last_sync_ms)
@ -702,15 +702,15 @@ class ReplicationCommandHandler:
user_agent: str, user_agent: str,
device_id: str, device_id: str,
last_seen: int, last_seen: int,
): ) -> None:
"""Tell the master that the user made a request.""" """Tell the master that the user made a request."""
cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen) cmd = UserIpCommand(user_id, access_token, ip, user_agent, device_id, last_seen)
self.send_command(cmd) self.send_command(cmd)
def send_remote_server_up(self, server: str): def send_remote_server_up(self, server: str) -> None:
self.send_command(RemoteServerUpCommand(server)) self.send_command(RemoteServerUpCommand(server))
def stream_update(self, stream_name: str, token: str, data: Any): def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None:
"""Called when a new update is available to stream to clients. """Called when a new update is available to stream to clients.
We need to check if the client is interested in the stream or not We need to check if the client is interested in the stream or not

View file

@ -49,7 +49,7 @@ import fcntl
import logging import logging
import struct import struct
from inspect import isawaitable from inspect import isawaitable
from typing import TYPE_CHECKING, Collection, List, Optional from typing import TYPE_CHECKING, Any, Collection, List, Optional
from prometheus_client import Counter from prometheus_client import Counter
from zope.interface import Interface, implementer from zope.interface import Interface, implementer
@ -123,7 +123,7 @@ class ConnectionStates:
class IReplicationConnection(Interface): class IReplicationConnection(Interface):
"""An interface for replication connections.""" """An interface for replication connections."""
def send_command(cmd: Command): def send_command(cmd: Command) -> None:
"""Send the command down the connection""" """Send the command down the connection"""
@ -190,7 +190,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
"replication-conn", self.conn_id "replication-conn", self.conn_id
) )
def connectionMade(self): def connectionMade(self) -> None:
logger.info("[%s] Connection established", self.id()) logger.info("[%s] Connection established", self.id())
self.state = ConnectionStates.ESTABLISHED self.state = ConnectionStates.ESTABLISHED
@ -207,11 +207,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# Always send the initial PING so that the other side knows that they # Always send the initial PING so that the other side knows that they
# can time us out. # can time us out.
self.send_command(PingCommand(self.clock.time_msec())) self.send_command(PingCommand(str(self.clock.time_msec())))
self.command_handler.new_connection(self) self.command_handler.new_connection(self)
def send_ping(self): def send_ping(self) -> None:
"""Periodically sends a ping and checks if we should close the connection """Periodically sends a ping and checks if we should close the connection
due to the other side timing out. due to the other side timing out.
""" """
@ -226,7 +226,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.transport.abortConnection() self.transport.abortConnection()
else: else:
if now - self.last_sent_command >= PING_TIME: if now - self.last_sent_command >= PING_TIME:
self.send_command(PingCommand(now)) self.send_command(PingCommand(str(now)))
if ( if (
self.received_ping self.received_ping
@ -239,12 +239,12 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
) )
self.send_error("ping timeout") self.send_error("ping timeout")
def lineReceived(self, line: bytes): def lineReceived(self, line: bytes) -> None:
"""Called when we've received a line""" """Called when we've received a line"""
with PreserveLoggingContext(self._logging_context): with PreserveLoggingContext(self._logging_context):
self._parse_and_dispatch_line(line) self._parse_and_dispatch_line(line)
def _parse_and_dispatch_line(self, line: bytes): def _parse_and_dispatch_line(self, line: bytes) -> None:
if line.strip() == "": if line.strip() == "":
# Ignore blank lines # Ignore blank lines
return return
@ -309,24 +309,24 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
if not handled: if not handled:
logger.warning("Unhandled command: %r", cmd) logger.warning("Unhandled command: %r", cmd)
def close(self): def close(self) -> None:
logger.warning("[%s] Closing connection", self.id()) logger.warning("[%s] Closing connection", self.id())
self.time_we_closed = self.clock.time_msec() self.time_we_closed = self.clock.time_msec()
assert self.transport is not None assert self.transport is not None
self.transport.loseConnection() self.transport.loseConnection()
self.on_connection_closed() self.on_connection_closed()
def send_error(self, error_string, *args): def send_error(self, error_string: str, *args: Any) -> None:
"""Send an error to remote and close the connection.""" """Send an error to remote and close the connection."""
self.send_command(ErrorCommand(error_string % args)) self.send_command(ErrorCommand(error_string % args))
self.close() self.close()
def send_command(self, cmd, do_buffer=True): def send_command(self, cmd: Command, do_buffer: bool = True) -> None:
"""Send a command if connection has been established. """Send a command if connection has been established.
Args: Args:
cmd (Command) cmd
do_buffer (bool): Whether to buffer the message or always attempt do_buffer: Whether to buffer the message or always attempt
to send the command. This is mostly used to send an error to send the command. This is mostly used to send an error
message if we're about to close the connection due our buffers message if we're about to close the connection due our buffers
becoming full. becoming full.
@ -357,7 +357,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.last_sent_command = self.clock.time_msec() self.last_sent_command = self.clock.time_msec()
def _queue_command(self, cmd): def _queue_command(self, cmd: Command) -> None:
"""Queue the command until the connection is ready to write to again.""" """Queue the command until the connection is ready to write to again."""
logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd) logger.debug("[%s] Queueing as conn %r, cmd: %r", self.id(), self.state, cmd)
self.pending_commands.append(cmd) self.pending_commands.append(cmd)
@ -370,20 +370,20 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False) self.send_command(ErrorCommand("Failed to keep up"), do_buffer=False)
self.close() self.close()
def _send_pending_commands(self): def _send_pending_commands(self) -> None:
"""Send any queued commandes""" """Send any queued commandes"""
pending = self.pending_commands pending = self.pending_commands
self.pending_commands = [] self.pending_commands = []
for cmd in pending: for cmd in pending:
self.send_command(cmd) self.send_command(cmd)
def on_PING(self, line): def on_PING(self, cmd: PingCommand) -> None:
self.received_ping = True self.received_ping = True
def on_ERROR(self, cmd): def on_ERROR(self, cmd: ErrorCommand) -> None:
logger.error("[%s] Remote reported error: %r", self.id(), cmd.data) logger.error("[%s] Remote reported error: %r", self.id(), cmd.data)
def pauseProducing(self): def pauseProducing(self) -> None:
"""This is called when both the kernel send buffer and the twisted """This is called when both the kernel send buffer and the twisted
tcp connection send buffers have become full. tcp connection send buffers have become full.
@ -394,26 +394,26 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
logger.info("[%s] Pause producing", self.id()) logger.info("[%s] Pause producing", self.id())
self.state = ConnectionStates.PAUSED self.state = ConnectionStates.PAUSED
def resumeProducing(self): def resumeProducing(self) -> None:
"""The remote has caught up after we started buffering!""" """The remote has caught up after we started buffering!"""
logger.info("[%s] Resume producing", self.id()) logger.info("[%s] Resume producing", self.id())
self.state = ConnectionStates.ESTABLISHED self.state = ConnectionStates.ESTABLISHED
self._send_pending_commands() self._send_pending_commands()
def stopProducing(self): def stopProducing(self) -> None:
"""We're never going to send any more data (normally because either """We're never going to send any more data (normally because either
we or the remote has closed the connection) we or the remote has closed the connection)
""" """
logger.info("[%s] Stop producing", self.id()) logger.info("[%s] Stop producing", self.id())
self.on_connection_closed() self.on_connection_closed()
def connectionLost(self, reason): def connectionLost(self, reason: Failure) -> None: # type: ignore[override]
logger.info("[%s] Replication connection closed: %r", self.id(), reason) logger.info("[%s] Replication connection closed: %r", self.id(), reason)
if isinstance(reason, Failure): if isinstance(reason, Failure):
assert reason.type is not None assert reason.type is not None
connection_close_counter.labels(reason.type.__name__).inc() connection_close_counter.labels(reason.type.__name__).inc()
else: else:
connection_close_counter.labels(reason.__class__.__name__).inc() connection_close_counter.labels(reason.__class__.__name__).inc() # type: ignore[unreachable]
try: try:
# Remove us from list of connections to be monitored # Remove us from list of connections to be monitored
@ -427,7 +427,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.on_connection_closed() self.on_connection_closed()
def on_connection_closed(self): def on_connection_closed(self) -> None:
logger.info("[%s] Connection was closed", self.id()) logger.info("[%s] Connection was closed", self.id())
self.state = ConnectionStates.CLOSED self.state = ConnectionStates.CLOSED
@ -445,7 +445,7 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
# the sentinel context is now active, which may not be correct. # the sentinel context is now active, which may not be correct.
# PreserveLoggingContext() will restore the correct logging context. # PreserveLoggingContext() will restore the correct logging context.
def __str__(self): def __str__(self) -> str:
addr = None addr = None
if self.transport: if self.transport:
addr = str(self.transport.getPeer()) addr = str(self.transport.getPeer())
@ -455,10 +455,10 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
addr, addr,
) )
def id(self): def id(self) -> str:
return "%s-%s" % (self.name, self.conn_id) return "%s-%s" % (self.name, self.conn_id)
def lineLengthExceeded(self, line): def lineLengthExceeded(self, line: str) -> None:
"""Called when we receive a line that is above the maximum line length""" """Called when we receive a line that is above the maximum line length"""
self.send_error("Line length exceeded") self.send_error("Line length exceeded")
@ -474,11 +474,11 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.server_name = server_name self.server_name = server_name
def connectionMade(self): def connectionMade(self) -> None:
self.send_command(ServerCommand(self.server_name)) self.send_command(ServerCommand(self.server_name))
super().connectionMade() super().connectionMade()
def on_NAME(self, cmd): def on_NAME(self, cmd: NameCommand) -> None:
logger.info("[%s] Renamed to %r", self.id(), cmd.data) logger.info("[%s] Renamed to %r", self.id(), cmd.data)
self.name = cmd.data self.name = cmd.data
@ -500,19 +500,19 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.client_name = client_name self.client_name = client_name
self.server_name = server_name self.server_name = server_name
def connectionMade(self): def connectionMade(self) -> None:
self.send_command(NameCommand(self.client_name)) self.send_command(NameCommand(self.client_name))
super().connectionMade() super().connectionMade()
# Once we've connected subscribe to the necessary streams # Once we've connected subscribe to the necessary streams
self.replicate() self.replicate()
def on_SERVER(self, cmd): def on_SERVER(self, cmd: ServerCommand) -> None:
if cmd.data != self.server_name: if cmd.data != self.server_name:
logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data) logger.error("[%s] Connected to wrong remote: %r", self.id(), cmd.data)
self.send_error("Wrong remote") self.send_error("Wrong remote")
def replicate(self): def replicate(self) -> None:
"""Send the subscription request to the server""" """Send the subscription request to the server"""
logger.info("[%s] Subscribing to replication streams", self.id()) logger.info("[%s] Subscribing to replication streams", self.id())
@ -529,7 +529,7 @@ pending_commands = LaterGauge(
) )
def transport_buffer_size(protocol): def transport_buffer_size(protocol: BaseReplicationStreamProtocol) -> int:
if protocol.transport: if protocol.transport:
size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen size = len(protocol.transport.dataBuffer) + protocol.transport._tempDataLen
return size return size
@ -544,7 +544,9 @@ transport_send_buffer = LaterGauge(
) )
def transport_kernel_read_buffer_size(protocol, read=True): def transport_kernel_read_buffer_size(
protocol: BaseReplicationStreamProtocol, read: bool = True
) -> int:
SIOCINQ = 0x541B SIOCINQ = 0x541B
SIOCOUTQ = 0x5411 SIOCOUTQ = 0x5411

View file

@ -14,7 +14,7 @@
import logging import logging
from inspect import isawaitable from inspect import isawaitable
from typing import TYPE_CHECKING, Generic, Optional, Type, TypeVar, cast from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast
import attr import attr
import txredisapi import txredisapi
@ -62,7 +62,7 @@ class ConstantProperty(Generic[T, V]):
def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V: def __get__(self, obj: Optional[T], objtype: Optional[Type[T]] = None) -> V:
return self.constant return self.constant
def __set__(self, obj: Optional[T], value: V): def __set__(self, obj: Optional[T], value: V) -> None:
pass pass
@ -95,7 +95,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
synapse_stream_name: str synapse_stream_name: str
synapse_outbound_redis_connection: txredisapi.RedisProtocol synapse_outbound_redis_connection: txredisapi.RedisProtocol
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# a logcontext which we use for processing incoming commands. We declare it as a # a logcontext which we use for processing incoming commands. We declare it as a
@ -108,12 +108,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
"replication_command_handler" "replication_command_handler"
) )
def connectionMade(self): def connectionMade(self) -> None:
logger.info("Connected to redis") logger.info("Connected to redis")
super().connectionMade() super().connectionMade()
run_as_background_process("subscribe-replication", self._send_subscribe) run_as_background_process("subscribe-replication", self._send_subscribe)
async def _send_subscribe(self): async def _send_subscribe(self) -> None:
# it's important to make sure that we only send the REPLICATE command once we # it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the # have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end. # POSITION response sent back by the other end.
@ -131,12 +131,12 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
# otherside won't know we've connected and so won't issue a REPLICATE. # otherside won't know we've connected and so won't issue a REPLICATE.
self.synapse_handler.send_positions_to_connection(self) self.synapse_handler.send_positions_to_connection(self)
def messageReceived(self, pattern: str, channel: str, message: str): def messageReceived(self, pattern: str, channel: str, message: str) -> None:
"""Received a message from redis.""" """Received a message from redis."""
with PreserveLoggingContext(self._logging_context): with PreserveLoggingContext(self._logging_context):
self._parse_and_dispatch_message(message) self._parse_and_dispatch_message(message)
def _parse_and_dispatch_message(self, message: str): def _parse_and_dispatch_message(self, message: str) -> None:
if message.strip() == "": if message.strip() == "":
# Ignore blank lines # Ignore blank lines
return return
@ -181,7 +181,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
"replication-" + cmd.get_logcontext_id(), lambda: res "replication-" + cmd.get_logcontext_id(), lambda: res
) )
def connectionLost(self, reason): def connectionLost(self, reason: Failure) -> None: # type: ignore[override]
logger.info("Lost connection to redis") logger.info("Lost connection to redis")
super().connectionLost(reason) super().connectionLost(reason)
self.synapse_handler.lost_connection(self) self.synapse_handler.lost_connection(self)
@ -193,17 +193,17 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
# the sentinel context is now active, which may not be correct. # the sentinel context is now active, which may not be correct.
# PreserveLoggingContext() will restore the correct logging context. # PreserveLoggingContext() will restore the correct logging context.
def send_command(self, cmd: Command): def send_command(self, cmd: Command) -> None:
"""Send a command if connection has been established. """Send a command if connection has been established.
Args: Args:
cmd (Command) cmd: The command to send
""" """
run_as_background_process( run_as_background_process(
"send-cmd", self._async_send_command, cmd, bg_start_span=False "send-cmd", self._async_send_command, cmd, bg_start_span=False
) )
async def _async_send_command(self, cmd: Command): async def _async_send_command(self, cmd: Command) -> None:
"""Encode a replication command and send it over our outbound connection""" """Encode a replication command and send it over our outbound connection"""
string = "%s %s" % (cmd.NAME, cmd.to_line()) string = "%s %s" % (cmd.NAME, cmd.to_line())
if "\n" in string: if "\n" in string:
@ -259,7 +259,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
hs.get_clock().looping_call(self._send_ping, 30 * 1000) hs.get_clock().looping_call(self._send_ping, 30 * 1000)
@wrap_as_background_process("redis_ping") @wrap_as_background_process("redis_ping")
async def _send_ping(self): async def _send_ping(self) -> None:
for connection in self.pool: for connection in self.pool:
try: try:
await make_deferred_yieldable(connection.ping()) await make_deferred_yieldable(connection.ping())
@ -269,13 +269,13 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
# ReconnectingClientFactory has some logging (if you enable `self.noisy`), but # ReconnectingClientFactory has some logging (if you enable `self.noisy`), but
# it's rubbish. We add our own here. # it's rubbish. We add our own here.
def startedConnecting(self, connector: IConnector): def startedConnecting(self, connector: IConnector) -> None:
logger.info( logger.info(
"Connecting to redis server %s", format_address(connector.getDestination()) "Connecting to redis server %s", format_address(connector.getDestination())
) )
super().startedConnecting(connector) super().startedConnecting(connector)
def clientConnectionFailed(self, connector: IConnector, reason: Failure): def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
logger.info( logger.info(
"Connection to redis server %s failed: %s", "Connection to redis server %s failed: %s",
format_address(connector.getDestination()), format_address(connector.getDestination()),
@ -283,7 +283,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
) )
super().clientConnectionFailed(connector, reason) super().clientConnectionFailed(connector, reason)
def clientConnectionLost(self, connector: IConnector, reason: Failure): def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
logger.info( logger.info(
"Connection to redis server %s lost: %s", "Connection to redis server %s lost: %s",
format_address(connector.getDestination()), format_address(connector.getDestination()),
@ -330,7 +330,7 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
self.synapse_outbound_redis_connection = outbound_redis_connection self.synapse_outbound_redis_connection = outbound_redis_connection
def buildProtocol(self, addr): def buildProtocol(self, addr: IAddress) -> RedisSubscriber:
p = super().buildProtocol(addr) p = super().buildProtocol(addr)
p = cast(RedisSubscriber, p) p = cast(RedisSubscriber, p)

View file

@ -16,16 +16,18 @@
import logging import logging
import random import random
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, List, Optional, Tuple
from prometheus_client import Counter from prometheus_client import Counter
from twisted.internet.interfaces import IAddress
from twisted.internet.protocol import ServerFactory from twisted.internet.protocol import ServerFactory
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.commands import PositionCommand from synapse.replication.tcp.commands import PositionCommand
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import EventsStream from synapse.replication.tcp.streams import EventsStream
from synapse.replication.tcp.streams._base import StreamRow, Token
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
if TYPE_CHECKING: if TYPE_CHECKING:
@ -56,7 +58,7 @@ class ReplicationStreamProtocolFactory(ServerFactory):
# listener config again or always starting a `ReplicationStreamer`.) # listener config again or always starting a `ReplicationStreamer`.)
hs.get_replication_streamer() hs.get_replication_streamer()
def buildProtocol(self, addr): def buildProtocol(self, addr: IAddress) -> ServerReplicationStreamProtocol:
return ServerReplicationStreamProtocol( return ServerReplicationStreamProtocol(
self.server_name, self.clock, self.command_handler self.server_name, self.clock, self.command_handler
) )
@ -105,7 +107,7 @@ class ReplicationStreamer:
if any(EventsStream.NAME == s.NAME for s in self.streams): if any(EventsStream.NAME == s.NAME for s in self.streams):
self.clock.looping_call(self.on_notifier_poke, 1000) self.clock.looping_call(self.on_notifier_poke, 1000)
def on_notifier_poke(self): def on_notifier_poke(self) -> None:
"""Checks if there is actually any new data and sends it to the """Checks if there is actually any new data and sends it to the
connections if there are. connections if there are.
@ -137,7 +139,7 @@ class ReplicationStreamer:
run_as_background_process("replication_notifier", self._run_notifier_loop) run_as_background_process("replication_notifier", self._run_notifier_loop)
async def _run_notifier_loop(self): async def _run_notifier_loop(self) -> None:
self.is_looping = True self.is_looping = True
try: try:
@ -238,7 +240,9 @@ class ReplicationStreamer:
self.is_looping = False self.is_looping = False
def _batch_updates(updates): def _batch_updates(
updates: List[Tuple[Token, StreamRow]]
) -> List[Tuple[Optional[Token], StreamRow]]:
"""Takes a list of updates of form [(token, row)] and sets the token to """Takes a list of updates of form [(token, row)] and sets the token to
None for all rows where the next row has the same token. This is used to None for all rows where the next row has the same token. This is used to
implement batching. implement batching.
@ -254,7 +258,7 @@ def _batch_updates(updates):
if not updates: if not updates:
return [] return []
new_updates = [] new_updates: List[Tuple[Optional[Token], StreamRow]] = []
for i, update in enumerate(updates[:-1]): for i, update in enumerate(updates[:-1]):
if update[0] == updates[i + 1][0]: if update[0] == updates[i + 1][0]:
new_updates.append((None, update[1])) new_updates.append((None, update[1]))

View file

@ -90,7 +90,7 @@ class Stream:
ROW_TYPE: Any = None ROW_TYPE: Any = None
@classmethod @classmethod
def parse_row(cls, row: StreamRow): def parse_row(cls, row: StreamRow) -> Any:
"""Parse a row received over replication """Parse a row received over replication
By default, assumes that the row data is an array object and passes its contents By default, assumes that the row data is an array object and passes its contents
@ -139,7 +139,7 @@ class Stream:
# The token from which we last asked for updates # The token from which we last asked for updates
self.last_token = self.current_token(self.local_instance_name) self.last_token = self.current_token(self.local_instance_name)
def discard_updates_and_advance(self): def discard_updates_and_advance(self) -> None:
"""Called when the stream should advance but the updates would be discarded, """Called when the stream should advance but the updates would be discarded,
e.g. when there are no currently connected workers. e.g. when there are no currently connected workers.
""" """
@ -200,7 +200,7 @@ def current_token_without_instance(
return lambda instance_name: current_token() return lambda instance_name: current_token()
def make_http_update_function(hs, stream_name: str) -> UpdateFunction: def make_http_update_function(hs: "HomeServer", stream_name: str) -> UpdateFunction:
"""Makes a suitable function for use as an `update_function` that queries """Makes a suitable function for use as an `update_function` that queries
the master process for updates. the master process for updates.
""" """

View file

@ -13,12 +13,16 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import heapq import heapq
from collections.abc import Iterable from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Type, TypeVar, cast
from typing import TYPE_CHECKING, Optional, Tuple, Type
import attr import attr
from ._base import Stream, StreamUpdateResult, Token from synapse.replication.tcp.streams._base import (
Stream,
StreamRow,
StreamUpdateResult,
Token,
)
if TYPE_CHECKING: if TYPE_CHECKING:
from synapse.server import HomeServer from synapse.server import HomeServer
@ -58,6 +62,9 @@ class EventsStreamRow:
data: "BaseEventsStreamRow" data: "BaseEventsStreamRow"
T = TypeVar("T", bound="BaseEventsStreamRow")
class BaseEventsStreamRow: class BaseEventsStreamRow:
"""Base class for rows to be sent in the events stream. """Base class for rows to be sent in the events stream.
@ -68,7 +75,7 @@ class BaseEventsStreamRow:
TypeId: str TypeId: str
@classmethod @classmethod
def from_data(cls, data): def from_data(cls: Type[T], data: Iterable[Optional[str]]) -> T:
"""Parse the data from the replication stream into a row. """Parse the data from the replication stream into a row.
By default we just call the constructor with the data list as arguments By default we just call the constructor with the data list as arguments
@ -221,7 +228,7 @@ class EventsStream(Stream):
return updates, upper_limit, limited return updates, upper_limit, limited
@classmethod @classmethod
def parse_row(cls, row): def parse_row(cls, row: StreamRow) -> "EventsStreamRow":
(typ, data) = row (typ, data) = cast(Tuple[str, Iterable[Optional[str]]], row)
data = TypeToRow[typ].from_data(data) event_stream_row_data = TypeToRow[typ].from_data(data)
return EventsStreamRow(typ, data) return EventsStreamRow(typ, event_stream_row_data)

View file

@ -16,8 +16,7 @@ import itertools
import re import re
import secrets import secrets
import string import string
from collections.abc import Iterable from typing import Iterable, Optional, Tuple
from typing import Optional, Tuple
from netaddr import valid_ipv6 from netaddr import valid_ipv6
@ -197,7 +196,7 @@ def shortstr(iterable: Iterable, maxitems: int = 5) -> str:
"""If iterable has maxitems or fewer, return the stringification of a list """If iterable has maxitems or fewer, return the stringification of a list
containing those items. containing those items.
Otherwise, return the stringification of a a list with the first maxitems items, Otherwise, return the stringification of a list with the first maxitems items,
followed by "...". followed by "...".
Args: Args:

View file

@ -14,6 +14,7 @@
import logging import logging
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol from twisted.internet.protocol import Protocol
from twisted.web.resource import Resource from twisted.web.resource import Resource
@ -53,7 +54,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
server_factory = ReplicationStreamProtocolFactory(hs) server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer() self.streamer = hs.get_replication_streamer()
self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol( self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
None IPv4Address("TCP", "127.0.0.1", 0)
) )
# Make a new HomeServer object for the worker # Make a new HomeServer object for the worker
@ -345,7 +346,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
self.clock, self.clock,
repl_handler, repl_handler,
) )
server = self.server_factory.buildProtocol(None) server = self.server_factory.buildProtocol(
IPv4Address("TCP", "127.0.0.1", 0)
)
client_transport = FakeTransport(server, self.reactor) client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport) client.makeConnection(client_transport)

View file

@ -14,6 +14,7 @@
from typing import Tuple from typing import Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.interfaces import IProtocol from twisted.internet.interfaces import IProtocol
from twisted.test.proto_helpers import StringTransport from twisted.test.proto_helpers import StringTransport
@ -29,7 +30,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
def _make_client(self) -> Tuple[IProtocol, StringTransport]: def _make_client(self) -> Tuple[IProtocol, StringTransport]:
"""Create a new direct TCP replication connection""" """Create a new direct TCP replication connection"""
proto = self.factory.buildProtocol(("127.0.0.1", 0)) proto = self.factory.buildProtocol(IPv4Address("TCP", "127.0.0.1", 0))
transport = StringTransport() transport = StringTransport()
proto.makeConnection(transport) proto.makeConnection(transport)