Fix up some typechecking (#6150)

* type checking fixes

* changelog
This commit is contained in:
Amber Brown 2019-10-02 05:29:01 -07:00 committed by GitHub
parent 2a1470cd05
commit 864f144543
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
22 changed files with 104 additions and 40 deletions

1
.gitignore vendored
View file

@ -10,6 +10,7 @@
*.tac *.tac
_trial_temp/ _trial_temp/
_trial_temp*/ _trial_temp*/
/out
# stuff that is likely to exist when you run a server locally # stuff that is likely to exist when you run a server locally
/*.db /*.db

1
changelog.d/6150.misc Normal file
View file

@ -0,0 +1 @@
Expand type-checking on modules imported by synapse.config.

View file

@ -17,6 +17,7 @@
"""Contains exceptions and error codes.""" """Contains exceptions and error codes."""
import logging import logging
from typing import Dict
from six import iteritems from six import iteritems
from six.moves import http_client from six.moves import http_client
@ -111,7 +112,7 @@ class ProxiedRequestError(SynapseError):
def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None): def __init__(self, code, msg, errcode=Codes.UNKNOWN, additional_fields=None):
super(ProxiedRequestError, self).__init__(code, msg, errcode) super(ProxiedRequestError, self).__init__(code, msg, errcode)
if additional_fields is None: if additional_fields is None:
self._additional_fields = {} self._additional_fields = {} # type: Dict
else: else:
self._additional_fields = dict(additional_fields) self._additional_fields = dict(additional_fields)

View file

@ -12,6 +12,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict
import attr import attr
@ -102,4 +105,4 @@ KNOWN_ROOM_VERSIONS = {
RoomVersions.V4, RoomVersions.V4,
RoomVersions.V5, RoomVersions.V5,
) )
} # type: dict[str, RoomVersion] } # type: Dict[str, RoomVersion]

View file

@ -263,7 +263,9 @@ def start(hs, listeners=None):
refresh_certificate(hs) refresh_certificate(hs)
# Start the tracer # Start the tracer
synapse.logging.opentracing.init_tracer(hs.config) synapse.logging.opentracing.init_tracer( # type: ignore[attr-defined] # noqa
hs.config
)
# It is now safe to start your Synapse. # It is now safe to start your Synapse.
hs.start_listening(listeners) hs.start_listening(listeners)

View file

@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Dict
from six import string_types from six import string_types
from six.moves.urllib import parse as urlparse from six.moves.urllib import parse as urlparse
@ -56,8 +57,8 @@ def load_appservices(hostname, config_files):
return [] return []
# Dicts of value -> filename # Dicts of value -> filename
seen_as_tokens = {} seen_as_tokens = {} # type: Dict[str, str]
seen_ids = {} seen_ids = {} # type: Dict[str, str]
appservices = [] appservices = []

View file

@ -73,8 +73,8 @@ DEFAULT_CONFIG = """\
class ConsentConfig(Config): class ConsentConfig(Config):
def __init__(self): def __init__(self, *args):
super(ConsentConfig, self).__init__() super(ConsentConfig, self).__init__(*args)
self.user_consent_version = None self.user_consent_version = None
self.user_consent_template_dir = None self.user_consent_template_dir = None

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Any, List
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
from ._base import Config from ._base import Config
@ -22,7 +24,7 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
class PasswordAuthProviderConfig(Config): class PasswordAuthProviderConfig(Config):
def read_config(self, config, **kwargs): def read_config(self, config, **kwargs):
self.password_providers = [] self.password_providers = [] # type: List[Any]
providers = [] providers = []
# We want to be backwards compatible with the old `ldap_config` # We want to be backwards compatible with the old `ldap_config`

View file

@ -15,6 +15,7 @@
import os import os
from collections import namedtuple from collections import namedtuple
from typing import Dict, List
from synapse.python_dependencies import DependencyException, check_requirements from synapse.python_dependencies import DependencyException, check_requirements
from synapse.util.module_loader import load_module from synapse.util.module_loader import load_module
@ -61,7 +62,7 @@ def parse_thumbnail_requirements(thumbnail_sizes):
Dictionary mapping from media type string to list of Dictionary mapping from media type string to list of
ThumbnailRequirement tuples. ThumbnailRequirement tuples.
""" """
requirements = {} requirements = {} # type: Dict[str, List]
for size in thumbnail_sizes: for size in thumbnail_sizes:
width = size["width"] width = size["width"]
height = size["height"] height = size["height"]
@ -130,7 +131,7 @@ class ContentRepositoryConfig(Config):
# #
# We don't create the storage providers here as not all workers need # We don't create the storage providers here as not all workers need
# them to be started. # them to be started.
self.media_storage_providers = [] self.media_storage_providers = [] # type: List[tuple]
for provider_config in storage_providers: for provider_config in storage_providers:
# We special case the module "file_system" so as not to need to # We special case the module "file_system" so as not to need to

View file

@ -19,6 +19,7 @@ import logging
import os.path import os.path
import re import re
from textwrap import indent from textwrap import indent
from typing import List
import attr import attr
import yaml import yaml
@ -243,7 +244,7 @@ class ServerConfig(Config):
# events with profile information that differ from the target's global profile. # events with profile information that differ from the target's global profile.
self.allow_per_room_profiles = config.get("allow_per_room_profiles", True) self.allow_per_room_profiles = config.get("allow_per_room_profiles", True)
self.listeners = [] self.listeners = [] # type: List[dict]
for listener in config.get("listeners", []): for listener in config.get("listeners", []):
if not isinstance(listener.get("port", None), int): if not isinstance(listener.get("port", None), int):
raise ConfigError( raise ConfigError(
@ -287,7 +288,10 @@ class ServerConfig(Config):
validator=attr.validators.instance_of(bool), default=False validator=attr.validators.instance_of(bool), default=False
) )
complexity = attr.ib( complexity = attr.ib(
validator=attr.validators.instance_of((int, float)), default=1.0 validator=attr.validators.instance_of(
(float, int) # type: ignore[arg-type] # noqa
),
default=1.0,
) )
complexity_error = attr.ib( complexity_error = attr.ib(
validator=attr.validators.instance_of(str), validator=attr.validators.instance_of(str),
@ -366,7 +370,7 @@ class ServerConfig(Config):
"cleanup_extremities_with_dummy_events", True "cleanup_extremities_with_dummy_events", True
) )
def has_tls_listener(self): def has_tls_listener(self) -> bool:
return any(l["tls"] for l in self.listeners) return any(l["tls"] for l in self.listeners)
def generate_config_section( def generate_config_section(

View file

@ -59,8 +59,8 @@ class ServerNoticesConfig(Config):
None if server notices are not enabled. None if server notices are not enabled.
""" """
def __init__(self): def __init__(self, *args):
super(ServerNoticesConfig, self).__init__() super(ServerNoticesConfig, self).__init__(*args)
self.server_notices_mxid = None self.server_notices_mxid = None
self.server_notices_mxid_display_name = None self.server_notices_mxid_display_name = None
self.server_notices_mxid_avatar_url = None self.server_notices_mxid_avatar_url = None

View file

@ -170,6 +170,7 @@ import inspect
import logging import logging
import re import re
from functools import wraps from functools import wraps
from typing import Dict
from canonicaljson import json from canonicaljson import json
@ -547,7 +548,7 @@ def inject_active_span_twisted_headers(headers, destination, check_destination=T
return return
span = opentracing.tracer.active_span span = opentracing.tracer.active_span
carrier = {} carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier) opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items(): for key, value in carrier.items():
@ -584,7 +585,7 @@ def inject_active_span_byte_dict(headers, destination, check_destination=True):
span = opentracing.tracer.active_span span = opentracing.tracer.active_span
carrier = {} carrier = {} # type: Dict[str, str]
opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier) opentracing.tracer.inject(span, opentracing.Format.HTTP_HEADERS, carrier)
for key, value in carrier.items(): for key, value in carrier.items():
@ -639,7 +640,7 @@ def get_active_span_text_map(destination=None):
if destination and not whitelisted_homeserver(destination): if destination and not whitelisted_homeserver(destination):
return {} return {}
carrier = {} carrier = {} # type: Dict[str, str]
opentracing.tracer.inject( opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier
) )
@ -653,7 +654,7 @@ def active_span_context_as_string():
Returns: Returns:
The active span context encoded as a string. The active span context encoded as a string.
""" """
carrier = {} carrier = {} # type: Dict[str, str]
if opentracing: if opentracing:
opentracing.tracer.inject( opentracing.tracer.inject(
opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier opentracing.tracer.active_span, opentracing.Format.TEXT_MAP, carrier

View file

@ -119,7 +119,11 @@ def trace_function(f):
logger = logging.getLogger(name) logger = logging.getLogger(name)
level = logging.DEBUG level = logging.DEBUG
s = inspect.currentframe().f_back frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back
to_print = [ to_print = [
"\t%s:%s %s. Args: args=%s, kwargs=%s" "\t%s:%s %s. Args: args=%s, kwargs=%s"
@ -144,7 +148,7 @@ def trace_function(f):
pathname=pathname, pathname=pathname,
lineno=lineno, lineno=lineno,
msg=msg, msg=msg,
args=None, args=tuple(),
exc_info=None, exc_info=None,
) )
@ -157,7 +161,12 @@ def trace_function(f):
def get_previous_frames(): def get_previous_frames():
s = inspect.currentframe().f_back.f_back
frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back.f_back
to_return = [] to_return = []
while s: while s:
if s.f_globals["__name__"].startswith("synapse"): if s.f_globals["__name__"].startswith("synapse"):
@ -174,7 +183,10 @@ def get_previous_frames():
def get_previous_frame(ignore=[]): def get_previous_frame(ignore=[]):
s = inspect.currentframe().f_back.f_back frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back.f_back
while s: while s:
if s.f_globals["__name__"].startswith("synapse"): if s.f_globals["__name__"].startswith("synapse"):

View file

@ -125,7 +125,7 @@ class InFlightGauge(object):
) )
# Counts number of in flight blocks for a given set of label values # Counts number of in flight blocks for a given set of label values
self._registrations = {} self._registrations = {} # type: Dict
# Protects access to _registrations # Protects access to _registrations
self._lock = threading.Lock() self._lock = threading.Lock()
@ -226,7 +226,7 @@ class BucketCollector(object):
# Fetch the data -- this must be synchronous! # Fetch the data -- this must be synchronous!
data = self.data_collector() data = self.data_collector()
buckets = {} buckets = {} # type: Dict[float, int]
res = [] res = []
for x in data.keys(): for x in data.keys():

View file

@ -36,9 +36,9 @@ from twisted.web.resource import Resource
try: try:
from prometheus_client.samples import Sample from prometheus_client.samples import Sample
except ImportError: except ImportError:
Sample = namedtuple( Sample = namedtuple( # type: ignore[no-redef] # noqa
"Sample", ["name", "labels", "value", "timestamp", "exemplar"] "Sample", ["name", "labels", "value", "timestamp", "exemplar"]
) # type: ignore )
CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8") CONTENT_TYPE_LATEST = str("text/plain; version=0.0.4; charset=utf-8")

View file

@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Set from typing import List, Set
from pkg_resources import ( from pkg_resources import (
DistributionNotFound, DistributionNotFound,
@ -73,6 +73,7 @@ REQUIREMENTS = [
"netaddr>=0.7.18", "netaddr>=0.7.18",
"Jinja2>=2.9", "Jinja2>=2.9",
"bleach>=1.4.3", "bleach>=1.4.3",
"typing-extensions>=3.7.4",
] ]
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
@ -144,7 +145,11 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency) deps_needed.append(dependency)
errors.append( errors.append(
"Needed %s, got %s==%s" "Needed %s, got %s==%s"
% (dependency, e.dist.project_name, e.dist.version) % (
dependency,
e.dist.project_name, # type: ignore[attr-defined] # noqa
e.dist.version, # type: ignore[attr-defined] # noqa
)
) )
except DistributionNotFound: except DistributionNotFound:
deps_needed.append(dependency) deps_needed.append(dependency)
@ -159,7 +164,7 @@ def check_requirements(for_feature=None):
if not for_feature: if not for_feature:
# Check the optional dependencies are up to date. We allow them to not be # Check the optional dependencies are up to date. We allow them to not be
# installed. # installed.
OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) OPTS = sum(CONDITIONAL_REQUIREMENTS.values(), []) # type: List[str]
for dependency in OPTS: for dependency in OPTS:
try: try:
@ -168,7 +173,11 @@ def check_requirements(for_feature=None):
deps_needed.append(dependency) deps_needed.append(dependency)
errors.append( errors.append(
"Needed optional %s, got %s==%s" "Needed optional %s, got %s==%s"
% (dependency, e.dist.project_name, e.dist.version) % (
dependency,
e.dist.project_name, # type: ignore[attr-defined] # noqa
e.dist.version, # type: ignore[attr-defined] # noqa
)
) )
except DistributionNotFound: except DistributionNotFound:
# If it's not found, we don't care # If it's not found, we don't care

View file

@ -318,6 +318,7 @@ class StreamToken(
) )
): ):
_SEPARATOR = "_" _SEPARATOR = "_"
START = None # type: StreamToken
@classmethod @classmethod
def from_string(cls, string): def from_string(cls, string):
@ -402,7 +403,7 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
followed by the "stream_ordering" id of the event it comes after. followed by the "stream_ordering" id of the event it comes after.
""" """
__slots__ = [] __slots__ = [] # type: list
@classmethod @classmethod
def parse(cls, string): def parse(cls, string):

View file

@ -13,9 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import collections import collections
import logging import logging
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, Sequence, Set, Union
from six.moves import range from six.moves import range
@ -213,7 +215,9 @@ class Linearizer(object):
# the first element is the number of things executing, and # the first element is the number of things executing, and
# the second element is an OrderedDict, where the keys are deferreds for the # the second element is an OrderedDict, where the keys are deferreds for the
# things blocked from executing. # things blocked from executing.
self.key_to_defer = {} self.key_to_defer = (
{}
) # type: Dict[str, Sequence[Union[int, Dict[defer.Deferred, int]]]]
def queue(self, key): def queue(self, key):
# we avoid doing defer.inlineCallbacks here, so that cancellation works correctly. # we avoid doing defer.inlineCallbacks here, so that cancellation works correctly.
@ -340,10 +344,10 @@ class ReadWriteLock(object):
def __init__(self): def __init__(self):
# Latest readers queued # Latest readers queued
self.key_to_current_readers = {} self.key_to_current_readers = {} # type: Dict[str, Set[defer.Deferred]]
# Latest writer queued # Latest writer queued
self.key_to_current_writer = {} self.key_to_current_writer = {} # type: Dict[str, defer.Deferred]
@defer.inlineCallbacks @defer.inlineCallbacks
def read(self, key): def read(self, key):

View file

@ -16,6 +16,7 @@
import logging import logging
import os import os
from typing import Dict
import six import six
from six.moves import intern from six.moves import intern
@ -37,7 +38,7 @@ def get_cache_factor_for(cache_name):
caches_by_name = {} caches_by_name = {}
collectors_by_name = {} collectors_by_name = {} # type: Dict
cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"]) cache_size = Gauge("synapse_util_caches_cache:size", "", ["name"])
cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"]) cache_hits = Gauge("synapse_util_caches_cache:hits", "", ["name"])

View file

@ -18,10 +18,12 @@ import inspect
import logging import logging
import threading import threading
from collections import namedtuple from collections import namedtuple
from typing import Any, cast
from six import itervalues from six import itervalues
from prometheus_client import Gauge from prometheus_client import Gauge
from typing_extensions import Protocol
from twisted.internet import defer from twisted.internet import defer
@ -37,6 +39,18 @@ from . import register_cache
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class _CachedFunction(Protocol):
invalidate = None # type: Any
invalidate_all = None # type: Any
invalidate_many = None # type: Any
prefill = None # type: Any
cache = None # type: Any
num_args = None # type: Any
def __name__(self):
...
cache_pending_metric = Gauge( cache_pending_metric = Gauge(
"synapse_util_caches_cache_pending", "synapse_util_caches_cache_pending",
"Number of lookups currently pending for this cache", "Number of lookups currently pending for this cache",
@ -245,7 +259,9 @@ class Cache(object):
class _CacheDescriptorBase(object): class _CacheDescriptorBase(object):
def __init__(self, orig, num_args, inlineCallbacks, cache_context=False): def __init__(
self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
):
self.orig = orig self.orig = orig
if inlineCallbacks: if inlineCallbacks:
@ -404,7 +420,7 @@ class CacheDescriptor(_CacheDescriptorBase):
return tuple(get_cache_key_gen(args, kwargs)) return tuple(get_cache_key_gen(args, kwargs))
@functools.wraps(self.orig) @functools.wraps(self.orig)
def wrapped(*args, **kwargs): def _wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate() # If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated # whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None) invalidate_callback = kwargs.pop("on_invalidate", None)
@ -440,6 +456,8 @@ class CacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(observer) return make_deferred_yieldable(observer)
wrapped = cast(_CachedFunction, _wrapped)
if self.num_args == 1: if self.num_args == 1:
wrapped.invalidate = lambda key: cache.invalidate(key[0]) wrapped.invalidate = lambda key: cache.invalidate(key[0])
wrapped.prefill = lambda key, val: cache.prefill(key[0], val) wrapped.prefill = lambda key, val: cache.prefill(key[0], val)

View file

@ -1,3 +1,5 @@
from typing import Dict
from six import itervalues from six import itervalues
SENTINEL = object() SENTINEL = object()
@ -12,7 +14,7 @@ class TreeCache(object):
def __init__(self): def __init__(self):
self.size = 0 self.size = 0
self.root = {} self.root = {} # type: Dict
def __setitem__(self, key, value): def __setitem__(self, key, value):
return self.set(key, value) return self.set(key, value)

View file

@ -54,5 +54,5 @@ def load_python_module(location: str):
if spec is None: if spec is None:
raise Exception("Unable to load module at %s" % (location,)) raise Exception("Unable to load module at %s" % (location,))
mod = importlib.util.module_from_spec(spec) mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod) spec.loader.exec_module(mod) # type: ignore
return mod return mod