diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index c51b641125..e1f374807e 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -50,7 +50,7 @@ block_db_txn_duration = metrics.register_distribution( class Measure(object): __slots__ = [ "clock", "name", "start_context", "start", "new_context", "ru_utime", - "ru_stime", "db_txn_count", "db_txn_duration" + "ru_stime", "db_txn_count", "db_txn_duration", "created_context" ] def __init__(self, clock, name): @@ -58,14 +58,20 @@ class Measure(object): self.name = name self.start_context = None self.start = None + self.created_context = False def __enter__(self): self.start = self.clock.time_msec() self.start_context = LoggingContext.current_context() - if self.start_context: - self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() - self.db_txn_count = self.start_context.db_txn_count - self.db_txn_duration = self.start_context.db_txn_duration + if not self.start_context: + logger.warn("Entered Measure without log context: %s", self.name) + self.start_context = LoggingContext("Measure") + self.start_context.__enter__() + self.created_context = True + + self.ru_utime, self.ru_stime = self.start_context.get_resource_usage() + self.db_txn_count = self.start_context.db_txn_count + self.db_txn_duration = self.start_context.db_txn_duration def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None or not self.start_context: @@ -91,7 +97,12 @@ class Measure(object): block_ru_utime.inc_by(ru_utime - self.ru_utime, self.name) block_ru_stime.inc_by(ru_stime - self.ru_stime, self.name) - block_db_txn_count.inc_by(context.db_txn_count - self.db_txn_count, self.name) + block_db_txn_count.inc_by( + context.db_txn_count - self.db_txn_count, self.name + ) block_db_txn_duration.inc_by( context.db_txn_duration - self.db_txn_duration, self.name ) + + if self.created_context: + self.start_context.__exit__(exc_type, exc_val, exc_tb) diff --git a/tests/test_state.py b/tests/test_state.py index a1ea7ef672..1a11bbcee0 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -140,13 +140,13 @@ class StateTestCase(unittest.TestCase): "add_event_hashes", ] ) - hs = Mock(spec=[ + hs = Mock(spec_set=[ "get_datastore", "get_auth", "get_state_handler", "get_clock", ]) hs.get_datastore.return_value = self.store hs.get_state_handler.return_value = None - hs.get_auth.return_value = Auth(hs) hs.get_clock.return_value = MockClock() + hs.get_auth.return_value = Auth(hs) self.state = StateHandler(hs) self.event_id = 0