Merge pull request #288 from matrix-org/markjh/unused_definitions

Remove some of the unused definitions from synapse
This commit is contained in:
Mark Haines 2015-09-28 14:22:44 +01:00
commit 301141515a
18 changed files with 162 additions and 588 deletions

142
scripts-dev/definitions.py Executable file
View file

@ -0,0 +1,142 @@
#! /usr/bin/python
import ast
import yaml
class DefinitionVisitor(ast.NodeVisitor):
def __init__(self):
super(DefinitionVisitor, self).__init__()
self.functions = {}
self.classes = {}
self.names = {}
self.attrs = set()
self.definitions = {
'def': self.functions,
'class': self.classes,
'names': self.names,
'attrs': self.attrs,
}
def visit_Name(self, node):
self.names.setdefault(type(node.ctx).__name__, set()).add(node.id)
def visit_Attribute(self, node):
self.attrs.add(node.attr)
for child in ast.iter_child_nodes(node):
self.visit(child)
def visit_ClassDef(self, node):
visitor = DefinitionVisitor()
self.classes[node.name] = visitor.definitions
for child in ast.iter_child_nodes(node):
visitor.visit(child)
def visit_FunctionDef(self, node):
visitor = DefinitionVisitor()
self.functions[node.name] = visitor.definitions
for child in ast.iter_child_nodes(node):
visitor.visit(child)
def non_empty(defs):
functions = {name: non_empty(f) for name, f in defs['def'].items()}
classes = {name: non_empty(f) for name, f in defs['class'].items()}
result = {}
if functions: result['def'] = functions
if classes: result['class'] = classes
names = defs['names']
uses = []
for name in names.get('Load', ()):
if name not in names.get('Param', ()) and name not in names.get('Store', ()):
uses.append(name)
uses.extend(defs['attrs'])
if uses: result['uses'] = uses
result['names'] = names
result['attrs'] = defs['attrs']
return result
def definitions_in_code(input_code):
input_ast = ast.parse(input_code)
visitor = DefinitionVisitor()
visitor.visit(input_ast)
definitions = non_empty(visitor.definitions)
return definitions
def definitions_in_file(filepath):
with open(filepath) as f:
return definitions_in_code(f.read())
def defined_names(prefix, defs, names):
for name, funcs in defs.get('def', {}).items():
names.setdefault(name, {'defined': []})['defined'].append(prefix + name)
defined_names(prefix + name + ".", funcs, names)
for name, funcs in defs.get('class', {}).items():
names.setdefault(name, {'defined': []})['defined'].append(prefix + name)
defined_names(prefix + name + ".", funcs, names)
def used_names(prefix, defs, names):
for name, funcs in defs.get('def', {}).items():
used_names(prefix + name + ".", funcs, names)
for name, funcs in defs.get('class', {}).items():
used_names(prefix + name + ".", funcs, names)
for used in defs.get('uses', ()):
if used in names:
names[used].setdefault('used', []).append(prefix.rstrip('.'))
if __name__ == '__main__':
import sys, os, argparse, re
parser = argparse.ArgumentParser(description='Find definitions.')
parser.add_argument(
"--unused", action="store_true", help="Only list unused definitions"
)
parser.add_argument(
"--ignore", action="append", metavar="REGEXP", help="Ignore a pattern"
)
parser.add_argument(
"--pattern", action="append", metavar="REGEXP",
help="Search for a pattern"
)
parser.add_argument(
"directories", nargs='+', metavar="DIR",
help="Directories to search for definitions"
)
args = parser.parse_args()
definitions = {}
for directory in args.directories:
for root, dirs, files in os.walk(directory):
for filename in files:
if filename.endswith(".py"):
filepath = os.path.join(root, filename)
definitions[filepath] = definitions_in_file(filepath)
names = {}
for filepath, defs in definitions.items():
defined_names(filepath + ":", defs, names)
for filepath, defs in definitions.items():
used_names(filepath + ":", defs, names)
patterns = [re.compile(pattern) for pattern in args.pattern or ()]
ignore = [re.compile(pattern) for pattern in args.ignore or ()]
result = {}
for name, definition in names.items():
if patterns and not any(pattern.match(name) for pattern in patterns):
continue
if ignore and any(pattern.match(name) for pattern in ignore):
continue
if args.unused and definition.get('used'):
continue
result[name] = definition
yaml.dump(result, sys.stdout, default_flow_style=False)

View file

@ -95,8 +95,6 @@ class Store(object):
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
_execute_and_decode = SQLBaseStore.__dict__["_execute_and_decode"]
def runInteraction(self, desc, func, *args, **kwargs): def runInteraction(self, desc, func, *args, **kwargs):
def r(conn): def r(conn):
try: try:

View file

@ -77,11 +77,6 @@ class SynapseError(CodeMessageException):
) )
class RoomError(SynapseError):
"""An error raised when a room event fails."""
pass
class RegistrationError(SynapseError): class RegistrationError(SynapseError):
"""An error raised when a registration event fails.""" """An error raised when a registration event fails."""
pass pass

View file

@ -85,12 +85,6 @@ import time
logger = logging.getLogger("synapse.app.homeserver") logger = logging.getLogger("synapse.app.homeserver")
class GzipFile(File):
def getChild(self, path, request):
child = File.getChild(self, path, request)
return EncodingResourceWrapper(child, [GzipEncoderFactory()])
def gz_wrap(r): def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()]) return EncodingResourceWrapper(r, [GzipEncoderFactory()])
@ -134,6 +128,7 @@ class SynapseHomeServer(HomeServer):
# (It can stay enabled for the API resources: they call # (It can stay enabled for the API resources: they call
# write() with the whole body and then finish() straight # write() with the whole body and then finish() straight
# after and so do not trigger the bug. # after and so do not trigger the bug.
# GzipFile was removed in commit 184ba09
# return GzipFile(webclient_path) # TODO configurable? # return GzipFile(webclient_path) # TODO configurable?
return File(webclient_path) # TODO configurable? return File(webclient_path) # TODO configurable?

View file

@ -1456,52 +1456,3 @@ class FederationHandler(BaseHandler):
}, },
"missing": [e.event_id for e in missing_locals], "missing": [e.event_id for e in missing_locals],
}) })
@defer.inlineCallbacks
def _handle_auth_events(self, origin, auth_events):
auth_ids_to_deferred = {}
def process_auth_ev(ev):
auth_ids = [e_id for e_id, _ in ev.auth_events]
prev_ds = [
auth_ids_to_deferred[i]
for i in auth_ids
if i in auth_ids_to_deferred
]
d = defer.Deferred()
auth_ids_to_deferred[ev.event_id] = d
@defer.inlineCallbacks
def f(*_):
ev.internal_metadata.outlier = True
try:
auth = {
(e.type, e.state_key): e for e in auth_events
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, ev, auth_events=auth
)
except:
logger.exception(
"Failed to handle auth event %s",
ev.event_id,
)
d.callback(None)
if prev_ds:
dx = defer.DeferredList(prev_ds)
dx.addBoth(f)
else:
f()
for e in auth_events:
process_auth_ev(e)
yield defer.DeferredList(auth_ids_to_deferred.values())

View file

@ -492,32 +492,6 @@ class RoomMemberHandler(BaseHandler):
"user_joined_room", user=user, room_id=room_id "user_joined_room", user=user, room_id=room_id
) )
@defer.inlineCallbacks
def _should_invite_join(self, room_id, prev_state, do_auth):
logger.debug("_should_invite_join: room_id: %s", room_id)
# XXX: We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
# Only do an invite join dance if a) we were invited,
# b) the person inviting was from a differnt HS and c) we are
# not currently in the room
room_host = None
if prev_state and prev_state.membership == Membership.INVITE:
room = yield self.store.get_room(room_id)
inviter = UserID.from_string(
prev_state.sender
)
is_remote_invite_join = not self.hs.is_mine(inviter) and not room
room_host = inviter.domain
else:
is_remote_invite_join = False
defer.returnValue((is_remote_invite_join, room_host))
@defer.inlineCallbacks @defer.inlineCallbacks
def get_joined_rooms_for_user(self, user): def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given """Returns a list of roomids that the user has any of the given

View file

@ -31,10 +31,6 @@ import hashlib
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _get_state_key_from_event(event):
return event.state_key
KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))

View file

@ -25,8 +25,6 @@ from util.id_generators import IdGenerator, StreamIdGenerator
from twisted.internet import defer from twisted.internet import defer
from collections import namedtuple
import sys import sys
import time import time
import threading import threading
@ -376,9 +374,6 @@ class SQLBaseStore(object):
return self.runInteraction(desc, interaction) return self.runInteraction(desc, interaction)
def _execute_and_decode(self, desc, query, *args):
return self._execute(desc, self.cursor_to_dict, query, *args)
# "Simple" SQL API methods that operate on a single table with no JOINs, # "Simple" SQL API methods that operate on a single table with no JOINs,
# no complex WHERE clauses, just a dict of values for columns. # no complex WHERE clauses, just a dict of values for columns.
@ -691,37 +686,6 @@ class SQLBaseStore(object):
return dict(zip(retcols, row)) return dict(zip(retcols, row))
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
retcols=None, allow_none=False,
desc="_simple_selectupdate_one"):
""" Combined SELECT then UPDATE."""
def func(txn):
ret = None
if retcols:
ret = self._simple_select_one_txn(
txn,
table=table,
keyvalues=keyvalues,
retcols=retcols,
allow_none=allow_none,
)
if updatevalues:
self._simple_update_one_txn(
txn,
table=table,
keyvalues=keyvalues,
updatevalues=updatevalues,
)
# if txn.rowcount == 0:
# raise StoreError(404, "No row found")
if txn.rowcount > 1:
raise StoreError(500, "More than one row matched")
return ret
return self.runInteraction(desc, func)
def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"): def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
"""Executes a DELETE query on the named table, expecting to delete a """Executes a DELETE query on the named table, expecting to delete a
single row. single row.
@ -743,16 +707,6 @@ class SQLBaseStore(object):
raise StoreError(500, "more than one row matched") raise StoreError(500, "more than one row matched")
return self.runInteraction(desc, func) return self.runInteraction(desc, func)
def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
"""Executes a DELETE query on the named table.
Args:
table : string giving the table name
keyvalues : dict of column names and values to select the row with
"""
return self.runInteraction(desc, self._simple_delete_txn)
def _simple_delete_txn(self, txn, table, keyvalues): def _simple_delete_txn(self, txn, table, keyvalues):
sql = "DELETE FROM %s WHERE %s" % ( sql = "DELETE FROM %s WHERE %s" % (
table, table,
@ -761,24 +715,6 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values()) return txn.execute(sql, keyvalues.values())
def _simple_max_id(self, table):
"""Executes a SELECT query on the named table, expecting to return the
max value for the column "id".
Args:
table : string giving the table name
"""
sql = "SELECT MAX(id) AS id FROM %s" % table
def func(txn):
txn.execute(sql)
max_id = self.cursor_to_dict(txn)[0]["id"]
if max_id is None:
return 0
return max_id
return self.runInteraction("_simple_max_id", func)
def get_next_stream_id(self): def get_next_stream_id(self):
with self._next_stream_id_lock: with self._next_stream_id_lock:
i = self._next_stream_id i = self._next_stream_id
@ -791,129 +727,3 @@ class _RollbackButIsFineException(Exception):
something went wrong. something went wrong.
""" """
pass pass
class Table(object):
""" A base class used to store information about a particular table.
"""
table_name = None
""" str: The name of the table """
fields = None
""" list: The field names """
EntryType = None
""" Type: A tuple type used to decode the results """
_select_where_clause = "SELECT %s FROM %s WHERE %s"
_select_clause = "SELECT %s FROM %s"
_insert_clause = "REPLACE INTO %s (%s) VALUES (%s)"
@classmethod
def select_statement(cls, where_clause=None):
"""
Args:
where_clause (str): The WHERE clause to use.
Returns:
str: An SQL statement to select rows from the table with the given
WHERE clause.
"""
if where_clause:
return cls._select_where_clause % (
", ".join(cls.fields),
cls.table_name,
where_clause
)
else:
return cls._select_clause % (
", ".join(cls.fields),
cls.table_name,
)
@classmethod
def insert_statement(cls):
return cls._insert_clause % (
cls.table_name,
", ".join(cls.fields),
", ".join(["?"] * len(cls.fields)),
)
@classmethod
def decode_single_result(cls, results):
""" Given an iterable of tuples, return a single instance of
`EntryType` or None if the iterable is empty
Args:
results (list): The results list to convert to `EntryType`
Returns:
EntryType: An instance of `EntryType`
"""
results = list(results)
if results:
return cls.EntryType(*results[0])
else:
return None
@classmethod
def decode_results(cls, results):
""" Given an iterable of tuples, return a list of `EntryType`
Args:
results (list): The results list to convert to `EntryType`
Returns:
list: A list of `EntryType`
"""
return [cls.EntryType(*row) for row in results]
@classmethod
def get_fields_string(cls, prefix=None):
if prefix:
to_join = ("%s.%s" % (prefix, f) for f in cls.fields)
else:
to_join = cls.fields
return ", ".join(to_join)
class JoinHelper(object):
""" Used to help do joins on tables by looking at the tables' fields and
creating a list of unique fields to use with SELECTs and a namedtuple
to dump the results into.
Attributes:
tables (list): List of `Table` classes
EntryType (type)
"""
def __init__(self, *tables):
self.tables = tables
res = []
for table in self.tables:
res += [f for f in table.fields if f not in res]
self.EntryType = namedtuple("JoinHelperEntry", res)
def get_fields(self, **prefixes):
"""Get a string representing a list of fields for use in SELECT
statements with the given prefixes applied to each.
For example::
JoinHelper(PdusTable, StateTable).get_fields(
PdusTable="pdus",
StateTable="state"
)
"""
res = []
for field in self.EntryType._fields:
for table in self.tables:
if field in table.fields:
res.append("%s.%s" % (prefixes[table.__name__], field))
break
return ", ".join(res)
def decode_results(self, rows):
return [self.EntryType(*row) for row in rows]

View file

@ -154,98 +154,6 @@ class EventFederationStore(SQLBaseStore):
return results return results
def _get_latest_state_in_room(self, txn, room_id, type, state_key):
event_ids = self._simple_select_onecol_txn(
txn,
table="state_forward_extremities",
keyvalues={
"room_id": room_id,
"type": type,
"state_key": state_key,
},
retcol="event_id",
)
results = []
for event_id in event_ids:
hashes = self._get_event_reference_hashes_txn(txn, event_id)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
if k == "sha256"
}
results.append((event_id, prev_hashes))
return results
def _get_prev_events(self, txn, event_id):
results = self._get_prev_events_and_state(
txn,
event_id,
is_state=0,
)
return [(e_id, h, ) for e_id, h, _ in results]
def _get_prev_state(self, txn, event_id):
results = self._get_prev_events_and_state(
txn,
event_id,
is_state=True,
)
return [(e_id, h, ) for e_id, h, _ in results]
def _get_prev_events_and_state(self, txn, event_id, is_state=None):
keyvalues = {
"event_id": event_id,
}
if is_state is not None:
keyvalues["is_state"] = bool(is_state)
res = self._simple_select_list_txn(
txn,
table="event_edges",
keyvalues=keyvalues,
retcols=["prev_event_id", "is_state"],
)
hashes = self._get_prev_event_hashes_txn(txn, event_id)
results = []
for d in res:
edge_hash = self._get_event_reference_hashes_txn(txn, d["prev_event_id"])
edge_hash.update(hashes.get(d["prev_event_id"], {}))
prev_hashes = {
k: encode_base64(v)
for k, v in edge_hash.items()
if k == "sha256"
}
results.append((d["prev_event_id"], prev_hashes, d["is_state"]))
return results
def _get_auth_events(self, txn, event_id):
auth_ids = self._simple_select_onecol_txn(
txn,
table="event_auth",
keyvalues={
"event_id": event_id,
},
retcol="auth_id",
)
results = []
for auth_id in auth_ids:
hashes = self._get_event_reference_hashes_txn(txn, auth_id)
prev_hashes = {
k: encode_base64(v) for k, v in hashes.items()
if k == "sha256"
}
results.append((auth_id, prev_hashes))
return results
def get_min_depth(self, room_id): def get_min_depth(self, room_id):
""" For hte given room, get the minimum depth we have seen for it. """ For hte given room, get the minimum depth we have seen for it.
""" """

View file

@ -890,22 +890,11 @@ class EventsStore(SQLBaseStore):
return ev return ev
def _parse_events(self, rows):
return self.runInteraction(
"_parse_events", self._parse_events_txn, rows
)
def _parse_events_txn(self, txn, rows): def _parse_events_txn(self, txn, rows):
event_ids = [r["event_id"] for r in rows] event_ids = [r["event_id"] for r in rows]
return self._get_events_txn(txn, event_ids) return self._get_events_txn(txn, event_ids)
def _has_been_redacted_txn(self, txn, event):
sql = "SELECT event_id FROM redactions WHERE redacts = ?"
txn.execute(sql, (event.event_id,))
result = txn.fetchone()
return result[0] if result else None
@defer.inlineCallbacks @defer.inlineCallbacks
def count_daily_messages(self): def count_daily_messages(self):
""" """

View file

@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from ._base import SQLBaseStore, Table from ._base import SQLBaseStore
from twisted.internet import defer from twisted.internet import defer
from synapse.api.errors import StoreError from synapse.api.errors import StoreError
@ -149,5 +149,5 @@ class PusherStore(SQLBaseStore):
) )
class PushersTable(Table): class PushersTable(object):
table_name = "pushers" table_name = "pushers"

View file

@ -178,12 +178,6 @@ class RoomMemberStore(SQLBaseStore):
return joined_domains return joined_domains
def _get_members_query(self, where_clause, where_values):
return self.runInteraction(
"get_members_query", self._get_members_events_txn,
where_clause, where_values
).addCallbacks(self._get_events)
def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None): def _get_members_events_txn(self, txn, room_id, membership=None, user_id=None):
rows = self._get_members_rows_txn( rows = self._get_members_rows_txn(
txn, txn,

View file

@ -24,41 +24,6 @@ from synapse.crypto.event_signing import compute_event_reference_hash
class SignatureStore(SQLBaseStore): class SignatureStore(SQLBaseStore):
"""Persistence for event signatures and hashes""" """Persistence for event signatures and hashes"""
def _get_event_content_hashes_txn(self, txn, event_id):
"""Get all the hashes for a given Event.
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
A dict of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"
" FROM event_content_hashes"
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
return dict(txn.fetchall())
def _store_event_content_hash_txn(self, txn, event_id, algorithm,
hash_bytes):
"""Store a hash for a Event
Args:
txn (cursor):
event_id (str): Id for the Event.
algorithm (str): Hashing algorithm.
hash_bytes (bytes): Hash function output bytes.
"""
self._simple_insert_txn(
txn,
"event_content_hashes",
{
"event_id": event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
)
def get_event_reference_hashes(self, event_ids): def get_event_reference_hashes(self, event_ids):
def f(txn): def f(txn):
return [ return [
@ -123,80 +88,3 @@ class SignatureStore(SQLBaseStore):
table="event_reference_hashes", table="event_reference_hashes",
values=vals, values=vals,
) )
def _get_event_signatures_txn(self, txn, event_id):
"""Get all the signatures for a given PDU.
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
A dict of sig name -> dict(key_id -> signature_bytes)
"""
query = (
"SELECT signature_name, key_id, signature"
" FROM event_signatures"
" WHERE event_id = ? "
)
txn.execute(query, (event_id, ))
rows = txn.fetchall()
res = {}
for name, key, sig in rows:
res.setdefault(name, {})[key] = sig
return res
def _store_event_signature_txn(self, txn, event_id, signature_name, key_id,
signature_bytes):
"""Store a signature from the origin server for a PDU.
Args:
txn (cursor):
event_id (str): Id for the Event.
origin (str): origin of the Event.
key_id (str): Id for the signing key.
signature (bytes): The signature.
"""
self._simple_insert_txn(
txn,
"event_signatures",
{
"event_id": event_id,
"signature_name": signature_name,
"key_id": key_id,
"signature": buffer(signature_bytes),
},
)
def _get_prev_event_hashes_txn(self, txn, event_id):
"""Get all the hashes for previous PDUs of a PDU
Args:
txn (cursor):
event_id (str): Id for the Event.
Returns:
dict of (pdu_id, origin) -> dict of algorithm -> hash_bytes.
"""
query = (
"SELECT prev_event_id, algorithm, hash"
" FROM event_edge_hashes"
" WHERE event_id = ?"
)
txn.execute(query, (event_id, ))
results = {}
for prev_event_id, algorithm, hash_bytes in txn.fetchall():
hashes = results.setdefault(prev_event_id, {})
hashes[algorithm] = hash_bytes
return results
def _store_prev_event_hash_txn(self, txn, event_id, prev_event_id,
algorithm, hash_bytes):
self._simple_insert_txn(
txn,
"event_edge_hashes",
{
"event_id": event_id,
"prev_event_id": prev_event_id,
"algorithm": algorithm,
"hash": buffer(hash_bytes),
},
)

View file

@ -20,8 +20,6 @@ from synapse.util.caches.descriptors import (
from twisted.internet import defer from twisted.internet import defer
from synapse.util.stringutils import random_string
import logging import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -428,7 +426,3 @@ class StateStore(SQLBaseStore):
} }
defer.returnValue(results) defer.returnValue(results)
def _make_group_id(clock):
return str(int(clock.time_msec())) + random_string(5)

View file

@ -23,22 +23,6 @@ from synapse.handlers.typing import TypingNotificationEventSource
from synapse.handlers.receipts import ReceiptEventSource from synapse.handlers.receipts import ReceiptEventSource
class NullSource(object):
"""This event source never yields any events and its token remains at
zero. It may be useful for unit-testing."""
def __init__(self, hs):
pass
def get_new_events_for_user(self, user, from_key, limit):
return defer.succeed(([], from_key))
def get_current_key(self, direction='f'):
return defer.succeed(0)
def get_pagination_rows(self, user, pagination_config, key):
return defer.succeed(([], pagination_config.from_key))
class EventSources(object): class EventSources(object):
SOURCE_TYPES = { SOURCE_TYPES = {
"room": RoomEventSource, "room": RoomEventSource,
@ -70,15 +54,3 @@ class EventSources(object):
), ),
) )
defer.returnValue(token) defer.returnValue(token)
class StreamSource(object):
def get_new_events_for_user(self, user, from_key, limit):
"""from_key is the key within this event source."""
raise NotImplementedError("get_new_events_for_user")
def get_current_key(self):
raise NotImplementedError("get_current_key")
def get_pagination_rows(self, user, pagination_config, key):
raise NotImplementedError("get_rows")

View file

@ -29,34 +29,6 @@ def unwrapFirstError(failure):
return failure.value.subFailure return failure.value.subFailure
def unwrap_deferred(d):
"""Given a deferred that we know has completed, return its value or raise
the failure as an exception
"""
if not d.called:
raise RuntimeError("deferred has not finished")
res = []
def f(r):
res.append(r)
return r
d.addCallback(f)
if res:
return res[0]
def f(r):
res.append(r)
return r
d.addErrback(f)
if res:
res[0].raiseException()
else:
raise RuntimeError("deferred did not call callbacks")
class Clock(object): class Clock(object):
"""A small utility that obtains current time-of-day so that time may be """A small utility that obtains current time-of-day so that time may be
mocked during unit-tests. mocked during unit-tests.

View file

@ -41,6 +41,22 @@ myid = "@apple:test"
PATH_PREFIX = "/_matrix/client/api/v1" PATH_PREFIX = "/_matrix/client/api/v1"
class NullSource(object):
"""This event source never yields any events and its token remains at
zero. It may be useful for unit-testing."""
def __init__(self, hs):
pass
def get_new_events_for_user(self, user, from_key, limit):
return defer.succeed(([], from_key))
def get_current_key(self, direction='f'):
return defer.succeed(0)
def get_pagination_rows(self, user, pagination_config, key):
return defer.succeed(([], pagination_config.from_key))
class JustPresenceHandlers(object): class JustPresenceHandlers(object):
def __init__(self, hs): def __init__(self, hs):
self.presence_handler = PresenceHandler(hs) self.presence_handler = PresenceHandler(hs)
@ -243,7 +259,7 @@ class PresenceEventStreamTestCase(unittest.TestCase):
# HIDEOUS HACKERY # HIDEOUS HACKERY
# TODO(paul): This should be injected in via the HomeServer DI system # TODO(paul): This should be injected in via the HomeServer DI system
from synapse.streams.events import ( from synapse.streams.events import (
PresenceEventSource, NullSource, EventSources PresenceEventSource, EventSources
) )
old_SOURCE_TYPES = EventSources.SOURCE_TYPES old_SOURCE_TYPES = EventSources.SOURCE_TYPES

View file

@ -185,26 +185,6 @@ class SQLBaseStoreTestCase(unittest.TestCase):
[3, 4, 1, 2] [3, 4, 1, 2]
) )
@defer.inlineCallbacks
def test_update_one_with_return(self):
self.mock_txn.rowcount = 1
self.mock_txn.fetchone.return_value = ("Old Value",)
ret = yield self.datastore._simple_selectupdate_one(
table="tablename",
keyvalues={"keycol": "TheKey"},
updatevalues={"columname": "New Value"},
retcols=["columname"]
)
self.assertEquals({"columname": "Old Value"}, ret)
self.mock_txn.execute.assert_has_calls([
call('SELECT columname FROM tablename WHERE keycol = ?',
['TheKey']),
call("UPDATE tablename SET columname = ? WHERE keycol = ?",
["New Value", "TheKey"])
])
@defer.inlineCallbacks @defer.inlineCallbacks
def test_delete_one(self): def test_delete_one(self):
self.mock_txn.rowcount = 1 self.mock_txn.rowcount = 1