diff --git a/synapse/logging/context.py b/synapse/logging/context.py index e52567afa0..3409ddf0d0 100644 --- a/synapse/logging/context.py +++ b/synapse/logging/context.py @@ -34,7 +34,6 @@ import threading import typing import warnings from collections.abc import Coroutine, Generator -from contextvars import ContextVar from types import TracebackType from typing import ( TYPE_CHECKING, @@ -235,7 +234,14 @@ LoggingContextOrSentinel = Union["LoggingContext", "_Sentinel"] class _Sentinel: """Sentinel to represent the root context""" - __slots__ = ["previous_context", "finished", "request", "scope", "tag"] + __slots__ = [ + "previous_context", + "finished", + "request", + "scope", + "tag", + "metrics_name", + ] def __init__(self) -> None: # Minimal set for compatibility with LoggingContext @@ -244,6 +250,7 @@ class _Sentinel: self.request = None self.scope = None self.tag = None + self.metrics_name = None def __str__(self) -> str: return "sentinel" @@ -296,6 +303,7 @@ class LoggingContext: "request", "tag", "scope", + "metrics_name", ] def __init__( @@ -306,6 +314,8 @@ class LoggingContext: ) -> None: self.previous_context = current_context() + self.metrics_name: Optional[str] = None + # track the resources used by this context so far self._resource_usage = ContextResourceUsage() @@ -339,6 +349,7 @@ class LoggingContext: # if we don't have a `name`, but do have a parent context, use its name. if self.parent_context and name is None: name = str(self.parent_context) + self.metrics_name = self.parent_context.metrics_name if name is None: raise ValueError( "LoggingContext must be given either a name or a parent context" @@ -821,14 +832,14 @@ def run_in_background( d: "defer.Deferred[R]" if isinstance(res, typing.Coroutine): # Wrap the coroutine in a `Deferred`. - d = defer.ensureDeferred(measure_coroutine(current.name, res)) + d = defer.ensureDeferred(measure_coroutine(current.metrics_name, res)) elif isinstance(res, defer.Deferred): d = res elif isinstance(res, Awaitable): # `res` is probably some kind of completed awaitable, such as a `DoneAwaitable` # or `Future` from `make_awaitable`. d = defer.ensureDeferred( - measure_coroutine(current.name, _unwrap_awaitable(res)) + measure_coroutine(current.metrics_name, _unwrap_awaitable(res)) ) else: # `res` is a plain value. Wrap it in a `Deferred`. @@ -1069,6 +1080,10 @@ class _ResourceTracker2(Coroutine[defer.Deferred[Any], Any, _T]): async def measure_coroutine( - name: str, co: Coroutine[defer.Deferred[Any], Any, _T] + name: Optional[str], co: Coroutine[defer.Deferred[Any], Any, _T] ) -> _T: + if not name: + return await co + + current_context().metrics_name = name return await _ResourceTracker2(name, co)