diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index 6592a4a6b7..4830c45432 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -22,7 +22,13 @@ import mypy.types from mypy.erasetype import remove_instance_last_known_values from mypy.errorcodes import ErrorCode from mypy.nodes import ARG_NAMED_OPT, TempNode, Var -from mypy.plugin import FunctionSigContext, MethodSigContext, Plugin +from mypy.plugin import ( + FunctionSigContext, + MethodSigContext, + Plugin, + AttributeContext, + CheckerPluginInterface, +) from mypy.typeops import bind_self from mypy.types import ( AnyType, @@ -56,6 +62,15 @@ class SynapsePlugin(Plugin): return None + def get_attribute_hook( + self, fullname: str + ) -> Optional[Callable[[AttributeContext], mypy.types.Type]]: + # Anything in synapse could be wrapped with the cached decorator, but + # we know that anything else is *not*. + if fullname.startswith("synapse."): + return cached_function_method_attribute + return None + def _get_true_return_type(signature: CallableType) -> mypy.types.Type: """ @@ -79,6 +94,26 @@ def _get_true_return_type(signature: CallableType) -> mypy.types.Type: def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: + return _unwrap_cached_decoratored_function(ctx.default_signature, ctx.api) + + +def cached_function_method_attribute(ctx: AttributeContext) -> mypy.types.Type: + if isinstance(ctx.default_attr_type, Instance): + if ( + ctx.default_attr_type.type.fullname + == "synapse.util.caches.descriptors.CachedFunction" + ): + # Unwrap the wrapped function. + return _unwrap_cached_decoratored_function( + ctx.default_attr_type.args[0], ctx.api + ) + + return ctx.default_attr_type + + +def _unwrap_cached_decoratored_function( + wrapped_signature: CallableType, api: CheckerPluginInterface +) -> CallableType: """Fixes the `CachedFunction.__call__` signature to be correct. It already has *almost* the correct signature, except: @@ -90,12 +125,12 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: """ # 1. Mark this as a bound function signature. - signature: CallableType = bind_self(ctx.default_signature) + signature: CallableType = bind_self(wrapped_signature) # 2. Remove any "cache_context" args. # # Note: We should be only doing this if `cache_context=True` is set, but if - # it isn't then the code will raise an exception when its called anyway, so + # it isn't then the code will raise an exception when it's called anyway, so # it's not the end of the world. context_arg_index = None for idx, name in enumerate(signature.arg_names): @@ -125,7 +160,7 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: arg_kinds=[], arg_names=[], ret_type=NoneType(), - fallback=ctx.api.named_generic_type("builtins.function", []), + fallback=api.named_generic_type("builtins.function", []), ), ] ) @@ -137,12 +172,12 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: # 4. Ensure the return type is a Deferred. ret_arg = _get_true_return_type(signature) - # This should be able to use ctx.api.named_generic_type, but that doesn't seem + # This should be able to use api.named_generic_type, but that doesn't seem # to find the correct symbol for anything more than 1 module deep. # # modules is not part of CheckerPluginInterface. The following is a combination # of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo. - sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined] + sym = api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined] ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)]) signature = signature.copy_modified( diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 8a55e4e41d..9360daff2b 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -395,9 +395,9 @@ def gather_results( def gather_results( # type: ignore[misc] - deferredList: Tuple["defer.Deferred[T1]", ...], + deferredList: Tuple["defer.Deferred[Any]", ...], consumeErrors: bool = False, -) -> "defer.Deferred[Tuple[T1, ...]]": +) -> "defer.Deferred[Tuple[Any, ...]]": """Combines a tuple of `Deferred`s into a single `Deferred`. Wraps `defer.gatherResults` to provide type annotations that support heterogenous