Add missing types to opentracing. (#13345)

After this change `synapse.logging` is fully typed.
This commit is contained in:
Patrick Cloke 2022-07-21 08:01:52 -04:00 committed by GitHub
parent 190f49d8ab
commit 50122754c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 83 additions and 45 deletions

View file

@ -1 +1 @@
Add type hints to `trace` decorator.
Add missing type hints to open tracing module.

1
changelog.d/13345.misc Normal file
View file

@ -0,0 +1 @@
Add missing type hints to open tracing module.

View file

@ -84,9 +84,6 @@ disallow_untyped_defs = False
[mypy-synapse.http.matrixfederationclient]
disallow_untyped_defs = False
[mypy-synapse.logging.opentracing]
disallow_untyped_defs = False
[mypy-synapse.metrics._reactor_metrics]
disallow_untyped_defs = False
# This module imports select.epoll. That exists on Linux, but doesn't on macOS.

View file

@ -309,7 +309,7 @@ class BaseFederationServlet:
raise
# update the active opentracing span with the authenticated entity
set_tag("authenticated_entity", origin)
set_tag("authenticated_entity", str(origin))
# if the origin is authenticated and whitelisted, use its span context
# as the parent.

View file

@ -118,8 +118,8 @@ class DeviceWorkerHandler:
ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
set_tag("device", device)
set_tag("ips", ips)
set_tag("device", str(device))
set_tag("ips", str(ips))
return device
@ -170,7 +170,7 @@ class DeviceWorkerHandler:
"""
set_tag("user_id", user_id)
set_tag("from_token", from_token)
set_tag("from_token", str(from_token))
now_room_key = self.store.get_room_max_token()
room_ids = await self.store.get_rooms_for_user(user_id)
@ -795,7 +795,7 @@ class DeviceListUpdater:
"""
set_tag("origin", origin)
set_tag("edu_content", edu_content)
set_tag("edu_content", str(edu_content))
user_id = edu_content.pop("user_id")
device_id = edu_content.pop("device_id")
stream_id = str(edu_content.pop("stream_id")) # They may come as ints

View file

@ -138,8 +138,8 @@ class E2eKeysHandler:
else:
remote_queries[user_id] = device_ids
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))
# First get local devices.
# A map of destination -> failure response.
@ -343,7 +343,7 @@ class E2eKeysHandler:
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
set_tag("reason", failure)
set_tag("reason", str(failure))
return
@ -405,7 +405,7 @@ class E2eKeysHandler:
Returns:
A map from user_id -> device_id -> device details
"""
set_tag("local_query", query)
set_tag("local_query", str(query))
local_query: List[Tuple[str, Optional[str]]] = []
result_dict: Dict[str, Dict[str, dict]] = {}
@ -477,8 +477,8 @@ class E2eKeysHandler:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = one_time_keys
set_tag("local_key_query", local_query)
set_tag("remote_key_query", remote_queries)
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))
results = await self.store.claim_e2e_one_time_keys(local_query)
@ -508,7 +508,7 @@ class E2eKeysHandler:
failure = _exception_to_failure(e)
failures[destination] = failure
set_tag("error", True)
set_tag("reason", failure)
set_tag("reason", str(failure))
await make_deferred_yieldable(
defer.gatherResults(
@ -611,7 +611,7 @@ class E2eKeysHandler:
result = await self.store.count_e2e_one_time_keys(user_id, device_id)
set_tag("one_time_key_counts", result)
set_tag("one_time_key_counts", str(result))
return {"one_time_key_counts": result}
async def _upload_one_time_keys_for_user(

View file

@ -14,7 +14,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Optional
from typing import TYPE_CHECKING, Dict, Optional, cast
from typing_extensions import Literal
@ -97,7 +97,7 @@ class E2eRoomKeysHandler:
user_id, version, room_id, session_id
)
log_kv(results)
log_kv(cast(JsonDict, results))
return results
@trace

View file

@ -182,6 +182,8 @@ from typing import (
Type,
TypeVar,
Union,
cast,
overload,
)
import attr
@ -328,6 +330,7 @@ class _Sentinel(enum.Enum):
P = ParamSpec("P")
R = TypeVar("R")
T = TypeVar("T")
def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
@ -343,22 +346,43 @@ def only_if_tracing(func: Callable[P, R]) -> Callable[P, Optional[R]]:
return _only_if_tracing_inner
def ensure_active_span(message: str, ret=None):
@overload
def ensure_active_span(
message: str,
) -> Callable[[Callable[P, R]], Callable[P, Optional[R]]]:
...
@overload
def ensure_active_span(
message: str, ret: T
) -> Callable[[Callable[P, R]], Callable[P, Union[T, R]]]:
...
def ensure_active_span(
message: str, ret: Optional[T] = None
) -> Callable[[Callable[P, R]], Callable[P, Union[Optional[T], R]]]:
"""Executes the operation only if opentracing is enabled and there is an active span.
If there is no active span it logs message at the error level.
Args:
message: Message which fills in "There was no active span when trying to %s"
in the error log if there is no active span and opentracing is enabled.
ret (object): return value if opentracing is None or there is no active span.
ret: return value if opentracing is None or there is no active span.
Returns (object): The result of the func or ret if opentracing is disabled or there
Returns:
The result of the func, falling back to ret if opentracing is disabled or there
was no active span.
"""
def ensure_active_span_inner_1(func):
def ensure_active_span_inner_1(
func: Callable[P, R]
) -> Callable[P, Union[Optional[T], R]]:
@wraps(func)
def ensure_active_span_inner_2(*args, **kwargs):
def ensure_active_span_inner_2(
*args: P.args, **kwargs: P.kwargs
) -> Union[Optional[T], R]:
if not opentracing:
return ret
@ -464,7 +488,7 @@ def start_active_span(
finish_on_close: bool = True,
*,
tracer: Optional["opentracing.Tracer"] = None,
):
) -> "opentracing.Scope":
"""Starts an active opentracing span.
Records the start time for the span, and sets it as the "active span" in the
@ -502,7 +526,7 @@ def start_active_span_follows_from(
*,
inherit_force_tracing: bool = False,
tracer: Optional["opentracing.Tracer"] = None,
):
) -> "opentracing.Scope":
"""Starts an active opentracing span, with additional references to previous spans
Args:
@ -717,7 +741,9 @@ def inject_response_headers(response_headers: Headers) -> None:
response_headers.addRawHeader("Synapse-Trace-Id", f"{trace_id:x}")
@ensure_active_span("get the active span context as a dict", ret={})
@ensure_active_span(
"get the active span context as a dict", ret=cast(Dict[str, str], {})
)
def get_active_span_text_map(destination: Optional[str] = None) -> Dict[str, str]:
"""
Gets a span context as a dict. This can be used instead of manually
@ -886,7 +912,7 @@ def tag_args(func: Callable[P, R]) -> Callable[P, R]:
for i, arg in enumerate(argspec.args[1:]):
set_tag("ARG_" + arg, args[i]) # type: ignore[index]
set_tag("args", args[len(argspec.args) :]) # type: ignore[index]
set_tag("kwargs", kwargs)
set_tag("kwargs", str(kwargs))
return func(*args, **kwargs)
return _tag_args_inner

View file

@ -235,7 +235,7 @@ def run_as_background_process(
f"bgproc.{desc}", tags={SynapseTags.REQUEST_ID: str(context)}
)
else:
ctx = nullcontext()
ctx = nullcontext() # type: ignore[assignment]
with ctx:
return await func(*args, **kwargs)
except Exception:

View file

@ -208,7 +208,9 @@ class KeyChangesServlet(RestServlet):
# We want to enforce they do pass us one, but we ignore it and return
# changes after the "to" as well as before.
set_tag("to", parse_string(request, "to"))
#
# XXX This does not enforce that "to" is passed.
set_tag("to", str(parse_string(request, "to")))
from_token = await StreamToken.from_string(self.store, from_token_string)

View file

@ -436,7 +436,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
(user_id, device_id), None
)
set_tag("last_deleted_stream_id", last_deleted_stream_id)
set_tag("last_deleted_stream_id", str(last_deleted_stream_id))
if last_deleted_stream_id:
has_changed = self._device_inbox_stream_cache.has_entity_changed(

View file

@ -706,8 +706,8 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
else:
results[user_id] = await self.get_cached_devices_for_user(user_id)
set_tag("in_cache", results)
set_tag("not_in_cache", user_ids_not_in_cache)
set_tag("in_cache", str(results))
set_tag("not_in_cache", str(user_ids_not_in_cache))
return user_ids_not_in_cache, results

View file

@ -146,7 +146,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
key data. The key data will be a dict in the same format as the
DeviceKeys type returned by POST /_matrix/client/r0/keys/query.
"""
set_tag("query_list", query_list)
set_tag("query_list", str(query_list))
if not query_list:
return {}
@ -418,7 +418,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("new_keys", new_keys)
set_tag("new_keys", str(new_keys))
# We are protected from race between lookup and insertion due to
# a unique constraint. If there is a race of two calls to
# `add_e2e_one_time_keys` then they'll conflict and we will only
@ -1161,7 +1161,7 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
set_tag("user_id", user_id)
set_tag("device_id", device_id)
set_tag("time_now", time_now)
set_tag("device_keys", device_keys)
set_tag("device_keys", str(device_keys))
old_key_json = self.db_pool.simple_select_one_onecol_txn(
txn,

View file

@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import cast
from twisted.internet import defer
from twisted.test.proto_helpers import MemoryReactorClock
@ -40,6 +42,15 @@ from tests.unittest import TestCase
class LogContextScopeManagerTestCase(TestCase):
"""
Test logging contexts and active opentracing spans.
There's casts throughout this from generic opentracing objects (e.g.
opentracing.Span) to the ones specific to Jaeger since they have additional
properties that these tests depend on. This is safe since the only supported
opentracing backend is Jaeger.
"""
if LogContextScopeManager is None:
skip = "Requires opentracing" # type: ignore[unreachable]
if jaeger_client is None:
@ -69,7 +80,7 @@ class LogContextScopeManagerTestCase(TestCase):
# start_active_span should start and activate a span.
scope = start_active_span("span", tracer=self._tracer)
span = scope.span
span = cast(jaeger_client.Span, scope.span)
self.assertEqual(self._tracer.active_span, span)
self.assertIsNotNone(span.start_time)
@ -91,6 +102,7 @@ class LogContextScopeManagerTestCase(TestCase):
with LoggingContext("root context"):
with start_active_span("root span", tracer=self._tracer) as root_scope:
self.assertEqual(self._tracer.active_span, root_scope.span)
root_context = cast(jaeger_client.SpanContext, root_scope.span.context)
scope1 = start_active_span(
"child1",
@ -99,9 +111,8 @@ class LogContextScopeManagerTestCase(TestCase):
self.assertEqual(
self._tracer.active_span, scope1.span, "child1 was not activated"
)
self.assertEqual(
scope1.span.context.parent_id, root_scope.span.context.span_id
)
context1 = cast(jaeger_client.SpanContext, scope1.span.context)
self.assertEqual(context1.parent_id, root_context.span_id)
scope2 = start_active_span_follows_from(
"child2",
@ -109,17 +120,18 @@ class LogContextScopeManagerTestCase(TestCase):
tracer=self._tracer,
)
self.assertEqual(self._tracer.active_span, scope2.span)
self.assertEqual(
scope2.span.context.parent_id, scope1.span.context.span_id
)
context2 = cast(jaeger_client.SpanContext, scope2.span.context)
self.assertEqual(context2.parent_id, context1.span_id)
with scope1, scope2:
pass
# the root scope should be restored
self.assertEqual(self._tracer.active_span, root_scope.span)
self.assertIsNotNone(scope2.span.end_time)
self.assertIsNotNone(scope1.span.end_time)
span2 = cast(jaeger_client.Span, scope2.span)
span1 = cast(jaeger_client.Span, scope1.span)
self.assertIsNotNone(span2.end_time)
self.assertIsNotNone(span1.end_time)
self.assertIsNone(self._tracer.active_span)