Fixup synapse.replication to pass mypy checks (#6667)

This commit is contained in:
Erik Johnston 2020-01-14 14:08:06 +00:00 committed by GitHub
parent 1177d3f3a3
commit e8b68a4e4b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 105 additions and 86 deletions

1
changelog.d/6667.misc Normal file
View file

@ -0,0 +1 @@
Fixup `synapse.replication` to pass mypy checks.

View file

@ -16,6 +16,7 @@
import abc import abc
import logging import logging
import re import re
from typing import Dict, List, Tuple
from six import raise_from from six import raise_from
from six.moves import urllib from six.moves import urllib
@ -78,9 +79,8 @@ class ReplicationEndpoint(object):
__metaclass__ = abc.ABCMeta __metaclass__ = abc.ABCMeta
NAME = abc.abstractproperty() NAME = abc.abstractproperty() # type: str # type: ignore
PATH_ARGS = abc.abstractproperty() PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore
METHOD = "POST" METHOD = "POST"
CACHE = True CACHE = True
RETRY_ON_TIMEOUT = True RETRY_ON_TIMEOUT = True
@ -171,7 +171,7 @@ class ReplicationEndpoint(object):
# have a good idea that the request has either succeeded or failed on # have a good idea that the request has either succeeded or failed on
# the master, and so whether we should clean up or not. # the master, and so whether we should clean up or not.
while True: while True:
headers = {} headers = {} # type: Dict[bytes, List[bytes]]
inject_active_span_byte_dict(headers, None, check_destination=False) inject_active_span_byte_dict(headers, None, check_destination=False)
try: try:
result = yield request_func(uri, data, headers=headers) result = yield request_func(uri, data, headers=headers)
@ -207,7 +207,7 @@ class ReplicationEndpoint(object):
method = self.METHOD method = self.METHOD
if self.CACHE: if self.CACHE:
handler = self._cached_handler handler = self._cached_handler # type: ignore
url_args.append("txn_id") url_args.append("txn_id")
args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args) args = "/".join("(?P<%s>[^/]+)" % (arg,) for arg in url_args)

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict from typing import Dict, Optional
import six import six
@ -41,7 +41,7 @@ class BaseSlavedStore(SQLBaseStore):
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
self._cache_id_gen = SlavedIdTracker( self._cache_id_gen = SlavedIdTracker(
db_conn, "cache_invalidation_stream", "stream_id" db_conn, "cache_invalidation_stream", "stream_id"
) ) # type: Optional[SlavedIdTracker]
else: else:
self._cache_id_gen = None self._cache_id_gen = None
@ -62,7 +62,8 @@ class BaseSlavedStore(SQLBaseStore):
def process_replication_rows(self, stream_name, token, rows): def process_replication_rows(self, stream_name, token, rows):
if stream_name == "caches": if stream_name == "caches":
self._cache_id_gen.advance(token) if self._cache_id_gen:
self._cache_id_gen.advance(token)
for row in rows: for row in rows:
if row.cache_func == CURRENT_STATE_CACHE_NAME: if row.cache_func == CURRENT_STATE_CACHE_NAME:
room_id = row.keys[0] room_id = row.keys[0]

View file

@ -29,7 +29,7 @@ class SlavedPresenceStore(BaseSlavedStore):
self._presence_on_startup = self._get_active_presence(db_conn) self._presence_on_startup = self._get_active_presence(db_conn)
self.presence_stream_cache = self.presence_stream_cache = StreamChangeCache( self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", self._presence_id_gen.get_current_token() "PresenceStreamChangeCache", self._presence_id_gen.get_current_token()
) )

View file

@ -16,7 +16,7 @@
""" """
import logging import logging
from typing import Dict from typing import Dict, List, Optional
from twisted.internet import defer from twisted.internet import defer
from twisted.internet.protocol import ReconnectingClientFactory from twisted.internet.protocol import ReconnectingClientFactory
@ -28,6 +28,7 @@ from synapse.replication.tcp.protocol import (
) )
from .commands import ( from .commands import (
Command,
FederationAckCommand, FederationAckCommand,
InvalidateCacheCommand, InvalidateCacheCommand,
RemovePusherCommand, RemovePusherCommand,
@ -89,15 +90,15 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
# Any pending commands to be sent once a new connection has been # Any pending commands to be sent once a new connection has been
# established # established
self.pending_commands = [] self.pending_commands = [] # type: List[Command]
# Map from string -> deferred, to wake up when receiveing a SYNC with # Map from string -> deferred, to wake up when receiveing a SYNC with
# the given string. # the given string.
# Used for tests. # Used for tests.
self.awaiting_syncs = {} self.awaiting_syncs = {} # type: Dict[str, defer.Deferred]
# The factory used to create connections. # The factory used to create connections.
self.factory = None self.factory = None # type: Optional[ReplicationClientFactory]
def start_replication(self, hs): def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server """Helper method to start a replication connection to the remote server
@ -235,4 +236,5 @@ class ReplicationClientHandler(AbstractReplicationClientHandler):
# We don't reset the delay any earlier as otherwise if there is a # We don't reset the delay any earlier as otherwise if there is a
# problem during start up we'll end up tight looping connecting to the # problem during start up we'll end up tight looping connecting to the
# server. # server.
self.factory.resetDelay() if self.factory:
self.factory.resetDelay()

View file

@ -20,15 +20,16 @@ allowed to be sent by which side.
import logging import logging
import platform import platform
from typing import Tuple, Type
if platform.python_implementation() == "PyPy": if platform.python_implementation() == "PyPy":
import json import json
_json_encoder = json.JSONEncoder() _json_encoder = json.JSONEncoder()
else: else:
import simplejson as json import simplejson as json # type: ignore[no-redef] # noqa: F821
_json_encoder = json.JSONEncoder(namedtuple_as_object=False) _json_encoder = json.JSONEncoder(namedtuple_as_object=False) # type: ignore[call-arg] # noqa: F821
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -44,7 +45,7 @@ class Command(object):
The default implementation creates a command of form `<NAME> <data>` The default implementation creates a command of form `<NAME> <data>`
""" """
NAME = None NAME = None # type: str
def __init__(self, data): def __init__(self, data):
self.data = data self.data = data
@ -386,25 +387,24 @@ class UserIpCommand(Command):
) )
_COMMANDS = (
ServerCommand,
RdataCommand,
PositionCommand,
ErrorCommand,
PingCommand,
NameCommand,
ReplicateCommand,
UserSyncCommand,
FederationAckCommand,
SyncCommand,
RemovePusherCommand,
InvalidateCacheCommand,
UserIpCommand,
) # type: Tuple[Type[Command], ...]
# Map of command name to command type. # Map of command name to command type.
COMMAND_MAP = { COMMAND_MAP = {cmd.NAME: cmd for cmd in _COMMANDS}
cmd.NAME: cmd
for cmd in (
ServerCommand,
RdataCommand,
PositionCommand,
ErrorCommand,
PingCommand,
NameCommand,
ReplicateCommand,
UserSyncCommand,
FederationAckCommand,
SyncCommand,
RemovePusherCommand,
InvalidateCacheCommand,
UserIpCommand,
)
}
# The commands the server is allowed to send # The commands the server is allowed to send
VALID_SERVER_COMMANDS = ( VALID_SERVER_COMMANDS = (

View file

@ -53,6 +53,7 @@ import fcntl
import logging import logging
import struct import struct
from collections import defaultdict from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set, Tuple
from six import iteritems, iterkeys from six import iteritems, iterkeys
@ -65,13 +66,11 @@ from twisted.python.failure import Failure
from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics import LaterGauge from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.util import Clock from synapse.replication.tcp.commands import (
from synapse.util.stringutils import random_string
from .commands import (
COMMAND_MAP, COMMAND_MAP,
VALID_CLIENT_COMMANDS, VALID_CLIENT_COMMANDS,
VALID_SERVER_COMMANDS, VALID_SERVER_COMMANDS,
Command,
ErrorCommand, ErrorCommand,
NameCommand, NameCommand,
PingCommand, PingCommand,
@ -82,6 +81,10 @@ from .commands import (
SyncCommand, SyncCommand,
UserSyncCommand, UserSyncCommand,
) )
from synapse.types import Collection
from synapse.util import Clock
from synapse.util.stringutils import random_string
from .streams import STREAMS_MAP from .streams import STREAMS_MAP
connection_close_counter = Counter( connection_close_counter = Counter(
@ -124,8 +127,11 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
delimiter = b"\n" delimiter = b"\n"
VALID_INBOUND_COMMANDS = [] # Valid commands we expect to receive # Valid commands we expect to receive
VALID_OUTBOUND_COMMANDS = [] # Valid commans we can send VALID_INBOUND_COMMANDS = [] # type: Collection[str]
# Valid commands we can send
VALID_OUTBOUND_COMMANDS = [] # type: Collection[str]
max_line_buffer = 10000 max_line_buffer = 10000
@ -144,13 +150,13 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver):
self.conn_id = random_string(5) # To dedupe in case of name clashes. self.conn_id = random_string(5) # To dedupe in case of name clashes.
# List of pending commands to send once we've established the connection # List of pending commands to send once we've established the connection
self.pending_commands = [] self.pending_commands = [] # type: List[Command]
# The LoopingCall for sending pings. # The LoopingCall for sending pings.
self._send_ping_loop = None self._send_ping_loop = None
self.inbound_commands_counter = defaultdict(int) self.inbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
self.outbound_commands_counter = defaultdict(int) self.outbound_commands_counter = defaultdict(int) # type: DefaultDict[str, int]
def connectionMade(self): def connectionMade(self):
logger.info("[%s] Connection established", self.id()) logger.info("[%s] Connection established", self.id())
@ -409,14 +415,14 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol):
self.streamer = streamer self.streamer = streamer
# The streams the client has subscribed to and is up to date with # The streams the client has subscribed to and is up to date with
self.replication_streams = set() self.replication_streams = set() # type: Set[str]
# The streams the client is currently subscribing to. # The streams the client is currently subscribing to.
self.connecting_streams = set() self.connecting_streams = set() # type: Set[str]
# Map from stream name to list of updates to send once we've finished # Map from stream name to list of updates to send once we've finished
# subscribing the client to the stream. # subscribing the client to the stream.
self.pending_rdata = {} self.pending_rdata = {} # type: Dict[str, List[Tuple[int, Any]]]
def connectionMade(self): def connectionMade(self):
self.send_command(ServerCommand(self.server_name)) self.send_command(ServerCommand(self.server_name))
@ -642,11 +648,11 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
# Set of stream names that have been subscribe to, but haven't yet # Set of stream names that have been subscribe to, but haven't yet
# caught up with. This is used to track when the client has been fully # caught up with. This is used to track when the client has been fully
# connected to the remote. # connected to the remote.
self.streams_connecting = set() self.streams_connecting = set() # type: Set[str]
# Map of stream to batched updates. See RdataCommand for info on how # Map of stream to batched updates. See RdataCommand for info on how
# batching works. # batching works.
self.pending_batches = {} self.pending_batches = {} # type: Dict[str, Any]
def connectionMade(self): def connectionMade(self):
self.send_command(NameCommand(self.client_name)) self.send_command(NameCommand(self.client_name))
@ -766,7 +772,7 @@ def transport_kernel_read_buffer_size(protocol, read=True):
op = SIOCINQ op = SIOCINQ
else: else:
op = SIOCOUTQ op = SIOCOUTQ
size = struct.unpack("I", fcntl.ioctl(fileno, op, "\0\0\0\0"))[0] size = struct.unpack("I", fcntl.ioctl(fileno, op, b"\0\0\0\0"))[0]
return size return size
return 0 return 0

View file

@ -17,6 +17,7 @@
import logging import logging
import random import random
from typing import List
from six import itervalues from six import itervalues
@ -79,7 +80,7 @@ class ReplicationStreamer(object):
self._replication_torture_level = hs.config.replication_torture_level self._replication_torture_level = hs.config.replication_torture_level
# Current connections. # Current connections.
self.connections = [] self.connections = [] # type: List[ServerReplicationStreamProtocol]
LaterGauge( LaterGauge(
"synapse_replication_tcp_resource_total_connections", "synapse_replication_tcp_resource_total_connections",

View file

@ -14,10 +14,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import itertools import itertools
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import Any
from twisted.internet import defer from twisted.internet import defer
@ -104,8 +104,9 @@ class Stream(object):
time it was called up until the point `advance_current_token` was called. time it was called up until the point `advance_current_token` was called.
""" """
NAME = None # The name of the stream NAME = None # type: str # The name of the stream
ROW_TYPE = None # The type of the row. Used by the default impl of parse_row. # The type of the row. Used by the default impl of parse_row.
ROW_TYPE = None # type: Any
_LIMITED = True # Whether the update function takes a limit _LIMITED = True # Whether the update function takes a limit
@classmethod @classmethod
@ -231,8 +232,8 @@ class BackfillStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_current_backfill_token self.current_token = store.get_current_backfill_token # type: ignore
self.update_function = store.get_all_new_backfill_event_rows self.update_function = store.get_all_new_backfill_event_rows # type: ignore
super(BackfillStream, self).__init__(hs) super(BackfillStream, self).__init__(hs)
@ -246,8 +247,8 @@ class PresenceStream(Stream):
store = hs.get_datastore() store = hs.get_datastore()
presence_handler = hs.get_presence_handler() presence_handler = hs.get_presence_handler()
self.current_token = store.get_current_presence_token self.current_token = store.get_current_presence_token # type: ignore
self.update_function = presence_handler.get_all_presence_updates self.update_function = presence_handler.get_all_presence_updates # type: ignore
super(PresenceStream, self).__init__(hs) super(PresenceStream, self).__init__(hs)
@ -260,8 +261,8 @@ class TypingStream(Stream):
def __init__(self, hs): def __init__(self, hs):
typing_handler = hs.get_typing_handler() typing_handler = hs.get_typing_handler()
self.current_token = typing_handler.get_current_token self.current_token = typing_handler.get_current_token # type: ignore
self.update_function = typing_handler.get_all_typing_updates self.update_function = typing_handler.get_all_typing_updates # type: ignore
super(TypingStream, self).__init__(hs) super(TypingStream, self).__init__(hs)
@ -273,8 +274,8 @@ class ReceiptsStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_max_receipt_stream_id self.current_token = store.get_max_receipt_stream_id # type: ignore
self.update_function = store.get_all_updated_receipts self.update_function = store.get_all_updated_receipts # type: ignore
super(ReceiptsStream, self).__init__(hs) super(ReceiptsStream, self).__init__(hs)
@ -310,8 +311,8 @@ class PushersStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_pushers_stream_token self.current_token = store.get_pushers_stream_token # type: ignore
self.update_function = store.get_all_updated_pushers_rows self.update_function = store.get_all_updated_pushers_rows # type: ignore
super(PushersStream, self).__init__(hs) super(PushersStream, self).__init__(hs)
@ -327,8 +328,8 @@ class CachesStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_cache_stream_token self.current_token = store.get_cache_stream_token # type: ignore
self.update_function = store.get_all_updated_caches self.update_function = store.get_all_updated_caches # type: ignore
super(CachesStream, self).__init__(hs) super(CachesStream, self).__init__(hs)
@ -343,8 +344,8 @@ class PublicRoomsStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_current_public_room_stream_id self.current_token = store.get_current_public_room_stream_id # type: ignore
self.update_function = store.get_all_new_public_rooms self.update_function = store.get_all_new_public_rooms # type: ignore
super(PublicRoomsStream, self).__init__(hs) super(PublicRoomsStream, self).__init__(hs)
@ -360,8 +361,8 @@ class DeviceListsStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_device_stream_token self.current_token = store.get_device_stream_token # type: ignore
self.update_function = store.get_all_device_list_changes_for_remotes self.update_function = store.get_all_device_list_changes_for_remotes # type: ignore
super(DeviceListsStream, self).__init__(hs) super(DeviceListsStream, self).__init__(hs)
@ -376,8 +377,8 @@ class ToDeviceStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_to_device_stream_token self.current_token = store.get_to_device_stream_token # type: ignore
self.update_function = store.get_all_new_device_messages self.update_function = store.get_all_new_device_messages # type: ignore
super(ToDeviceStream, self).__init__(hs) super(ToDeviceStream, self).__init__(hs)
@ -392,8 +393,8 @@ class TagAccountDataStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_max_account_data_stream_id self.current_token = store.get_max_account_data_stream_id # type: ignore
self.update_function = store.get_all_updated_tags self.update_function = store.get_all_updated_tags # type: ignore
super(TagAccountDataStream, self).__init__(hs) super(TagAccountDataStream, self).__init__(hs)
@ -408,7 +409,7 @@ class AccountDataStream(Stream):
def __init__(self, hs): def __init__(self, hs):
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.current_token = self.store.get_max_account_data_stream_id self.current_token = self.store.get_max_account_data_stream_id # type: ignore
super(AccountDataStream, self).__init__(hs) super(AccountDataStream, self).__init__(hs)
@ -434,8 +435,8 @@ class GroupServerStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_group_stream_token self.current_token = store.get_group_stream_token # type: ignore
self.update_function = store.get_all_groups_changes self.update_function = store.get_all_groups_changes # type: ignore
super(GroupServerStream, self).__init__(hs) super(GroupServerStream, self).__init__(hs)
@ -451,7 +452,7 @@ class UserSignatureStream(Stream):
def __init__(self, hs): def __init__(self, hs):
store = hs.get_datastore() store = hs.get_datastore()
self.current_token = store.get_device_stream_token self.current_token = store.get_device_stream_token # type: ignore
self.update_function = store.get_all_user_signature_changes_for_remotes self.update_function = store.get_all_user_signature_changes_for_remotes # type: ignore
super(UserSignatureStream, self).__init__(hs) super(UserSignatureStream, self).__init__(hs)

View file

@ -13,7 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import heapq import heapq
from typing import Tuple, Type
import attr import attr
@ -63,7 +65,8 @@ class BaseEventsStreamRow(object):
Specifies how to identify, serialize and deserialize the different types. Specifies how to identify, serialize and deserialize the different types.
""" """
TypeId = None # Unique string that ids the type. Must be overriden in sub classes. # Unique string that ids the type. Must be overriden in sub classes.
TypeId = None # type: str
@classmethod @classmethod
def from_data(cls, data): def from_data(cls, data):
@ -99,9 +102,12 @@ class EventsStreamCurrentStateRow(BaseEventsStreamRow):
event_id = attr.ib() # str, optional event_id = attr.ib() # str, optional
TypeToRow = { _EventRows = (
Row.TypeId: Row for Row in (EventsStreamEventRow, EventsStreamCurrentStateRow) EventsStreamEventRow,
} EventsStreamCurrentStateRow,
) # type: Tuple[Type[BaseEventsStreamRow], ...]
TypeToRow = {Row.TypeId: Row for Row in _EventRows}
class EventsStream(Stream): class EventsStream(Stream):
@ -112,7 +118,7 @@ class EventsStream(Stream):
def __init__(self, hs): def __init__(self, hs):
self._store = hs.get_datastore() self._store = hs.get_datastore()
self.current_token = self._store.get_current_events_token self.current_token = self._store.get_current_events_token # type: ignore
super(EventsStream, self).__init__(hs) super(EventsStream, self).__init__(hs)

View file

@ -37,7 +37,7 @@ class FederationStream(Stream):
def __init__(self, hs): def __init__(self, hs):
federation_sender = hs.get_federation_sender() federation_sender = hs.get_federation_sender()
self.current_token = federation_sender.get_current_token self.current_token = federation_sender.get_current_token # type: ignore
self.update_function = federation_sender.get_replication_rows self.update_function = federation_sender.get_replication_rows # type: ignore
super(FederationStream, self).__init__(hs) super(FederationStream, self).__init__(hs)

View file

@ -181,6 +181,7 @@ commands = mypy \
synapse/handlers/ui_auth \ synapse/handlers/ui_auth \
synapse/logging/ \ synapse/logging/ \
synapse/module_api \ synapse/module_api \
synapse/replication \
synapse/rest/consent \ synapse/rest/consent \
synapse/rest/saml2 \ synapse/rest/saml2 \
synapse/spam_checker_api \ synapse/spam_checker_api \