Draw: readd --extras, arbitrary resolutions

This commit is contained in:
David Robertson 2022-08-04 20:12:00 +01:00
parent df8c0c44cf
commit 7c9388f19e
No known key found for this signature in database
GPG key ID: 903ECE108A39DEDD

View file

@ -3,7 +3,7 @@ import argparse
import logging
import sys
from pprint import pformat
from typing import Awaitable, Callable, Collection, Optional, Tuple, cast
from typing import Awaitable, Callable, Collection, Dict, List, Optional, Tuple, cast
from unittest.mock import MagicMock, patch
import dictdiffer
@ -73,13 +73,15 @@ def node(
event: EventBase, suffix: Optional[str] = None, **kwargs: object
) -> pydot.Node:
if "label" not in kwargs:
label = f"{event.event_id}\n{event.sender}: {(event.type,event.state_key)}"
label = (
f"{event.event_id}\n{event.sender}: {(event.type,event.get_state_key())}"
)
if event.type == "m.room.member":
label += f" ({event.membership.upper()})"
if suffix:
label += f"\n{suffix}"
kwargs["label"] = label
type_to_shape = {} # {"m.room.member": "oval"}
type_to_shape: Dict[str, str] = {} # {"m.room.member": "oval"}
if event.type in type_to_shape:
kwargs.setdefault("shape", type_to_shape[event.type])
@ -97,9 +99,10 @@ def edge(source: EventBase, target: EventBase, **kwargs: object) -> pydot.Edge:
async def dump_mainlines(
hs: MockHomeserver,
starting_event: EventBase,
resolve_point: Optional[EventBase],
events: Collection[EventBase],
extras: Collection[str],
watch_func: Optional[Callable[[EventBase], Awaitable[str]]] = None,
extras: Collection[EventBase] = (),
) -> None:
"""Visualise the auth DAG above a given `starting_event`.
@ -123,21 +126,29 @@ async def dump_mainlines(
suffix = await watch_func(event) if watch_func else None
return node(event, suffix, **kwargs)
graph.add_node(await new_node(starting_event, fillcolor="#6699cc"))
seen = {starting_event.event_id}
seen = set()
todo: List[EventBase] = []
todo = []
for extra in extras:
graph.add_node(await new_node(extra, fillcolor="#cc9966"))
seen.add(extra.event_id)
todo.append(extra)
if resolve_point:
graph.add_node(await new_node(resolve_point, fillcolor="#6699cc"))
seen.add(resolve_point.event_id)
for pid in starting_event.prev_event_ids():
parent = await hs.get_datastores().main.get_event(pid)
for parent in events:
graph.add_node(await new_node(parent, fillcolor="#6699cc"))
seen.add(pid)
graph.add_edge(edge(starting_event, parent, style="dashed"))
seen.add(parent.event_id)
todo.append(parent)
if resolve_point:
graph.add_edge(edge(resolve_point, parent, style="dashed"))
if extras:
logger.debug(extras)
extra_events = await hs.get_datastores().main.get_events(extras)
logger.debug(extra_events)
for extra_event in extra_events.values():
if extra_event.event_id in seen:
continue
graph.add_node(await new_node(extra_event, fillcolor="#6699ee"))
todo.append(extra_event)
async def fetch_auth_events(event: EventBase) -> StateMap[EventBase]:
return {
@ -155,6 +166,8 @@ async def dump_mainlines(
(("m.room.power_levels", ""), "solid"),
(("m.room.join_rules", ""), "solid"),
(("m.room.member", event.sender), "dotted"),
# TODO: handle that state_key might be missing
# (("m.room.member", event.state_key), "solid"),
]:
auth_event = auth_events.get(key)
if auth_event:
@ -189,13 +202,30 @@ parser.add_argument(
"config_file", help="Synapse config file", type=argparse.FileType("r")
)
parser.add_argument("--verbose", "-v", help="Log verbosely", action="store_true")
parser.add_argument("-d", "--draw", help="Render auth DAG", action="store_true")
parser.add_argument(
"--debug", "-d", help="Enter debugger after state is resolved", action="store_true"
"event_ids",
help="""\
The event ID(s) to be resolved.\
If a single event is given, resolve across all of its parents to compute the state
before the given event. If multiple events are given, resolve across them directly.
""",
nargs="+",
)
parser.add_argument(
"-e",
"--extra",
dest="extras",
help=(
"An extra event to include in the auth DAG when using the `--draw` flag. "
"Can be provided multiple times."
),
action="append",
)
parser.add_argument("event_id", help="The event ID to be resolved")
parser.add_argument(
"--watch",
help="Track a piece of state in the auth DAG",
help="Track a piece of state in the auth DAG when using the `--draw` flag.",
default=None,
nargs=2,
metavar=("TYPE", "STATE_KEY"),
@ -213,19 +243,22 @@ async def debug_specific_stateres(
- the recomputed and stored state, written to stdout, and
- their difference, written to stdout.
"""
# Fetch the event in question.
event = await hs.get_datastores().main.get_event(args.event_id)
assert event is not None
logger.info(
"event %s has %d parents, %s",
event.event_id,
len(event.prev_event_ids()),
event.prev_event_ids(),
)
DEBUG_AT_EVENT = len(args.event_ids) == 1
if DEBUG_AT_EVENT:
resolve_point = await hs.get_datastores().main.get_event(args.event_ids[0])
prev_event_ids = resolve_point.prev_event_ids()
else:
resolve_point = None
prev_event_ids = args.event_ids
parent_events = (await hs.get_datastores().main.get_events(prev_event_ids)).values()
sample_event = next(iter(parent_events))
logger.info("Resolving across %d parents, %s", len(prev_event_ids), prev_event_ids)
state_after_parents = [
await hs.get_storage_controllers().state.get_state_ids_for_event(prev_event_id)
for prev_event_id in event.prev_event_ids()
for prev_event_id in prev_event_ids
]
if args.watch is not None:
@ -236,8 +269,10 @@ async def debug_specific_stateres(
async def watch_func(event: EventBase) -> str:
try:
result = await hs.get_storage_controllers().state.get_state_ids_for_event(
event.event_id, filter
result = (
await hs.get_storage_controllers().state.get_state_ids_for_event(
event.event_id, filter
)
)
except RuntimeError:
return f"\n{key_pair}: <Event unavailable :(>"
@ -247,37 +282,31 @@ async def debug_specific_stateres(
else:
watch_func = None
await dump_mainlines(hs, event, watch_func)
if args.draw:
await dump_mainlines(hs, resolve_point, parent_events, args.extras, watch_func)
result = await hs.get_state_resolution_handler().resolve_events_with_store(
event.room_id,
event.room_version.identifier,
sample_event.room_id,
sample_event.room_version.identifier,
state_after_parents,
event_map=None,
state_res_store=StateResolutionStore(hs.get_datastores().main),
)
logger.info("State resolved at %s:", event.event_id)
logger.info("State resolved:")
logger.info(pformat(result))
logger.info("Stored state at %s:", event.event_id)
stored_state = await hs.get_storage_controllers().state.get_state_ids_for_event(
event.event_id
)
logger.info(pformat(stored_state))
# TODO make this a like-for-like comparison.
logger.info("Diff from stored (after event) to resolved (before event):")
for change in dictdiffer.diff(stored_state, result):
logger.info(pformat(change))
if args.debug:
print(
f"see `state_after_parents[i]` for 0 <= i < {len(state_after_parents)}"
" and `result`",
file=sys.stderr,
if DEBUG_AT_EVENT:
logger.info("Stored state at %s:", sample_event.event_id)
stored_state = await hs.get_storage_controllers().state.get_state_ids_for_event(
sample_event.event_id
)
breakpoint()
logger.info(pformat(stored_state))
# TODO make this a like-for-like comparison.
logger.info("Diff from stored (after event) to resolved (before event):")
for change in dictdiffer.diff(stored_state, result):
logger.info(pformat(change))
# Entrypoint.
@ -288,7 +317,7 @@ if __name__ == "__main__":
level=logging.DEBUG if args.verbose else logging.INFO,
stream=sys.stdout,
)
# Suppress logs weren't not interested in.
# Suppress logs we aren't interested in.
logging.getLogger("synapse.util").setLevel(logging.ERROR)
logging.getLogger("synapse.storage").setLevel(logging.ERROR)