Have the Filtering API return Deferreds, so we can do the Datastore implementation nicely

This commit is contained in:
Paul "LeoNerd" Evans 2015-01-27 16:17:56 +00:00
parent b1503112ce
commit 059651efa1
3 changed files with 22 additions and 7 deletions

View file

@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from twisted.internet import defer
# TODO(paul) # TODO(paul)
_filters_for_user = {} _filters_for_user = {}
@ -24,18 +26,28 @@ class Filtering(object):
super(Filtering, self).__init__() super(Filtering, self).__init__()
self.hs = hs self.hs = hs
@defer.inlineCallbacks
def get_user_filter(self, user_localpart, filter_id): def get_user_filter(self, user_localpart, filter_id):
filters = _filters_for_user.get(user_localpart, None) filters = _filters_for_user.get(user_localpart, None)
if not filters or filter_id >= len(filters): if not filters or filter_id >= len(filters):
raise KeyError() raise KeyError()
return filters[filter_id] # trivial yield to make it a generator so d.iC works
yield
defer.returnValue(filters[filter_id])
@defer.inlineCallbacks
def add_user_filter(self, user_localpart, definition): def add_user_filter(self, user_localpart, definition):
filters = _filters_for_user.setdefault(user_localpart, []) filters = _filters_for_user.setdefault(user_localpart, [])
filter_id = len(filters) filter_id = len(filters)
filters.append(definition) filters.append(definition)
return filter_id # trivial yield, see above
yield
defer.returnValue(filter_id)
# TODO(paul): surely we should probably add a delete_user_filter or
# replace_user_filter at some point? There's no REST API specified for
# them however

View file

@ -54,10 +54,12 @@ class GetFilterRestServlet(RestServlet):
raise SynapseError(400, "Invalid filter_id") raise SynapseError(400, "Invalid filter_id")
try: try:
defer.returnValue((200, self.filtering.get_user_filter( filter = yield self.filtering.get_user_filter(
user_localpart=target_user.localpart, user_localpart=target_user.localpart,
filter_id=filter_id, filter_id=filter_id,
))) )
defer.returnValue((200, filter))
except KeyError: except KeyError:
raise SynapseError(400, "No such filter") raise SynapseError(400, "No such filter")
@ -89,7 +91,7 @@ class CreateFilterRestServlet(RestServlet):
except: except:
raise SynapseError(400, "Invalid filter definition") raise SynapseError(400, "Invalid filter definition")
filter_id = self.filtering.add_user_filter( filter_id = yield self.filtering.add_user_filter(
user_localpart=target_user.localpart, user_localpart=target_user.localpart,
definition=content, definition=content,
) )

View file

@ -53,14 +53,15 @@ class FilteringTestCase(unittest.TestCase):
self.filtering = hs.get_filtering() self.filtering = hs.get_filtering()
@defer.inlineCallbacks
def test_filter(self): def test_filter(self):
filter_id = self.filtering.add_user_filter( filter_id = yield self.filtering.add_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart,
definition={"type": ["m.*"]}, definition={"type": ["m.*"]},
) )
self.assertEquals(filter_id, 0) self.assertEquals(filter_id, 0)
filter = self.filtering.get_user_filter( filter = yield self.filtering.get_user_filter(
user_localpart=user_localpart, user_localpart=user_localpart,
filter_id=filter_id, filter_id=filter_id,
) )