Add missing types to opentracing. (#13345)

After this change `synapse.logging` is fully typed.
This commit is contained in:
Patrick Cloke 2022-07-21 08:01:52 -04:00 committed by GitHub
parent 190f49d8ab
commit 50122754c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 83 additions and 45 deletions

View file

@ -1 +1 @@
Add type hints to `trace` decorator. Add missing type hints to open tracing module.

1
changelog.d/13345.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to open tracing module.

View file

@ -84,9 +84,6 @@ disallow_untyped_defs = False
[mypy-synapse.http.matrixfederationclient] [mypy-synapse.http.matrixfederationclient]
disallow_untyped_defs = False disallow_untyped_defs = False
[mypy-synapse.logging.opentracing]
disallow_untyped_defs = False
[mypy-synapse.metrics._reactor_metrics] [mypy-synapse.metrics._reactor_metrics]
disallow_untyped_defs = False disallow_untyped_defs = False
# This module imports select.epoll. That exists on Linux, but doesn't on macOS. # This module imports select.epoll. That exists on Linux, but doesn't on macOS.

View file

@ -309,7 +309,7 @@ class BaseFederationServlet:
raise raise
# update the active opentracing span with the authenticated entity # update the active opentracing span with the authenticated entity
set_tag("authenticated_entity", origin) set_tag("authenticated_entity", str(origin))
# if the origin is authenticated and whitelisted, use its span context # if the origin is authenticated and whitelisted, use its span context
# as the parent. # as the parent.

View file

@ -118,8 +118,8 @@ class DeviceWorkerHandler:
ips = await self.store.get_last_client_ip_by_device(user_id, device_id) ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips) _update_device_from_client_ips(device, ips)
set_tag("device", device) set_tag("device", str(device))
set_tag("ips", ips) set_tag("ips", str(ips))
return device return device
@ -170,7 +170,7 @@ class DeviceWorkerHandler:
""" """
set_tag("user_id", user_id) set_tag("user_id", user_id)
set_tag("from_token", from_token) set_tag("from_token", str(from_token))
now_room_key = self.store.get_room_max_token() now_room_key = self.store.get_room_max_token()
room_ids = await self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)
@ -795,7 +795,7 @@ class DeviceListUpdater:
""" """
set_tag("origin", origin) set_tag("origin", origin)
set_tag("edu_content", edu_content) set_tag("edu_content", str(edu_content))
user_id = edu_content.pop("user_id") user_id = edu_content.pop("user_id")
device_id = edu_content.pop("device_id") device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints stream_id = str(edu_content.pop("stream_id")) # They may come as ints

View file

@ -138,8 +138,8 @@ class E2eKeysHandler:
else: else:
remote_queries[user_id] = device_ids remote_queries[user_id] = device_ids
set_tag("local_key_query", local_query) set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", remote_queries) set_tag("remote_key_query", str(remote_queries))
# First get local devices. # First get local devices.
# A map of destination -> failure response. # A map of destination -> failure response.
@ -343,7 +343,7 @@ class E2eKeysHandler:
failure = _exception_to_failure(e) failure = _exception_to_failure(e)
failures[destination] = failure failures[destination] = failure
set_tag("error", True) set_tag("error", True)
set_tag("reason", failure) set_tag("reason", str(failure))
return return
@ -405,7 +405,7 @@ class E2eKeysHandler:
Returns: Returns:
A map from user_id -> device_id -> device details A map from user_id -> device_id -> device details
""" """
set_tag("local_query", query) set_tag("local_query", str(query))
local_query: List[Tuple[str, Optional[str]]] = [] local_query: List[Tuple[str, Optional[str]]] = []
result_dict: Dict[str, Dict[str, dict]] = {} result_dict: Dict[str, Dict[str, dict]] = {}
@ -477,8 +477,8 @@ class E2eKeysHandler:
domain = get_domain_from_id(user_id) domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys remote_queries.setdefault(domain, {})[user_id] = one_time_keys
set_tag("local_key_query", local_query) set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", remote_queries) set_tag("remote_key_query", str(remote_queries))
results = await self.store.claim_e2e_one_time_keys(local_query) results = await self.store.claim_e2e_one_time_keys(local_query)
@ -508,7 +508,7 @@ class E2eKeysHandler:
failure = _exception_to_failure(e) failure = _exception_to_failure(e)
failures[destination] = failure failures[destination] = failure
set_tag("error", True) set_tag("error", True)
set_tag("reason", failure) set_tag("reason", str(failure))
await make_deferred_yieldable( await make_deferred_yieldable(
defer.gatherResults( defer.gatherResults(
@ -611,7 +611,7 @@ class E2eKeysHandler:
result = await self.store.count_e2e_one_time_keys(user_id, device_id) result = await self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", result) set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result} return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user( async def _upload_one_time_keys_for_user(

View file

@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Dict, Optional from typing import TYPE_CHECKING, Dict, Optional, cast
from typing_extensions import Literal from typing_extensions import Literal
@ -97,7 +97,7 @@ class E2eRoomKeysHandler:
user_id, version, room_id, session_id user_id, version, room_id, session_id
) )
log_kv(results) log_kv(cast(JsonDict, results))
return results return results
@trace @trace

View file

@ -182,6 +182,8 @@ from typing import (
Type, Type,
TypeVar, TypeVar,
Union, Union,
cast,
overload,
) )
import attr import attr
@ -328,6 +330,7 @@ class _Sentinel(enum.Enum):
P = ParamSpec("P") P = ParamSpec("P")
R = TypeVar("R") R = TypeVar("R")
T = TypeVar("T")
def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]: def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
@ -343,22 +346,43 @@ def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
return _only_if_tracing_inner return _only_if_tracing_inner
def ensure_active_span(message: str, ret=None): @overload
def ensure_active_span(
message: str,
) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]:
...
@overload
def ensure_active_span(
message: str, ret: T
) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]:
...
def ensure_active_span(
message: str, ret: Optional[T] = None
) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]:
"""Executes the operation only if opentracing is enabled and there is an active span. """Executes the operation only if opentracing is enabled and there is an active span.
If there is no active span it logs message at the error level. If there is no active span it logs message at the error level.
Args: Args:
message: Message which fills in "There was no active span when trying to %s" message: Message which fills in "There was no active span when trying to %s"
in the error log if there is no active span and opentracing is enabled. in the error log if there is no active span and opentracing is enabled.
ret (object): return value if opentracing is None or there is no active span. ret: return value if opentracing is None or there is no active span.
Returns (object): The result of the func or ret if opentracing is disabled or there Returns:
The result of the func, falling back to ret if opentracing is disabled or there
was no active span. was no active span.
""" """
def ensure_active_span_inner_1(func): def ensure_active_span_inner_1(
func: Callable[P, R]
) -> Callable[P, Union[Optional[T], R]]:
@wraps(func) @wraps(func)
def ensure_active_span_inner_2(*args, **kwargs): def ensure_active_span_inner_2(
*args: P.args, **kwargs: P.kwargs
) -> Union[Optional[T], R]:
if not opentracing: if not opentracing:
return ret return ret
@ -464,7 +488,7 @@ def start_active_span(
finish_on_close: bool = True, finish_on_close: bool = True,
*, *,
tracer: Optional["opentracing.Tracer"] = None, tracer: Optional["opentracing.Tracer"] = None,
): ) -> "opentracing.Scope":
"""Starts an active opentracing span. """Starts an active opentracing span.
Records the start time for the span, and sets it as the "active span" in the Records the start time for the span, and sets it as the "active span" in the
@ -502,7 +526,7 @@ def start_active_span_follows_from(
*, *,
inherit_force_tracing: bool = False, inherit_force_tracing: bool = False,
tracer: Optional["opentracing.Tracer"] = None, tracer: Optional["opentracing.Tracer"] = None,
): ) -> "opentracing.Scope":
"""Starts an active opentracing span, with additional references to previous spans """Starts an active opentracing span, with additional references to previous spans
Args: Args:
@ -717,7 +741,9 @@ def inject_response_headers(response_headers: Headers) -> None:
response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}") response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
@ensure_active_span("get the active span context as a dict", ret={}) @ensure_active_span(
"get the active span context as a dict", ret=cast(Dict[str, str], {})
)
def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]: def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
""" """
Gets a span context as a dict. This can be used instead of manually Gets a span context as a dict. This can be used instead of manually
@ -886,7 +912,7 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
for i, arg in enumerate(argspec.args[1:]): for i, arg in enumerate(argspec.args[1:]):
set_tag("ARG_" + arg, args[i]) # type: ignore[index] set_tag("ARG_" + arg, args[i]) # type: ignore[index]
set_tag("args", args[len(argspec.args) :]) # type: ignore[index] set_tag("args", args[len(argspec.args) :]) # type: ignore[index]
set_tag("kwargs", kwargs) set_tag("kwargs", str(kwargs))
return func(*args, **kwargs) return func(*args, **kwargs)
return _tag_args_inner return _tag_args_inner

View file

@ -235,7 +235,7 @@ def run_as_background_process(
f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)} f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)}
) )
else: else:
ctx = nullcontext() ctx = nullcontext() # type: ignore[assignment]
with ctx: with ctx:
return await func(*args, **kwargs) return await func(*args, **kwargs)
except Exception: except Exception:

View file

@ -208,7 +208,9 @@ class KeyChangesServlet(RestServlet):
# We want to enforce they do pass us one, but we ignore it and return # We want to enforce they do pass us one, but we ignore it and return
# changes after the "to" as well as before. # changes after the "to" as well as before.
set_tag("to", parse_string(request, "to")) #
# XXX This does not enforce that "to" is passed.
set_tag("to", str(parse_string(request, "to")))
from_token = await StreamToken.from_string(self.store, from_token_string) from_token = await StreamToken.from_string(self.store, from_token_string)

View file

@ -436,7 +436,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
(user_id, device_id), None (user_id, device_id), None
) )
set_tag("last_deleted_stream_id", last_deleted_stream_id) set_tag("last_deleted_stream_id", str(last_deleted_stream_id))
if last_deleted_stream_id: if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed( has_changed = self._device_inbox_stream_cache.has_entity_changed(

View file

@ -706,8 +706,8 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
else: else:
results[user_id] = await self.get_cached_devices_for_user(user_id) results[user_id] = await self.get_cached_devices_for_user(user_id)
set_tag("in_cache", results) set_tag("in_cache", str(results))
set_tag("not_in_cache", user_ids_not_in_cache) set_tag("not_in_cache", str(user_ids_not_in_cache))
return user_ids_not_in_cache, results return user_ids_not_in_cache, results

View file

@ -146,7 +146,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
key data. The key data will be a dict in the same format as the key data. The key data will be a dict in the same format as the
DeviceKeys type returned by POST /_matrix/client/r0/keys/query. DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
""" """
set_tag("query_list", query_list) set_tag("query_list", str(query_list))
if not query_list: if not query_list:
return {} return {}
@ -418,7 +418,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None: def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
set_tag("user_id", user_id) set_tag("user_id", user_id)
set_tag("device_id", device_id) set_tag("device_id", device_id)
set_tag("new_keys", new_keys) set_tag("new_keys", str(new_keys))
# We are protected from race between lookup and insertion due to # We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to # a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only # `add_e2e_one_time_keys` then they'll conflict and we will only
@ -1161,7 +1161,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
set_tag("user_id", user_id) set_tag("user_id", user_id)
set_tag("device_id", device_id) set_tag("device_id", device_id)
set_tag("time_now", time_now) set_tag("time_now", time_now)
set_tag("device_keys", device_keys) set_tag("device_keys", str(device_keys))
old_key_json = self.db_pool.simple_select_one_onecol_txn( old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn, txn,

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import cast
from twisted.internet import defer from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactorClock from twisted.test.proto_helpers import MemoryReactorClock
@ -40,6 +42,15 @@ from tests.unittest import TestCase
class LogContextScopeManagerTestCase(TestCase): class LogContextScopeManagerTestCase(TestCase):
"""
Test logging contexts and active opentracing spans.
There's casts throughout this from generic opentracing objects (e.g.
opentracing.Span) to the ones specific to Jaeger since they have additional
properties that these tests depend on. This is safe since the only supported
opentracing backend is Jaeger.
"""
if LogContextScopeManager is None: if LogContextScopeManager is None:
skip = "Requires opentracing" # type: ignore[unreachable] skip = "Requires opentracing" # type: ignore[unreachable]
if jaeger_client is None: if jaeger_client is None:
@ -69,7 +80,7 @@ class LogContextScopeManagerTestCase(TestCase):
# start_active_span should start and activate a span. # start_active_span should start and activate a span.
scope = start_active_span("span", tracer=self._tracer) scope = start_active_span("span", tracer=self._tracer)
span = scope.span span = cast(jaeger_client.Span, scope.span)
self.assertEqual(self._tracer.active_span, span) self.assertEqual(self._tracer.active_span, span)
self.assertIsNotNone(span.start_time) self.assertIsNotNone(span.start_time)
@ -91,6 +102,7 @@ class LogContextScopeManagerTestCase(TestCase):
with LoggingContext("root context"): with LoggingContext("root context"):
with start_active_span("root span", tracer=self._tracer) as root_scope: with start_active_span("root span", tracer=self._tracer) as root_scope:
self.assertEqual(self._tracer.active_span, root_scope.span) self.assertEqual(self._tracer.active_span, root_scope.span)
root_context = cast(jaeger_client.SpanContext, root_scope.span.context)
scope1 = start_active_span( scope1 = start_active_span(
"child1", "child1",
@ -99,9 +111,8 @@ class LogContextScopeManagerTestCase(TestCase):
self.assertEqual( self.assertEqual(
self._tracer.active_span, scope1.span, "child1 was not activated" self._tracer.active_span, scope1.span, "child1 was not activated"
) )
self.assertEqual( context1 = cast(jaeger_client.SpanContext, scope1.span.context)
scope1.span.context.parent_id, root_scope.span.context.span_id self.assertEqual(context1.parent_id, root_context.span_id)
)
scope2 = start_active_span_follows_from( scope2 = start_active_span_follows_from(
"child2", "child2",
@ -109,17 +120,18 @@ class LogContextScopeManagerTestCase(TestCase):
tracer=self._tracer, tracer=self._tracer,
) )
self.assertEqual(self._tracer.active_span, scope2.span) self.assertEqual(self._tracer.active_span, scope2.span)
self.assertEqual( context2 = cast(jaeger_client.SpanContext, scope2.span.context)
scope2.span.context.parent_id, scope1.span.context.span_id self.assertEqual(context2.parent_id, context1.span_id)
)
with scope1, scope2: with scope1, scope2:
pass pass
# the root scope should be restored # the root scope should be restored
self.assertEqual(self._tracer.active_span, root_scope.span) self.assertEqual(self._tracer.active_span, root_scope.span)
self.assertIsNotNone(scope2.span.end_time) span2 = cast(jaeger_client.Span, scope2.span)
self.assertIsNotNone(scope1.span.end_time) span1 = cast(jaeger_client.Span, scope1.span)
self.assertIsNotNone(span2.end_time)
self.assertIsNotNone(span1.end_time)
self.assertIsNone(self._tracer.active_span) self.assertIsNone(self._tracer.active_span)