Move from _base to events

This commit is contained in:
Erik Johnston 2015-05-14 14:45:22 +01:00
parent 7d6a1dae31
commit 386b7330d2
2 changed files with 247 additions and 232 deletions

View file

@ -15,9 +15,7 @@
import logging import logging
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util import unwrap_deferred
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn, LoggingContext from synapse.util.logcontext import preserve_context_over_fn, LoggingContext
from synapse.util.lrucache import LruCache from synapse.util.lrucache import LruCache
@ -29,7 +27,6 @@ from twisted.internet import defer
from collections import namedtuple, OrderedDict from collections import namedtuple, OrderedDict
import functools import functools
import simplejson as json
import sys import sys
import time import time
import threading import threading
@ -868,234 +865,6 @@ class SQLBaseStore(object):
return self.runInteraction("_simple_max_id", func) return self.runInteraction("_simple_max_id", func)
@defer.inlineCallbacks
def _get_events(self, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False, txn=None):
if not event_ids:
defer.returnValue([])
event_map = self._get_events_from_cache(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
missing_events_ids = [e for e in event_ids if e not in event_map]
if not missing_events_ids:
defer.returnValue([
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
])
def get_missing(txn):
missing_events = unwrap_deferred(self._fetch_events(
txn,
missing_events_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
))
event_map.update(missing_events)
return [
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
]
if not txn:
res = yield self.runInteraction("_get_events", get_missing)
defer.returnValue(res)
else:
defer.returnValue(get_missing(txn))
def _get_events_txn(self, txn, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
return unwrap_deferred(self._get_events(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
txn=txn,
))
def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True):
for get_prev_content in (False, True):
self._get_event_cache.invalidate(event_id, check_redacted,
get_prev_content)
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False):
events = self._get_events_txn(
txn, [event_id],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
return events[0] if events else None
def _get_events_from_cache(self, events, check_redacted, get_prev_content,
allow_rejected):
event_map = {}
for event_id in events:
try:
ret = self._get_event_cache.get(
event_id, check_redacted, get_prev_content
)
if allow_rejected or not ret.rejected_reason:
event_map[event_id] = ret
else:
event_map[event_id] = None
except KeyError:
pass
return event_map
@defer.inlineCallbacks
def _fetch_events(self, txn, events, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not events:
defer.returnValue({})
rows = []
N = 200
for i in range(1 + len(events) / N):
evs = events[i*N:(i + 1)*N]
if not evs:
break
sql = (
"SELECT e.internal_metadata, e.json, r.redacts, rej.event_id "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"]*len(evs)),)
if txn:
txn.execute(sql, evs)
rows.extend(txn.fetchall())
else:
res = yield self._execute("_fetch_events", None, sql, *evs)
rows.extend(res)
res = yield defer.gatherResults(
[
defer.maybeDeferred(
self._get_event_from_row,
txn,
row[0], row[1], row[2],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
rejected_reason=row[3],
)
for row in rows
],
consumeErrors=True,
)
for e in res:
self._get_event_cache.prefill(
e.event_id, check_redacted, get_prev_content, e
)
defer.returnValue({
e.event_id: e
for e in res if e
})
@defer.inlineCallbacks
def _get_event_from_row(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False,
rejected_reason=None):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
def select(txn, *args, **kwargs):
if txn:
return self._simple_select_one_onecol_txn(txn, *args, **kwargs)
else:
return self._simple_select_one_onecol(
*args,
desc="_get_event_from_row", **kwargs
)
def get_event(txn, *args, **kwargs):
if txn:
return self._get_event_txn(txn, *args, **kwargs)
else:
return self.get_event(*args, **kwargs)
if rejected_reason:
rejected_reason = yield select(
txn,
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
)
ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
if check_redacted and redacted:
ev = prune_event(ev)
redaction_id = yield select(
txn,
table="redactions",
keyvalues={"redacts": ev.event_id},
retcol="event_id",
)
ev.unsigned["redacted_by"] = redaction_id
# Get the redaction event.
because = yield get_event(
txn,
redaction_id,
check_redacted=False
)
if because:
ev.unsigned["redacted_because"] = because
if get_prev_content and "replaces_state" in ev.unsigned:
prev = yield get_event(
txn,
ev.unsigned["replaces_state"],
get_prev_content=False,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
defer.returnValue(ev)
def _parse_events(self, rows):
return self.runInteraction(
"_parse_events", self._parse_events_txn, rows
)
def _parse_events_txn(self, txn, rows):
event_ids = [r["event_id"] for r in rows]
return self._get_events_txn(txn, event_ids)
def _has_been_redacted_txn(self, txn, event):
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
txn.execute(sql, (event.event_id,))
result = txn.fetchone()
return result[0] if result else None
def get_next_stream_id(self): def get_next_stream_id(self):
with self._next_stream_id_lock: with self._next_stream_id_lock:
i = self._next_stream_id i = self._next_stream_id

View file

@ -17,6 +17,10 @@ from _base import SQLBaseStore, _RollbackButIsFineException
from twisted.internet import defer from twisted.internet import defer
from synapse.events import FrozenEvent
from synapse.events.utils import prune_event
from synapse.util import unwrap_deferred
from synapse.util.logutils import log_function from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes from synapse.api.constants import EventTypes
from synapse.crypto.event_signing import compute_event_reference_hash from synapse.crypto.event_signing import compute_event_reference_hash
@ -26,6 +30,7 @@ from syutil.jsonutil import encode_canonical_json
from contextlib import contextmanager from contextlib import contextmanager
import logging import logging
import simplejson as json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -393,3 +398,244 @@ class EventsStore(SQLBaseStore):
return self.runInteraction( return self.runInteraction(
"have_events", f, "have_events", f,
) )
@defer.inlineCallbacks
def _get_events(self, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False, txn=None):
if not event_ids:
defer.returnValue([])
event_map = self._get_events_from_cache(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
missing_events_ids = [e for e in event_ids if e not in event_map]
if not missing_events_ids:
defer.returnValue([
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
])
if not txn:
missing_events = yield self.runInteraction(
"_get_events",
self._fetch_events_txn,
missing_events_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
else:
missing_events = yield self._fetch_events(
txn,
missing_events_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
event_map.update(missing_events)
defer.returnValue([
event_map[e_id] for e_id in event_ids
if e_id in event_map and event_map[e_id]
])
def _get_events_txn(self, txn, event_ids, check_redacted=True,
get_prev_content=False, allow_rejected=False):
return unwrap_deferred(self._get_events(
event_ids,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
txn=txn,
))
def _invalidate_get_event_cache(self, event_id):
for check_redacted in (False, True):
for get_prev_content in (False, True):
self._get_event_cache.invalidate(event_id, check_redacted,
get_prev_content)
def _get_event_txn(self, txn, event_id, check_redacted=True,
get_prev_content=False, allow_rejected=False):
events = self._get_events_txn(
txn, [event_id],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
return events[0] if events else None
def _get_events_from_cache(self, events, check_redacted, get_prev_content,
allow_rejected):
event_map = {}
for event_id in events:
try:
ret = self._get_event_cache.get(
event_id, check_redacted, get_prev_content
)
if allow_rejected or not ret.rejected_reason:
event_map[event_id] = ret
else:
event_map[event_id] = None
except KeyError:
pass
return event_map
def _fetch_events_txn(self, txn, events, check_redacted=True,
get_prev_content=False, allow_rejected=False):
return unwrap_deferred(self._fetch_events(
txn, events,
check_redacted=check_redacted,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
))
@defer.inlineCallbacks
def _fetch_events(self, txn, events, check_redacted=True,
get_prev_content=False, allow_rejected=False):
if not events:
defer.returnValue({})
rows = []
N = 200
for i in range(1 + len(events) / N):
evs = events[i*N:(i + 1)*N]
if not evs:
break
sql = (
"SELECT e.internal_metadata, e.json, r.redacts, rej.event_id "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"]*len(evs)),)
if txn:
txn.execute(sql, evs)
rows.extend(txn.fetchall())
else:
res = yield self._execute("_fetch_events", None, sql, *evs)
rows.extend(res)
res = yield defer.gatherResults(
[
defer.maybeDeferred(
self._get_event_from_row,
txn,
row[0], row[1], row[2],
check_redacted=check_redacted,
get_prev_content=get_prev_content,
rejected_reason=row[3],
)
for row in rows
],
consumeErrors=True,
)
for e in res:
self._get_event_cache.prefill(
e.event_id, check_redacted, get_prev_content, e
)
defer.returnValue({
e.event_id: e
for e in res if e
})
@defer.inlineCallbacks
def _get_event_from_row(self, txn, internal_metadata, js, redacted,
check_redacted=True, get_prev_content=False,
rejected_reason=None):
d = json.loads(js)
internal_metadata = json.loads(internal_metadata)
def select(txn, *args, **kwargs):
if txn:
return self._simple_select_one_onecol_txn(txn, *args, **kwargs)
else:
return self._simple_select_one_onecol(
*args,
desc="_get_event_from_row", **kwargs
)
def get_event(txn, *args, **kwargs):
if txn:
return self._get_event_txn(txn, *args, **kwargs)
else:
return self.get_event(*args, **kwargs)
if rejected_reason:
rejected_reason = yield select(
txn,
table="rejections",
keyvalues={"event_id": rejected_reason},
retcol="reason",
)
ev = FrozenEvent(
d,
internal_metadata_dict=internal_metadata,
rejected_reason=rejected_reason,
)
if check_redacted and redacted:
ev = prune_event(ev)
redaction_id = yield select(
txn,
table="redactions",
keyvalues={"redacts": ev.event_id},
retcol="event_id",
)
ev.unsigned["redacted_by"] = redaction_id
# Get the redaction event.
because = yield get_event(
txn,
redaction_id,
check_redacted=False
)
if because:
ev.unsigned["redacted_because"] = because
if get_prev_content and "replaces_state" in ev.unsigned:
prev = yield get_event(
txn,
ev.unsigned["replaces_state"],
get_prev_content=False,
)
if prev:
ev.unsigned["prev_content"] = prev.get_dict()["content"]
defer.returnValue(ev)
def _parse_events(self, rows):
return self.runInteraction(
"_parse_events", self._parse_events_txn, rows
)
def _parse_events_txn(self, txn, rows):
event_ids = [r["event_id"] for r in rows]
return self._get_events_txn(txn, event_ids)
def _has_been_redacted_txn(self, txn, event):
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
txn.execute(sql, (event.event_id,))
result = txn.fetchone()
return result[0] if result else None