diff --git a/README.rst b/README.rst index 768da3df64..282a53873f 100644 --- a/README.rst +++ b/README.rst @@ -95,27 +95,30 @@ Installing prerequisites on Ubuntu or Debian:: $ sudo apt-get install build-essential python2.7-dev libffi-dev \ python-pip python-setuptools sqlite3 \ - libssl-dev + libssl-dev python-virtualenv libjpeg-dev Installing prerequisites on Mac OS X:: $ xcode-select --install + $ sudo pip install virtualenv To install the synapse homeserver run:: - $ pip install --user --process-dependency-links https://github.com/matrix-org/synapse/tarball/master + $ virtualenv ~/.synapse + $ source ~/.synapse/bin/activate + $ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master -This installs synapse, along with the libraries it uses, into -``$HOME/.local/lib/`` on Linux or ``$HOME/Library/Python/2.7/lib/`` on OSX. +This installs synapse, along with the libraries it uses, into a virtual +environment under ``~/.synapse``. -Your python may not give priority to locally installed libraries over system -libraries, in which case you must add your local packages to your python path:: +To set up your homeserver, run (in your virtualenv, as before):: - $ # on Linux: - $ export PYTHONPATH=$HOME/.local/lib/python2.7/site-packages:$PYTHONPATH + $ python -m synapse.app.homeserver \ + --server-name machine.my.domain.name \ + --config-path homeserver.yaml \ + --generate-config - $ # on OSX: - $ export PYTHONPATH=$HOME/Library/Python/2.7/lib/python/site-packages:$PYTHONPATH +Substituting your host and domain name as appropriate. For reliable VoIP calls to be routed via this homeserver, you MUST configure a TURN server. See docs/turn-howto.rst for details. @@ -128,19 +131,19 @@ you get errors about ``error: no such option: --process-dependency-links`` you may need to manually upgrade it:: $ sudo pip install --upgrade pip - + If pip crashes mid-installation for reason (e.g. lost terminal), pip may refuse to run until you remove the temporary installation directory it created. To reset the installation:: $ rm -rf /tmp/pip_install_matrix - + pip seems to leak *lots* of memory during installation. For instance, a Linux host with 512MB of RAM may run out of memory whilst installing Twisted. If this happens, you will have to individually install the dependencies which are failing, e.g.:: - $ pip install --user twisted + $ pip install twisted On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you will need to export CFLAGS=-Qunused-arguments. @@ -155,7 +158,7 @@ Synapse can be installed on Cygwin. It requires the following Cygwin packages: - openssl (and openssl-devel, python-openssl) - python - python-setuptools - + The content repository requires additional packages and will be unable to process uploads without them: - libjpeg8 @@ -182,23 +185,13 @@ Running Your Homeserver To actually run your new homeserver, pick a working directory for Synapse to run (e.g. ``~/.synapse``), and:: - $ mkdir ~/.synapse $ cd ~/.synapse - - $ # on Linux - $ ~/.local/bin/synctl start - - $ # on OSX - $ ~/Library/Python/2.7/bin/synctl start + $ source ./bin/activate + $ synctl start Troubleshooting Running ----------------------- -If ``synctl`` fails with ``pkg_resources.DistributionNotFound`` errors you may -need a newer version of setuptools than that provided by your OS.:: - - $ sudo pip install setuptools --upgrade - If synapse fails with ``missing "sodium.h"`` crypto errors, you may need to manually upgrade PyNaCL, as synapse uses NaCl (http://nacl.cr.yp.to/) for encryption and digital signatures. @@ -225,13 +218,15 @@ directory of your choice:: $ cd synapse The homeserver has a number of external dependencies, that are easiest -to install by making setup.py do so, in --user mode:: +to install using pip and a virtualenv:: - $ python setup.py develop --user + $ virtualenv env + $ source env/bin/activate + $ python synapse/dependencies | xargs -i pip install + $ pip install setuptools_trial mock -This will run a process of downloading and installing into your -user's .local/lib directory all of the required dependencies that are -missing. +This will run a process of downloading and installing all the needed +dependencies into a virtual env. Once this is done, you may wish to run the homeserver's unit tests, to check that everything is installed as it should be:: @@ -252,7 +247,7 @@ IMPORTANT: Before upgrading an existing homeserver to a new version, please refer to UPGRADE.rst for any additional instructions. Otherwise, simply re-install the new codebase over the current one - e.g. -by ``pip install --user --process-dependency-links +by ``pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master`` if using pip, or by ``git pull`` if running off a git working copy. @@ -279,9 +274,9 @@ For the first form, simply pass the required hostname (of the machine) as the $ python -m synapse.app.homeserver \ --server-name machine.my.domain.name \ - --config-path homeserver.config \ + --config-path homeserver.yaml \ --generate-config - $ python -m synapse.app.homeserver --config-path homeserver.config + $ python -m synapse.app.homeserver --config-path homeserver.yaml Alternatively, you can run ``synctl start`` to guide you through the process. @@ -301,9 +296,9 @@ SRV record, as that is the name other machines will expect it to have:: $ python -m synapse.app.homeserver \ --server-name YOURDOMAIN \ --bind-port 8448 \ - --config-path homeserver.config \ + --config-path homeserver.yaml \ --generate-config - $ python -m synapse.app.homeserver --config-path homeserver.config + $ python -m synapse.app.homeserver --config-path homeserver.yaml You may additionally want to pass one or more "-v" options, in order to diff --git a/contrib/jitsimeetbridge/jitsimeetbridge.py b/contrib/jitsimeetbridge/jitsimeetbridge.py index dbc6f6ffa5..15f8e1c48b 100644 --- a/contrib/jitsimeetbridge/jitsimeetbridge.py +++ b/contrib/jitsimeetbridge/jitsimeetbridge.py @@ -39,43 +39,43 @@ ROOMDOMAIN="meet.jit.si" #ROOMDOMAIN="conference.jitsi.vuc.me" class TrivialMatrixClient: - def __init__(self, access_token): - self.token = None - self.access_token = access_token + def __init__(self, access_token): + self.token = None + self.access_token = access_token - def getEvent(self): - while True: - url = MATRIXBASE+'events?access_token='+self.access_token+"&timeout=60000" - if self.token: - url += "&from="+self.token - req = grequests.get(url) - resps = grequests.map([req]) - obj = json.loads(resps[0].content) - print "incoming from matrix",obj - if 'end' not in obj: - continue - self.token = obj['end'] - if len(obj['chunk']): - return obj['chunk'][0] + def getEvent(self): + while True: + url = MATRIXBASE+'events?access_token='+self.access_token+"&timeout=60000" + if self.token: + url += "&from="+self.token + req = grequests.get(url) + resps = grequests.map([req]) + obj = json.loads(resps[0].content) + print "incoming from matrix",obj + if 'end' not in obj: + continue + self.token = obj['end'] + if len(obj['chunk']): + return obj['chunk'][0] - def joinRoom(self, roomId): - url = MATRIXBASE+'rooms/'+roomId+'/join?access_token='+self.access_token - print url - headers={ 'Content-Type': 'application/json' } - req = grequests.post(url, headers=headers, data='{}') - resps = grequests.map([req]) - obj = json.loads(resps[0].content) - print "response: ",obj + def joinRoom(self, roomId): + url = MATRIXBASE+'rooms/'+roomId+'/join?access_token='+self.access_token + print url + headers={ 'Content-Type': 'application/json' } + req = grequests.post(url, headers=headers, data='{}') + resps = grequests.map([req]) + obj = json.loads(resps[0].content) + print "response: ",obj - def sendEvent(self, roomId, evType, event): - url = MATRIXBASE+'rooms/'+roomId+'/send/'+evType+'?access_token='+self.access_token - print url - print json.dumps(event) - headers={ 'Content-Type': 'application/json' } - req = grequests.post(url, headers=headers, data=json.dumps(event)) - resps = grequests.map([req]) - obj = json.loads(resps[0].content) - print "response: ",obj + def sendEvent(self, roomId, evType, event): + url = MATRIXBASE+'rooms/'+roomId+'/send/'+evType+'?access_token='+self.access_token + print url + print json.dumps(event) + headers={ 'Content-Type': 'application/json' } + req = grequests.post(url, headers=headers, data=json.dumps(event)) + resps = grequests.map([req]) + obj = json.loads(resps[0].content) + print "response: ",obj @@ -83,178 +83,178 @@ xmppClients = {} def matrixLoop(): - while True: - ev = matrixCli.getEvent() - print ev - if ev['type'] == 'm.room.member': - print 'membership event' - if ev['membership'] == 'invite' and ev['state_key'] == MYUSERNAME: - roomId = ev['room_id'] - print "joining room %s" % (roomId) - matrixCli.joinRoom(roomId) - elif ev['type'] == 'm.room.message': - if ev['room_id'] in xmppClients: - print "already have a bridge for that user, ignoring" - continue - print "got message, connecting" - xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) - gevent.spawn(xmppClients[ev['room_id']].xmppLoop) - elif ev['type'] == 'm.call.invite': - print "Incoming call" - #sdp = ev['content']['offer']['sdp'] - #print "sdp: %s" % (sdp) - #xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) - #gevent.spawn(xmppClients[ev['room_id']].xmppLoop) - elif ev['type'] == 'm.call.answer': - print "Call answered" - sdp = ev['content']['answer']['sdp'] - if ev['room_id'] not in xmppClients: - print "We didn't have a call for that room" - continue - # should probably check call ID too - xmppCli = xmppClients[ev['room_id']] - xmppCli.sendAnswer(sdp) - elif ev['type'] == 'm.call.hangup': - if ev['room_id'] in xmppClients: - xmppClients[ev['room_id']].stop() - del xmppClients[ev['room_id']] - + while True: + ev = matrixCli.getEvent() + print ev + if ev['type'] == 'm.room.member': + print 'membership event' + if ev['membership'] == 'invite' and ev['state_key'] == MYUSERNAME: + roomId = ev['room_id'] + print "joining room %s" % (roomId) + matrixCli.joinRoom(roomId) + elif ev['type'] == 'm.room.message': + if ev['room_id'] in xmppClients: + print "already have a bridge for that user, ignoring" + continue + print "got message, connecting" + xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) + gevent.spawn(xmppClients[ev['room_id']].xmppLoop) + elif ev['type'] == 'm.call.invite': + print "Incoming call" + #sdp = ev['content']['offer']['sdp'] + #print "sdp: %s" % (sdp) + #xmppClients[ev['room_id']] = TrivialXmppClient(ev['room_id'], ev['user_id']) + #gevent.spawn(xmppClients[ev['room_id']].xmppLoop) + elif ev['type'] == 'm.call.answer': + print "Call answered" + sdp = ev['content']['answer']['sdp'] + if ev['room_id'] not in xmppClients: + print "We didn't have a call for that room" + continue + # should probably check call ID too + xmppCli = xmppClients[ev['room_id']] + xmppCli.sendAnswer(sdp) + elif ev['type'] == 'm.call.hangup': + if ev['room_id'] in xmppClients: + xmppClients[ev['room_id']].stop() + del xmppClients[ev['room_id']] + class TrivialXmppClient: - def __init__(self, matrixRoom, userId): - self.rid = 0 - self.matrixRoom = matrixRoom - self.userId = userId - self.running = True + def __init__(self, matrixRoom, userId): + self.rid = 0 + self.matrixRoom = matrixRoom + self.userId = userId + self.running = True - def stop(self): - self.running = False + def stop(self): + self.running = False - def nextRid(self): - self.rid += 1 - return '%d' % (self.rid) + def nextRid(self): + self.rid += 1 + return '%d' % (self.rid) - def sendIq(self, xml): - fullXml = "%s" % (self.nextRid(), self.sid, xml) - #print "\t>>>%s" % (fullXml) - return self.xmppPoke(fullXml) - - def xmppPoke(self, xml): - headers = {'Content-Type': 'application/xml'} - req = grequests.post(HTTPBIND, verify=False, headers=headers, data=xml) - resps = grequests.map([req]) - obj = BeautifulSoup(resps[0].content) - return obj + def sendIq(self, xml): + fullXml = "%s" % (self.nextRid(), self.sid, xml) + #print "\t>>>%s" % (fullXml) + return self.xmppPoke(fullXml) - def sendAnswer(self, answer): - print "sdp from matrix client",answer - p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--sdp'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) - jingle, out_err = p.communicate(answer) - jingle = jingle % { - 'tojid': self.callfrom, - 'action': 'session-accept', - 'initiator': self.callfrom, - 'responder': self.jid, - 'sid': self.callsid - } - print "answer jingle from sdp",jingle - res = self.sendIq(jingle) - print "reply from answer: ",res - - self.ssrcs = {} - jingleSoup = BeautifulSoup(jingle) - for cont in jingleSoup.iq.jingle.findAll('content'): - if cont.description: - self.ssrcs[cont['name']] = cont.description['ssrc'] - print "my ssrcs:",self.ssrcs + def xmppPoke(self, xml): + headers = {'Content-Type': 'application/xml'} + req = grequests.post(HTTPBIND, verify=False, headers=headers, data=xml) + resps = grequests.map([req]) + obj = BeautifulSoup(resps[0].content) + return obj - gevent.joinall([ - gevent.spawn(self.advertiseSsrcs) - ]) - - def advertiseSsrcs(self): + def sendAnswer(self, answer): + print "sdp from matrix client",answer + p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--sdp'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) + jingle, out_err = p.communicate(answer) + jingle = jingle % { + 'tojid': self.callfrom, + 'action': 'session-accept', + 'initiator': self.callfrom, + 'responder': self.jid, + 'sid': self.callsid + } + print "answer jingle from sdp",jingle + res = self.sendIq(jingle) + print "reply from answer: ",res + + self.ssrcs = {} + jingleSoup = BeautifulSoup(jingle) + for cont in jingleSoup.iq.jingle.findAll('content'): + if cont.description: + self.ssrcs[cont['name']] = cont.description['ssrc'] + print "my ssrcs:",self.ssrcs + + gevent.joinall([ + gevent.spawn(self.advertiseSsrcs) + ]) + + def advertiseSsrcs(self): time.sleep(7) - print "SSRC spammer started" - while self.running: - ssrcMsg = "%(nick)s" % { 'tojid': "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), 'nick': self.userId, 'assrc': self.ssrcs['audio'], 'vssrc': self.ssrcs['video'] } - res = self.sendIq(ssrcMsg) - print "reply from ssrc announce: ",res - time.sleep(10) - - + print "SSRC spammer started" + while self.running: + ssrcMsg = "%(nick)s" % { 'tojid': "%s@%s/%s" % (ROOMNAME, ROOMDOMAIN, self.shortJid), 'nick': self.userId, 'assrc': self.ssrcs['audio'], 'vssrc': self.ssrcs['video'] } + res = self.sendIq(ssrcMsg) + print "reply from ssrc announce: ",res + time.sleep(10) - def xmppLoop(self): - self.matrixCallId = time.time() - res = self.xmppPoke("" % (self.nextRid(), HOST)) - print res - self.sid = res.body['sid'] - print "sid %s" % (self.sid) - res = self.sendIq("") + def xmppLoop(self): + self.matrixCallId = time.time() + res = self.xmppPoke("" % (self.nextRid(), HOST)) - res = self.xmppPoke("" % (self.nextRid(), self.sid, HOST)) - - res = self.sendIq("") - print res + print res + self.sid = res.body['sid'] + print "sid %s" % (self.sid) - self.jid = res.body.iq.bind.jid.string - print "jid: %s" % (self.jid) - self.shortJid = self.jid.split('-')[0] + res = self.sendIq("") - res = self.sendIq("") + res = self.xmppPoke("" % (self.nextRid(), self.sid, HOST)) - #randomthing = res.body.iq['to'] - #whatsitpart = randomthing.split('-')[0] + res = self.sendIq("") + print res - #print "other random bind thing: %s" % (randomthing) + self.jid = res.body.iq.bind.jid.string + print "jid: %s" % (self.jid) + self.shortJid = self.jid.split('-')[0] - # advertise preence to the jitsi room, with our nick - res = self.sendIq("%s" % (HOST, TURNSERVER, ROOMNAME, ROOMDOMAIN, self.userId)) - self.muc = {'users': []} - for p in res.body.findAll('presence'): - u = {} - u['shortJid'] = p['from'].split('/')[1] - if p.c and p.c.nick: - u['nick'] = p.c.nick.string - self.muc['users'].append(u) - print "muc: ",self.muc + res = self.sendIq("") - # wait for stuff - while True: - print "waiting..." - res = self.sendIq("") - print "got from stream: ",res - if res.body.iq: - jingles = res.body.iq.findAll('jingle') - if len(jingles): - self.callfrom = res.body.iq['from'] - self.handleInvite(jingles[0]) - elif 'type' in res.body and res.body['type'] == 'terminate': - self.running = False - del xmppClients[self.matrixRoom] - return + #randomthing = res.body.iq['to'] + #whatsitpart = randomthing.split('-')[0] + + #print "other random bind thing: %s" % (randomthing) + + # advertise preence to the jitsi room, with our nick + res = self.sendIq("%s" % (HOST, TURNSERVER, ROOMNAME, ROOMDOMAIN, self.userId)) + self.muc = {'users': []} + for p in res.body.findAll('presence'): + u = {} + u['shortJid'] = p['from'].split('/')[1] + if p.c and p.c.nick: + u['nick'] = p.c.nick.string + self.muc['users'].append(u) + print "muc: ",self.muc + + # wait for stuff + while True: + print "waiting..." + res = self.sendIq("") + print "got from stream: ",res + if res.body.iq: + jingles = res.body.iq.findAll('jingle') + if len(jingles): + self.callfrom = res.body.iq['from'] + self.handleInvite(jingles[0]) + elif 'type' in res.body and res.body['type'] == 'terminate': + self.running = False + del xmppClients[self.matrixRoom] + return + + def handleInvite(self, jingle): + self.initiator = jingle['initiator'] + self.callsid = jingle['sid'] + p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--jingle'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) + print "raw jingle invite",str(jingle) + sdp, out_err = p.communicate(str(jingle)) + print "transformed remote offer sdp",sdp + inviteEvent = { + 'offer': { + 'type': 'offer', + 'sdp': sdp + }, + 'call_id': self.matrixCallId, + 'version': 0, + 'lifetime': 30000 + } + matrixCli.sendEvent(self.matrixRoom, 'm.call.invite', inviteEvent) - def handleInvite(self, jingle): - self.initiator = jingle['initiator'] - self.callsid = jingle['sid'] - p = subprocess.Popen(['node', 'unjingle/unjingle.js', '--jingle'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) - print "raw jingle invite",str(jingle) - sdp, out_err = p.communicate(str(jingle)) - print "transformed remote offer sdp",sdp - inviteEvent = { - 'offer': { - 'type': 'offer', - 'sdp': sdp - }, - 'call_id': self.matrixCallId, - 'version': 0, - 'lifetime': 30000 - } - matrixCli.sendEvent(self.matrixRoom, 'm.call.invite', inviteEvent) - matrixCli = TrivialMatrixClient(ACCESS_TOKEN) gevent.joinall([ - gevent.spawn(matrixLoop) + gevent.spawn(matrixLoop) ]) diff --git a/setup.py b/setup.py index 043cd044a7..eb1d17c25f 100755 --- a/setup.py +++ b/setup.py @@ -32,8 +32,8 @@ setup( description="Reference Synapse Home Server", install_requires=[ "syutil==0.0.2", - "matrix_angular_sdk==0.6.0", - "Twisted>=14.0.0", + "matrix_angular_sdk>=0.6.0", + "Twisted==14.0.2", "service_identity>=1.0.0", "pyopenssl>=0.14", "pyyaml", @@ -50,6 +50,7 @@ setup( "https://github.com/matrix-org/matrix-angular-sdk/tarball/v0.6.0/#egg=matrix_angular_sdk-0.6.0", ], setup_requires=[ + "Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0 "setuptools_trial", "setuptools>=1.0.0", # Needs setuptools that supports git+ssh. # TODO: Do we need this now? we don't use git+ssh. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index a342a0e0da..37e31d2b6f 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -21,7 +21,7 @@ from synapse.api.constants import EventTypes, Membership, JoinRules from synapse.api.errors import AuthError, StoreError, Codes, SynapseError from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor -from synapse.types import UserID +from synapse.types import UserID, ClientInfo import logging @@ -102,6 +102,8 @@ class Auth(object): def check_host_in_room(self, room_id, host): curr_state = yield self.state.get_current_state(room_id) + logger.debug("Got curr_state %s", curr_state) + for event in curr_state: if event.type == EventTypes.Member: try: @@ -290,7 +292,9 @@ class Auth(object): Args: request - An HTTP request with an access_token query parameter. Returns: - UserID : User ID object of the user making the request + tuple : of UserID and device string: + User ID object of the user making the request + Client ID object of the client instance the user is using Raises: AuthError if no user by that token exists or the token is invalid. """ @@ -299,6 +303,8 @@ class Auth(object): access_token = request.args["access_token"][0] user_info = yield self.get_user_by_token(access_token) user = user_info["user"] + device_id = user_info["device_id"] + token_id = user_info["token_id"] ip_addr = self.hs.get_ip_from_request(request) user_agent = request.requestHeaders.getRawHeaders( @@ -314,7 +320,7 @@ class Auth(object): user_agent=user_agent ) - defer.returnValue(user) + defer.returnValue((user, ClientInfo(device_id, token_id))) except KeyError: raise AuthError(403, "Missing access token.") @@ -339,6 +345,7 @@ class Auth(object): "admin": bool(ret.get("admin", False)), "device_id": ret.get("device_id"), "user": UserID.from_string(ret.get("name")), + "token_id": ret.get("token_id", None), } defer.returnValue(user_info) @@ -353,9 +360,23 @@ class Auth(object): def add_auth_events(self, builder, context): yield run_on_reactor() - if builder.type == EventTypes.Create: - builder.auth_events = [] - return + auth_ids = self.compute_auth_events(builder, context) + + auth_events_entries = yield self.store.add_event_hashes( + auth_ids + ) + + builder.auth_events = auth_events_entries + + context.auth_events = { + k: v + for k, v in context.current_state.items() + if v.event_id in auth_ids + } + + def compute_auth_events(self, event, context): + if event.type == EventTypes.Create: + return [] auth_ids = [] @@ -368,7 +389,7 @@ class Auth(object): key = (EventTypes.JoinRules, "", ) join_rule_event = context.current_state.get(key) - key = (EventTypes.Member, builder.user_id, ) + key = (EventTypes.Member, event.user_id, ) member_event = context.current_state.get(key) key = (EventTypes.Create, "", ) @@ -382,8 +403,8 @@ class Auth(object): else: is_public = False - if builder.type == EventTypes.Member: - e_type = builder.content["membership"] + if event.type == EventTypes.Member: + e_type = event.content["membership"] if e_type in [Membership.JOIN, Membership.INVITE]: if join_rule_event: auth_ids.append(join_rule_event.event_id) @@ -398,17 +419,7 @@ class Auth(object): if member_event.content["membership"] == Membership.JOIN: auth_ids.append(member_event.event_id) - auth_events_entries = yield self.store.add_event_hashes( - auth_ids - ) - - builder.auth_events = auth_events_entries - - context.auth_events = { - k: v - for k, v in context.current_state.items() - if v.event_id in auth_ids - } + return auth_ids @log_function def _can_send_event(self, event, auth_events): diff --git a/synapse/api/constants.py b/synapse/api/constants.py index 7ee6dcc46e..0d3fc629af 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -74,3 +74,9 @@ class EventTypes(object): Message = "m.room.message" Topic = "m.room.topic" Name = "m.room.name" + + +class RejectedReason(object): + AUTH_ERROR = "auth_error" + REPLACED = "replaced" + NOT_ANCESTOR = "not_ancestor" diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 2b049debf3..ad478aa6b7 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -21,6 +21,7 @@ logger = logging.getLogger(__name__) class Codes(object): + UNRECOGNIZED = "M_UNRECOGNIZED" UNAUTHORIZED = "M_UNAUTHORIZED" FORBIDDEN = "M_FORBIDDEN" BAD_JSON = "M_BAD_JSON" @@ -34,6 +35,7 @@ class Codes(object): LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED" CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED" CAPTCHA_INVALID = "M_CAPTCHA_INVALID" + MISSING_PARAM = "M_MISSING_PARAM", TOO_LARGE = "M_TOO_LARGE" @@ -81,6 +83,35 @@ class RegistrationError(SynapseError): pass +class UnrecognizedRequestError(SynapseError): + """An error indicating we don't understand the request you're trying to make""" + def __init__(self, *args, **kwargs): + if "errcode" not in kwargs: + kwargs["errcode"] = Codes.UNRECOGNIZED + message = None + if len(args) == 0: + message = "Unrecognized request" + else: + message = args[0] + super(UnrecognizedRequestError, self).__init__( + 400, + message, + **kwargs + ) + + +class NotFoundError(SynapseError): + """An error indicating we can't find the thing you asked for""" + def __init__(self, *args, **kwargs): + if "errcode" not in kwargs: + kwargs["errcode"] = Codes.NOT_FOUND + super(NotFoundError, self).__init__( + 404, + "Not found", + **kwargs + ) + + class AuthError(SynapseError): """An error raised when there was a problem authorising an event.""" diff --git a/synapse/api/filtering.py b/synapse/api/filtering.py new file mode 100644 index 0000000000..4d570b74f8 --- /dev/null +++ b/synapse/api/filtering.py @@ -0,0 +1,229 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from synapse.api.errors import SynapseError +from synapse.types import UserID, RoomID + + +class Filtering(object): + + def __init__(self, hs): + super(Filtering, self).__init__() + self.store = hs.get_datastore() + + def get_user_filter(self, user_localpart, filter_id): + result = self.store.get_user_filter(user_localpart, filter_id) + result.addCallback(Filter) + return result + + def add_user_filter(self, user_localpart, user_filter): + self._check_valid_filter(user_filter) + return self.store.add_user_filter(user_localpart, user_filter) + + # TODO(paul): surely we should probably add a delete_user_filter or + # replace_user_filter at some point? There's no REST API specified for + # them however + + def _check_valid_filter(self, user_filter_json): + """Check if the provided filter is valid. + + This inspects all definitions contained within the filter. + + Args: + user_filter_json(dict): The filter + Raises: + SynapseError: If the filter is not valid. + """ + # NB: Filters are the complete json blobs. "Definitions" are an + # individual top-level key e.g. public_user_data. Filters are made of + # many definitions. + + top_level_definitions = [ + "public_user_data", "private_user_data", "server_data" + ] + + room_level_definitions = [ + "state", "events", "ephemeral" + ] + + for key in top_level_definitions: + if key in user_filter_json: + self._check_definition(user_filter_json[key]) + + if "room" in user_filter_json: + for key in room_level_definitions: + if key in user_filter_json["room"]: + self._check_definition(user_filter_json["room"][key]) + + def _check_definition(self, definition): + """Check if the provided definition is valid. + + This inspects not only the types but also the values to make sure they + make sense. + + Args: + definition(dict): The filter definition + Raises: + SynapseError: If there was a problem with this definition. + """ + # NB: Filters are the complete json blobs. "Definitions" are an + # individual top-level key e.g. public_user_data. Filters are made of + # many definitions. + if type(definition) != dict: + raise SynapseError( + 400, "Expected JSON object, not %s" % (definition,) + ) + + # check rooms are valid room IDs + room_id_keys = ["rooms", "not_rooms"] + for key in room_id_keys: + if key in definition: + if type(definition[key]) != list: + raise SynapseError(400, "Expected %s to be a list." % key) + for room_id in definition[key]: + RoomID.from_string(room_id) + + # check senders are valid user IDs + user_id_keys = ["senders", "not_senders"] + for key in user_id_keys: + if key in definition: + if type(definition[key]) != list: + raise SynapseError(400, "Expected %s to be a list." % key) + for user_id in definition[key]: + UserID.from_string(user_id) + + # TODO: We don't limit event type values but we probably should... + # check types are valid event types + event_keys = ["types", "not_types"] + for key in event_keys: + if key in definition: + if type(definition[key]) != list: + raise SynapseError(400, "Expected %s to be a list." % key) + for event_type in definition[key]: + if not isinstance(event_type, basestring): + raise SynapseError(400, "Event type should be a string") + + if "format" in definition: + event_format = definition["format"] + if event_format not in ["federation", "events"]: + raise SynapseError(400, "Invalid format: %s" % (event_format,)) + + if "select" in definition: + event_select_list = definition["select"] + for select_key in event_select_list: + if select_key not in ["event_id", "origin_server_ts", + "thread_id", "content", "content.body"]: + raise SynapseError(400, "Bad select: %s" % (select_key,)) + + if ("bundle_updates" in definition and + type(definition["bundle_updates"]) != bool): + raise SynapseError(400, "Bad bundle_updates: expected bool.") + + +class Filter(object): + def __init__(self, filter_json): + self.filter_json = filter_json + + def filter_public_user_data(self, events): + return self._filter_on_key(events, ["public_user_data"]) + + def filter_private_user_data(self, events): + return self._filter_on_key(events, ["private_user_data"]) + + def filter_room_state(self, events): + return self._filter_on_key(events, ["room", "state"]) + + def filter_room_events(self, events): + return self._filter_on_key(events, ["room", "events"]) + + def filter_room_ephemeral(self, events): + return self._filter_on_key(events, ["room", "ephemeral"]) + + def _filter_on_key(self, events, keys): + filter_json = self.filter_json + if not filter_json: + return events + + try: + # extract the right definition from the filter + definition = filter_json + for key in keys: + definition = definition[key] + return self._filter_with_definition(events, definition) + except KeyError: + # return all events if definition isn't specified. + return events + + def _filter_with_definition(self, events, definition): + return [e for e in events if self._passes_definition(definition, e)] + + def _passes_definition(self, definition, event): + """Check if the event passes through the given definition. + + Args: + definition(dict): The definition to check against. + event(Event): The event to check. + Returns: + True if the event passes through the filter. + """ + # Algorithm notes: + # For each key in the definition, check the event meets the criteria: + # * For types: Literal match or prefix match (if ends with wildcard) + # * For senders/rooms: Literal match only + # * "not_" checks take presedence (e.g. if "m.*" is in both 'types' + # and 'not_types' then it is treated as only being in 'not_types') + + # room checks + if hasattr(event, "room_id"): + room_id = event.room_id + allow_rooms = definition.get("rooms", None) + reject_rooms = definition.get("not_rooms", None) + if reject_rooms and room_id in reject_rooms: + return False + if allow_rooms and room_id not in allow_rooms: + return False + + # sender checks + if hasattr(event, "sender"): + # Should we be including event.state_key for some event types? + sender = event.sender + allow_senders = definition.get("senders", None) + reject_senders = definition.get("not_senders", None) + if reject_senders and sender in reject_senders: + return False + if allow_senders and sender not in allow_senders: + return False + + # type checks + if "not_types" in definition: + for def_type in definition["not_types"]: + if self._event_matches_type(event, def_type): + return False + if "types" in definition: + included = False + for def_type in definition["types"]: + if self._event_matches_type(event, def_type): + included = True + break + if not included: + return False + + return True + + def _event_matches_type(self, event, def_type): + if def_type.endswith("*"): + type_prefix = def_type[:-1] + return event.type.startswith(type_prefix) + else: + return event.type == def_type diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index f7e24d0cc6..a9397de5b2 100755 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -277,6 +277,8 @@ def setup(): bind_port = None hs.start_listening(bind_port, config.unsecure_port) + hs.get_pusherpool().start() + if config.daemonize: print config.pid_file daemon = Daemonize( diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py index 9c910fa3fc..cdb6279764 100644 --- a/synapse/crypto/keyclient.py +++ b/synapse/crypto/keyclient.py @@ -61,9 +61,11 @@ class SynapseKeyClientProtocol(HTTPClient): def __init__(self): self.remote_key = defer.Deferred() + self.host = None def connectionMade(self): - logger.debug("Connected to %s", self.transport.getHost()) + self.host = self.transport.getHost() + logger.debug("Connected to %s", self.host) self.sendCommand(b"GET", b"/_matrix/key/v1/") self.endHeaders() self.timer = reactor.callLater( @@ -92,8 +94,7 @@ class SynapseKeyClientProtocol(HTTPClient): self.timer.cancel() def on_timeout(self): - logger.debug("Timeout waiting for response from %s", - self.transport.getHost()) + logger.debug("Timeout waiting for response from %s", self.host) self.remote_key.errback(IOError("Timeout waiting for response")) self.transport.abortConnection() diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py index 4252e5ab5c..bf07951027 100644 --- a/synapse/events/__init__.py +++ b/synapse/events/__init__.py @@ -18,7 +18,7 @@ from synapse.util.frozenutils import freeze, unfreeze class _EventInternalMetadata(object): def __init__(self, internal_metadata_dict): - self.__dict__ = internal_metadata_dict + self.__dict__ = dict(internal_metadata_dict) def get_dict(self): return dict(self.__dict__) diff --git a/synapse/events/builder.py b/synapse/events/builder.py index a9b1b99a10..9d45bdb892 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -23,14 +23,15 @@ import copy class EventBuilder(EventBase): - def __init__(self, key_values={}): + def __init__(self, key_values={}, internal_metadata_dict={}): signatures = copy.deepcopy(key_values.pop("signatures", {})) unsigned = copy.deepcopy(key_values.pop("unsigned", {})) super(EventBuilder, self).__init__( key_values, signatures=signatures, - unsigned=unsigned + unsigned=unsigned, + internal_metadata_dict=internal_metadata_dict, ) def build(self): diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py index 6bbba8d6ba..7e98bdef28 100644 --- a/synapse/events/snapshot.py +++ b/synapse/events/snapshot.py @@ -20,3 +20,4 @@ class EventContext(object): self.current_state = current_state self.auth_events = auth_events self.state_group = None + self.rejected = False diff --git a/synapse/events/utils.py b/synapse/events/utils.py index e391aca4cc..1aa952150e 100644 --- a/synapse/events/utils.py +++ b/synapse/events/utils.py @@ -45,12 +45,14 @@ def prune_event(event): "membership", ] + event_dict = event.get_dict() + new_content = {} def add_fields(*fields): for field in fields: if field in event.content: - new_content[field] = event.content[field] + new_content[field] = event_dict["content"][field] if event_type == EventTypes.Member: add_fields("membership") @@ -75,7 +77,7 @@ def prune_event(event): allowed_fields = { k: v - for k, v in event.get_dict().items() + for k, v in event_dict.items() if k in allowed_keys } @@ -86,10 +88,53 @@ def prune_event(event): if "age_ts" in event.unsigned: allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"] - return type(event)(allowed_fields) + return type(event)( + allowed_fields, + internal_metadata_dict=event.internal_metadata.get_dict() + ) -def serialize_event(e, time_now_ms, client_event=True): +def format_event_raw(d): + return d + + +def format_event_for_client_v1(d): + d["user_id"] = d.pop("sender", None) + + move_keys = ("age", "redacted_because", "replaces_state", "prev_content") + for key in move_keys: + if key in d["unsigned"]: + d[key] = d["unsigned"][key] + + drop_keys = ( + "auth_events", "prev_events", "hashes", "signatures", "depth", + "unsigned", "origin", "prev_state" + ) + for key in drop_keys: + d.pop(key, None) + return d + + +def format_event_for_client_v2(d): + drop_keys = ( + "auth_events", "prev_events", "hashes", "signatures", "depth", + "origin", "prev_state", + ) + for key in drop_keys: + d.pop(key, None) + return d + + +def format_event_for_client_v2_without_event_id(d): + d = format_event_for_client_v2(d) + d.pop("room_id", None) + d.pop("event_id", None) + return d + + +def serialize_event(e, time_now_ms, as_client_event=True, + event_format=format_event_for_client_v1, + token_id=None): # FIXME(erikj): To handle the case of presence events and the like if not isinstance(e, EventBase): return e @@ -99,43 +144,22 @@ def serialize_event(e, time_now_ms, client_event=True): # Should this strip out None's? d = {k: v for k, v in e.get_dict().items()} - if not client_event: - # set the age and keep all other keys - if "age_ts" in d["unsigned"]: - d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"] - return d - if "age_ts" in d["unsigned"]: - d["age"] = time_now_ms - d["unsigned"]["age_ts"] + d["unsigned"]["age"] = time_now_ms - d["unsigned"]["age_ts"] del d["unsigned"]["age_ts"] - d["user_id"] = d.pop("sender", None) - if "redacted_because" in e.unsigned: - d["redacted_because"] = serialize_event( + d["unsigned"]["redacted_because"] = serialize_event( e.unsigned["redacted_because"], time_now_ms ) - del d["unsigned"]["redacted_because"] + if token_id is not None: + if token_id == getattr(e.internal_metadata, "token_id", None): + txn_id = getattr(e.internal_metadata, "txn_id", None) + if txn_id is not None: + d["unsigned"]["transaction_id"] = txn_id - if "redacted_by" in e.unsigned: - d["redacted_by"] = e.unsigned["redacted_by"] - del d["unsigned"]["redacted_by"] - - if "replaces_state" in e.unsigned: - d["replaces_state"] = e.unsigned["replaces_state"] - del d["unsigned"]["replaces_state"] - - if "prev_content" in e.unsigned: - d["prev_content"] = e.unsigned["prev_content"] - del d["unsigned"]["prev_content"] - - del d["auth_events"] - del d["prev_events"] - del d["hashes"] - del d["signatures"] - d.pop("depth", None) - d.pop("unsigned", None) - d.pop("origin", None) - - return d + if as_client_event: + return event_format(d) + else: + return d diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py new file mode 100644 index 0000000000..e1539bd0e0 --- /dev/null +++ b/synapse/federation/federation_client.py @@ -0,0 +1,409 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from .units import Edu + +from synapse.util.logutils import log_function +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +from syutil.jsonutil import encode_canonical_json + +from synapse.crypto.event_signing import check_event_content_hash + +from synapse.api.errors import SynapseError + +import logging + + +logger = logging.getLogger(__name__) + + +class FederationClient(object): + @log_function + def send_pdu(self, pdu, destinations): + """Informs the replication layer about a new PDU generated within the + home server that should be transmitted to others. + + TODO: Figure out when we should actually resolve the deferred. + + Args: + pdu (Pdu): The new Pdu. + + Returns: + Deferred: Completes when we have successfully processed the PDU + and replicated it to any interested remote home servers. + """ + order = self._order + self._order += 1 + + logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) + + # TODO, add errback, etc. + self._transaction_queue.enqueue_pdu(pdu, destinations, order) + + logger.debug( + "[%s] transaction_layer.enqueue_pdu... done", + pdu.event_id + ) + + @log_function + def send_edu(self, destination, edu_type, content): + edu = Edu( + origin=self.server_name, + destination=destination, + edu_type=edu_type, + content=content, + ) + + # TODO, add errback, etc. + self._transaction_queue.enqueue_edu(edu) + return defer.succeed(None) + + @log_function + def send_failure(self, failure, destination): + self._transaction_queue.enqueue_failure(failure, destination) + return defer.succeed(None) + + @log_function + def make_query(self, destination, query_type, args, + retry_on_dns_fail=True): + """Sends a federation Query to a remote homeserver of the given type + and arguments. + + Args: + destination (str): Domain name of the remote homeserver + query_type (str): Category of the query type; should match the + handler name used in register_query_handler(). + args (dict): Mapping of strings to strings containing the details + of the query request. + + Returns: + a Deferred which will eventually yield a JSON object from the + response + """ + return self.transport_layer.make_query( + destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail + ) + + @defer.inlineCallbacks + @log_function + def backfill(self, dest, context, limit, extremities): + """Requests some more historic PDUs for the given context from the + given destination server. + + Args: + dest (str): The remote home server to ask. + context (str): The context to backfill. + limit (int): The maximum number of PDUs to return. + extremities (list): List of PDU id and origins of the first pdus + we have seen from the context + + Returns: + Deferred: Results in the received PDUs. + """ + logger.debug("backfill extrem=%s", extremities) + + # If there are no extremeties then we've (probably) reached the start. + if not extremities: + return + + transaction_data = yield self.transport_layer.backfill( + dest, context, extremities, limit) + + logger.debug("backfill transaction_data=%s", repr(transaction_data)) + + pdus = [ + self.event_from_pdu_json(p, outlier=False) + for p in transaction_data["pdus"] + ] + + for i, pdu in enumerate(pdus): + pdus[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + defer.returnValue(pdus) + + @defer.inlineCallbacks + @log_function + def get_pdu(self, destinations, event_id, outlier=False): + """Requests the PDU with given origin and ID from the remote home + servers. + + Will attempt to get the PDU from each destination in the list until + one succeeds. + + This will persist the PDU locally upon receipt. + + Args: + destinations (list): Which home servers to query + pdu_origin (str): The home server that originally sent the pdu. + event_id (str) + outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if + it's from an arbitary point in the context as opposed to part + of the current block of PDUs. Defaults to `False` + + Returns: + Deferred: Results in the requested PDU. + """ + + # TODO: Rate limit the number of times we try and get the same event. + + pdu = None + for destination in destinations: + try: + transaction_data = yield self.transport_layer.get_event( + destination, event_id + ) + + logger.debug("transaction_data %r", transaction_data) + + pdu_list = [ + self.event_from_pdu_json(p, outlier=outlier) + for p in transaction_data["pdus"] + ] + + if pdu_list: + pdu = pdu_list[0] + + # Check signatures are correct. + pdu = yield self._check_sigs_and_hash(pdu) + + break + + except Exception as e: + logger.info( + "Failed to get PDU %s from %s because %s", + event_id, destination, e, + ) + continue + + defer.returnValue(pdu) + + @defer.inlineCallbacks + @log_function + def get_state_for_room(self, destination, room_id, event_id): + """Requests all of the `current` state PDUs for a given room from + a remote home server. + + Args: + destination (str): The remote homeserver to query for the state. + room_id (str): The id of the room we're interested in. + event_id (str): The id of the event we want the state at. + + Returns: + Deferred: Results in a list of PDUs. + """ + + result = yield self.transport_layer.get_room_state( + destination, room_id, event_id=event_id, + ) + + pdus = [ + self.event_from_pdu_json(p, outlier=True) for p in result["pdus"] + ] + + auth_chain = [ + self.event_from_pdu_json(p, outlier=True) + for p in result.get("auth_chain", []) + ] + + for i, pdu in enumerate(pdus): + pdus[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + for i, pdu in enumerate(auth_chain): + auth_chain[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + defer.returnValue((pdus, auth_chain)) + + @defer.inlineCallbacks + @log_function + def get_event_auth(self, destination, room_id, event_id): + res = yield self.transport_layer.get_event_auth( + destination, room_id, event_id, + ) + + auth_chain = [ + self.event_from_pdu_json(p, outlier=True) + for p in res["auth_chain"] + ] + + for i, pdu in enumerate(auth_chain): + auth_chain[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + auth_chain.sort(key=lambda e: e.depth) + + defer.returnValue(auth_chain) + + @defer.inlineCallbacks + def make_join(self, destination, room_id, user_id): + ret = yield self.transport_layer.make_join( + destination, room_id, user_id + ) + + pdu_dict = ret["event"] + + logger.debug("Got response to make_join: %s", pdu_dict) + + defer.returnValue(self.event_from_pdu_json(pdu_dict)) + + @defer.inlineCallbacks + def send_join(self, destination, pdu): + time_now = self._clock.time_msec() + _, content = yield self.transport_layer.send_join( + destination=destination, + room_id=pdu.room_id, + event_id=pdu.event_id, + content=pdu.get_pdu_json(time_now), + ) + + logger.debug("Got content: %s", content) + + state = [ + self.event_from_pdu_json(p, outlier=True) + for p in content.get("state", []) + ] + + auth_chain = [ + self.event_from_pdu_json(p, outlier=True) + for p in content.get("auth_chain", []) + ] + + for i, pdu in enumerate(state): + state[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + for i, pdu in enumerate(auth_chain): + auth_chain[i] = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + auth_chain.sort(key=lambda e: e.depth) + + defer.returnValue({ + "state": state, + "auth_chain": auth_chain, + }) + + @defer.inlineCallbacks + def send_invite(self, destination, room_id, event_id, pdu): + time_now = self._clock.time_msec() + code, content = yield self.transport_layer.send_invite( + destination=destination, + room_id=room_id, + event_id=event_id, + content=pdu.get_pdu_json(time_now), + ) + + pdu_dict = content["event"] + + logger.debug("Got response to send_invite: %s", pdu_dict) + + pdu = self.event_from_pdu_json(pdu_dict) + + # Check signatures are correct. + pdu = yield self._check_sigs_and_hash(pdu) + + # FIXME: We should handle signature failures more gracefully. + + defer.returnValue(pdu) + + @defer.inlineCallbacks + def query_auth(self, destination, room_id, event_id, local_auth): + """ + Params: + destination (str) + event_it (str) + local_auth (list) + """ + time_now = self._clock.time_msec() + + send_content = { + "auth_chain": [e.get_pdu_json(time_now) for e in local_auth], + } + + code, content = yield self.transport_layer.send_query_auth( + destination=destination, + room_id=room_id, + event_id=event_id, + content=send_content, + ) + + auth_chain = [ + (yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) + for e in content["auth_chain"] + ] + + ret = { + "auth_chain": auth_chain, + "rejects": content.get("rejects", []), + "missing": content.get("missing", []), + } + + defer.returnValue(ret) + + def event_from_pdu_json(self, pdu_json, outlier=False): + event = FrozenEvent( + pdu_json + ) + + event.internal_metadata.outlier = outlier + + return event + + @defer.inlineCallbacks + def _check_sigs_and_hash(self, pdu): + """Throws a SynapseError if the PDU does not have the correct + signatures. + + Returns: + FrozenEvent: Either the given event or it redacted if it failed the + content hash check. + """ + # Check signatures are correct. + redacted_event = prune_event(pdu) + redacted_pdu_json = redacted_event.get_pdu_json() + + try: + yield self.keyring.verify_json_for_server( + pdu.origin, redacted_pdu_json + ) + except SynapseError: + logger.warn( + "Signature check failed for %s redacted to %s", + encode_canonical_json(pdu.get_pdu_json()), + encode_canonical_json(redacted_pdu_json), + ) + raise + + if not check_event_content_hash(pdu): + logger.warn( + "Event content has been tampered, redacting %s, %s", + pdu.event_id, encode_canonical_json(pdu.get_dict()) + ) + defer.returnValue(redacted_event) + + defer.returnValue(pdu) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py new file mode 100644 index 0000000000..5fbd8b19de --- /dev/null +++ b/synapse/federation/federation_server.py @@ -0,0 +1,462 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from .units import Transaction, Edu + +from synapse.util.logutils import log_function +from synapse.util.logcontext import PreserveLoggingContext +from synapse.events import FrozenEvent +from synapse.events.utils import prune_event + +from syutil.jsonutil import encode_canonical_json + +from synapse.crypto.event_signing import check_event_content_hash + +from synapse.api.errors import FederationError, SynapseError + +import logging + + +logger = logging.getLogger(__name__) + + +class FederationServer(object): + def set_handler(self, handler): + """Sets the handler that the replication layer will use to communicate + receipt of new PDUs from other home servers. The required methods are + documented on :py:class:`.ReplicationHandler`. + """ + self.handler = handler + + def register_edu_handler(self, edu_type, handler): + if edu_type in self.edu_handlers: + raise KeyError("Already have an EDU handler for %s" % (edu_type,)) + + self.edu_handlers[edu_type] = handler + + def register_query_handler(self, query_type, handler): + """Sets the handler callable that will be used to handle an incoming + federation Query of the given type. + + Args: + query_type (str): Category name of the query, which should match + the string used by make_query. + handler (callable): Invoked to handle incoming queries of this type + + handler is invoked as: + result = handler(args) + + where 'args' is a dict mapping strings to strings of the query + arguments. It should return a Deferred that will eventually yield an + object to encode as JSON. + """ + if query_type in self.query_handlers: + raise KeyError( + "Already have a Query handler for %s" % (query_type,) + ) + + self.query_handlers[query_type] = handler + + @defer.inlineCallbacks + @log_function + def on_backfill_request(self, origin, room_id, versions, limit): + pdus = yield self.handler.on_backfill_request( + origin, room_id, versions, limit + ) + + defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) + + @defer.inlineCallbacks + @log_function + def on_incoming_transaction(self, transaction_data): + transaction = Transaction(**transaction_data) + + for p in transaction.pdus: + if "unsigned" in p: + unsigned = p["unsigned"] + if "age" in unsigned: + p["age"] = unsigned["age"] + if "age" in p: + p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) + del p["age"] + + pdu_list = [ + self.event_from_pdu_json(p) for p in transaction.pdus + ] + + logger.debug("[%s] Got transaction", transaction.transaction_id) + + response = yield self.transaction_actions.have_responded(transaction) + + if response: + logger.debug( + "[%s] We've already responed to this request", + transaction.transaction_id + ) + defer.returnValue(response) + return + + logger.debug("[%s] Transaction is new", transaction.transaction_id) + + with PreserveLoggingContext(): + dl = [] + for pdu in pdu_list: + dl.append(self._handle_new_pdu(transaction.origin, pdu)) + + if hasattr(transaction, "edus"): + for edu in [Edu(**x) for x in transaction.edus]: + self.received_edu( + transaction.origin, + edu.edu_type, + edu.content + ) + + results = yield defer.DeferredList(dl) + + ret = [] + for r in results: + if r[0]: + ret.append({}) + else: + logger.exception(r[1]) + ret.append({"error": str(r[1])}) + + logger.debug("Returning: %s", str(ret)) + + yield self.transaction_actions.set_response( + transaction, + 200, response + ) + defer.returnValue((200, response)) + + def received_edu(self, origin, edu_type, content): + if edu_type in self.edu_handlers: + self.edu_handlers[edu_type](origin, content) + else: + logger.warn("Received EDU of type %s with no handler", edu_type) + + @defer.inlineCallbacks + @log_function + def on_context_state_request(self, origin, room_id, event_id): + if event_id: + pdus = yield self.handler.get_state_for_pdu( + origin, room_id, event_id, + ) + auth_chain = yield self.store.get_auth_chain( + [pdu.event_id for pdu in pdus] + ) + else: + raise NotImplementedError("Specify an event") + + defer.returnValue((200, { + "pdus": [pdu.get_pdu_json() for pdu in pdus], + "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], + })) + + @defer.inlineCallbacks + @log_function + def on_pdu_request(self, origin, event_id): + pdu = yield self._get_persisted_pdu(origin, event_id) + + if pdu: + defer.returnValue( + (200, self._transaction_from_pdus([pdu]).get_dict()) + ) + else: + defer.returnValue((404, "")) + + @defer.inlineCallbacks + @log_function + def on_pull_request(self, origin, versions): + raise NotImplementedError("Pull transactions not implemented") + + @defer.inlineCallbacks + def on_query_request(self, query_type, args): + if query_type in self.query_handlers: + response = yield self.query_handlers[query_type](args) + defer.returnValue((200, response)) + else: + defer.returnValue( + (404, "No handler for Query type '%s'" % (query_type,)) + ) + + @defer.inlineCallbacks + def on_make_join_request(self, room_id, user_id): + pdu = yield self.handler.on_make_join_request(room_id, user_id) + time_now = self._clock.time_msec() + defer.returnValue({"event": pdu.get_pdu_json(time_now)}) + + @defer.inlineCallbacks + def on_invite_request(self, origin, content): + pdu = self.event_from_pdu_json(content) + ret_pdu = yield self.handler.on_invite_request(origin, pdu) + time_now = self._clock.time_msec() + defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) + + @defer.inlineCallbacks + def on_send_join_request(self, origin, content): + logger.debug("on_send_join_request: content: %s", content) + pdu = self.event_from_pdu_json(content) + logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) + res_pdus = yield self.handler.on_send_join_request(origin, pdu) + time_now = self._clock.time_msec() + defer.returnValue((200, { + "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], + "auth_chain": [ + p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] + ], + })) + + @defer.inlineCallbacks + def on_event_auth(self, origin, room_id, event_id): + time_now = self._clock.time_msec() + auth_pdus = yield self.handler.on_event_auth(event_id) + defer.returnValue((200, { + "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], + })) + + @defer.inlineCallbacks + def on_query_auth_request(self, origin, content, event_id): + """ + Content is a dict with keys:: + auth_chain (list): A list of events that give the auth chain. + missing (list): A list of event_ids indicating what the other + side (`origin`) think we're missing. + rejects (dict): A mapping from event_id to a 2-tuple of reason + string and a proof (or None) of why the event was rejected. + The keys of this dict give the list of events the `origin` has + rejected. + + Args: + origin (str) + content (dict) + event_id (str) + + Returns: + Deferred: Results in `dict` with the same format as `content` + """ + auth_chain = [ + (yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) + for e in content["auth_chain"] + ] + + missing = [ + (yield self._check_sigs_and_hash(self.event_from_pdu_json(e))) + for e in content.get("missing", []) + ] + + ret = yield self.handler.on_query_auth( + origin, event_id, auth_chain, content.get("rejects", []), missing + ) + + time_now = self._clock.time_msec() + send_content = { + "auth_chain": [ + e.get_pdu_json(time_now) + for e in ret["auth_chain"] + ], + "rejects": ret.get("rejects", []), + "missing": ret.get("missing", []), + } + + defer.returnValue( + (200, send_content) + ) + + @log_function + def _get_persisted_pdu(self, origin, event_id, do_auth=True): + """ Get a PDU from the database with given origin and id. + + Returns: + Deferred: Results in a `Pdu`. + """ + return self.handler.get_persisted_pdu( + origin, event_id, do_auth=do_auth + ) + + def _transaction_from_pdus(self, pdu_list): + """Returns a new Transaction containing the given PDUs suitable for + transmission. + """ + time_now = self._clock.time_msec() + pdus = [p.get_pdu_json(time_now) for p in pdu_list] + return Transaction( + origin=self.server_name, + pdus=pdus, + origin_server_ts=int(time_now), + destination=None, + ) + + @defer.inlineCallbacks + @log_function + def _handle_new_pdu(self, origin, pdu, max_recursion=10): + # We reprocess pdus when we have seen them only as outliers + existing = yield self._get_persisted_pdu( + origin, pdu.event_id, do_auth=False + ) + + # FIXME: Currently we fetch an event again when we already have it + # if it has been marked as an outlier. + + already_seen = ( + existing and ( + not existing.internal_metadata.is_outlier() + or pdu.internal_metadata.is_outlier() + ) + ) + if already_seen: + logger.debug("Already seen pdu %s", pdu.event_id) + defer.returnValue({}) + return + + # Check signature. + try: + pdu = yield self._check_sigs_and_hash(pdu) + except SynapseError as e: + raise FederationError( + "ERROR", + e.code, + e.msg, + affected=pdu.event_id, + ) + + state = None + + auth_chain = [] + + have_seen = yield self.store.have_events( + [ev for ev, _ in pdu.prev_events] + ) + + fetch_state = False + + # Get missing pdus if necessary. + if not pdu.internal_metadata.is_outlier(): + # We only backfill backwards to the min depth. + min_depth = yield self.handler.get_min_depth_for_context( + pdu.room_id + ) + + logger.debug( + "_handle_new_pdu min_depth for %s: %d", + pdu.room_id, min_depth + ) + + if min_depth and pdu.depth > min_depth and max_recursion > 0: + for event_id, hashes in pdu.prev_events: + if event_id not in have_seen: + logger.debug( + "_handle_new_pdu requesting pdu %s", + event_id + ) + + try: + new_pdu = yield self.federation_client.get_pdu( + [origin, pdu.origin], + event_id=event_id, + ) + + if new_pdu: + yield self._handle_new_pdu( + origin, + new_pdu, + max_recursion=max_recursion-1 + ) + + logger.debug("Processed pdu %s", event_id) + else: + logger.warn("Failed to get PDU %s", event_id) + fetch_state = True + except: + # TODO(erikj): Do some more intelligent retries. + logger.exception("Failed to get PDU") + fetch_state = True + else: + prevs = {e_id for e_id, _ in pdu.prev_events} + seen = set(have_seen.keys()) + if prevs - seen: + fetch_state = True + else: + fetch_state = True + + if fetch_state: + # We need to get the state at this event, since we haven't + # processed all the prev events. + logger.debug( + "_handle_new_pdu getting state for %s", + pdu.room_id + ) + state, auth_chain = yield self.get_state_for_room( + origin, pdu.room_id, pdu.event_id, + ) + + ret = yield self.handler.on_receive_pdu( + origin, + pdu, + backfilled=False, + state=state, + auth_chain=auth_chain, + ) + + defer.returnValue(ret) + + def __str__(self): + return "" % self.server_name + + def event_from_pdu_json(self, pdu_json, outlier=False): + event = FrozenEvent( + pdu_json + ) + + event.internal_metadata.outlier = outlier + + return event + + @defer.inlineCallbacks + def _check_sigs_and_hash(self, pdu): + """Throws a SynapseError if the PDU does not have the correct + signatures. + + Returns: + FrozenEvent: Either the given event or it redacted if it failed the + content hash check. + """ + # Check signatures are correct. + redacted_event = prune_event(pdu) + redacted_pdu_json = redacted_event.get_pdu_json() + + try: + yield self.keyring.verify_json_for_server( + pdu.origin, redacted_pdu_json + ) + except SynapseError: + logger.warn( + "Signature check failed for %s redacted to %s", + encode_canonical_json(pdu.get_pdu_json()), + encode_canonical_json(redacted_pdu_json), + ) + raise + + if not check_event_content_hash(pdu): + logger.warn( + "Event content has been tampered, redacting %s, %s", + pdu.event_id, encode_canonical_json(pdu.get_dict()) + ) + defer.returnValue(redacted_event) + + defer.returnValue(pdu) diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py index 6620532a60..e442c6c5d5 100644 --- a/synapse/federation/replication.py +++ b/synapse/federation/replication.py @@ -17,23 +17,20 @@ a given transport. """ -from twisted.internet import defer +from .federation_client import FederationClient +from .federation_server import FederationServer -from .units import Transaction, Edu +from .transaction_queue import TransactionQueue from .persistence import TransactionActions -from synapse.util.logutils import log_function -from synapse.util.logcontext import PreserveLoggingContext -from synapse.events import FrozenEvent - import logging logger = logging.getLogger(__name__) -class ReplicationLayer(object): +class ReplicationLayer(FederationClient, FederationServer): """This layer is responsible for replicating with remote home servers over the given transport. I.e., does the sending and receiving of PDUs to remote home servers. @@ -54,898 +51,26 @@ class ReplicationLayer(object): def __init__(self, hs, transport_layer): self.server_name = hs.hostname + self.keyring = hs.get_keyring() + self.transport_layer = transport_layer self.transport_layer.register_received_handler(self) self.transport_layer.register_request_handler(self) - self.store = hs.get_datastore() - # self.pdu_actions = PduActions(self.store) - self.transaction_actions = TransactionActions(self.store) + self.federation_client = self - self._transaction_queue = _TransactionQueue( - hs, self.transaction_actions, transport_layer - ) + self.store = hs.get_datastore() self.handler = None self.edu_handlers = {} self.query_handlers = {} - self._order = 0 - self._clock = hs.get_clock() - self.event_builder_factory = hs.get_event_builder_factory() + self.transaction_actions = TransactionActions(self.store) + self._transaction_queue = TransactionQueue(hs, transport_layer) - def set_handler(self, handler): - """Sets the handler that the replication layer will use to communicate - receipt of new PDUs from other home servers. The required methods are - documented on :py:class:`.ReplicationHandler`. - """ - self.handler = handler - - def register_edu_handler(self, edu_type, handler): - if edu_type in self.edu_handlers: - raise KeyError("Already have an EDU handler for %s" % (edu_type,)) - - self.edu_handlers[edu_type] = handler - - def register_query_handler(self, query_type, handler): - """Sets the handler callable that will be used to handle an incoming - federation Query of the given type. - - Args: - query_type (str): Category name of the query, which should match - the string used by make_query. - handler (callable): Invoked to handle incoming queries of this type - - handler is invoked as: - result = handler(args) - - where 'args' is a dict mapping strings to strings of the query - arguments. It should return a Deferred that will eventually yield an - object to encode as JSON. - """ - if query_type in self.query_handlers: - raise KeyError( - "Already have a Query handler for %s" % (query_type,) - ) - - self.query_handlers[query_type] = handler - - @log_function - def send_pdu(self, pdu, destinations): - """Informs the replication layer about a new PDU generated within the - home server that should be transmitted to others. - - TODO: Figure out when we should actually resolve the deferred. - - Args: - pdu (Pdu): The new Pdu. - - Returns: - Deferred: Completes when we have successfully processed the PDU - and replicated it to any interested remote home servers. - """ - order = self._order - self._order += 1 - - logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id) - - # TODO, add errback, etc. - self._transaction_queue.enqueue_pdu(pdu, destinations, order) - - logger.debug( - "[%s] transaction_layer.enqueue_pdu... done", - pdu.event_id - ) - - @log_function - def send_edu(self, destination, edu_type, content): - edu = Edu( - origin=self.server_name, - destination=destination, - edu_type=edu_type, - content=content, - ) - - # TODO, add errback, etc. - self._transaction_queue.enqueue_edu(edu) - return defer.succeed(None) - - @log_function - def send_failure(self, failure, destination): - self._transaction_queue.enqueue_failure(failure, destination) - return defer.succeed(None) - - @log_function - def make_query(self, destination, query_type, args, - retry_on_dns_fail=True): - """Sends a federation Query to a remote homeserver of the given type - and arguments. - - Args: - destination (str): Domain name of the remote homeserver - query_type (str): Category of the query type; should match the - handler name used in register_query_handler(). - args (dict): Mapping of strings to strings containing the details - of the query request. - - Returns: - a Deferred which will eventually yield a JSON object from the - response - """ - return self.transport_layer.make_query( - destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail - ) - - @defer.inlineCallbacks - @log_function - def backfill(self, dest, context, limit, extremities): - """Requests some more historic PDUs for the given context from the - given destination server. - - Args: - dest (str): The remote home server to ask. - context (str): The context to backfill. - limit (int): The maximum number of PDUs to return. - extremities (list): List of PDU id and origins of the first pdus - we have seen from the context - - Returns: - Deferred: Results in the received PDUs. - """ - logger.debug("backfill extrem=%s", extremities) - - # If there are no extremeties then we've (probably) reached the start. - if not extremities: - return - - transaction_data = yield self.transport_layer.backfill( - dest, context, extremities, limit) - - logger.debug("backfill transaction_data=%s", repr(transaction_data)) - - transaction = Transaction(**transaction_data) - - pdus = [ - self.event_from_pdu_json(p, outlier=False) - for p in transaction.pdus - ] - for pdu in pdus: - yield self._handle_new_pdu(dest, pdu, backfilled=True) - - defer.returnValue(pdus) - - @defer.inlineCallbacks - @log_function - def get_pdu(self, destination, event_id, outlier=False): - """Requests the PDU with given origin and ID from the remote home - server. - - This will persist the PDU locally upon receipt. - - Args: - destination (str): Which home server to query - pdu_origin (str): The home server that originally sent the pdu. - event_id (str) - outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if - it's from an arbitary point in the context as opposed to part - of the current block of PDUs. Defaults to `False` - - Returns: - Deferred: Results in the requested PDU. - """ - - transaction_data = yield self.transport_layer.get_event( - destination, event_id - ) - - transaction = Transaction(**transaction_data) - - pdu_list = [ - self.event_from_pdu_json(p, outlier=outlier) - for p in transaction.pdus - ] - - pdu = None - if pdu_list: - pdu = pdu_list[0] - yield self._handle_new_pdu(destination, pdu) - - defer.returnValue(pdu) - - @defer.inlineCallbacks - @log_function - def get_state_for_room(self, destination, room_id, event_id): - """Requests all of the `current` state PDUs for a given room from - a remote home server. - - Args: - destination (str): The remote homeserver to query for the state. - room_id (str): The id of the room we're interested in. - event_id (str): The id of the event we want the state at. - - Returns: - Deferred: Results in a list of PDUs. - """ - - result = yield self.transport_layer.get_room_state( - destination, room_id, event_id=event_id, - ) - - pdus = [ - self.event_from_pdu_json(p, outlier=True) for p in result["pdus"] - ] - - auth_chain = [ - self.event_from_pdu_json(p, outlier=True) - for p in result.get("auth_chain", []) - ] - - defer.returnValue((pdus, auth_chain)) - - @defer.inlineCallbacks - @log_function - def get_event_auth(self, destination, room_id, event_id): - res = yield self.transport_layer.get_event_auth( - destination, room_id, event_id, - ) - - auth_chain = [ - self.event_from_pdu_json(p, outlier=True) - for p in res["auth_chain"] - ] - - auth_chain.sort(key=lambda e: e.depth) - - defer.returnValue(auth_chain) - - @defer.inlineCallbacks - @log_function - def on_backfill_request(self, origin, room_id, versions, limit): - pdus = yield self.handler.on_backfill_request( - origin, room_id, versions, limit - ) - - defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) - - @defer.inlineCallbacks - @log_function - def on_incoming_transaction(self, transaction_data): - transaction = Transaction(**transaction_data) - - for p in transaction.pdus: - if "unsigned" in p: - unsigned = p["unsigned"] - if "age" in unsigned: - p["age"] = unsigned["age"] - if "age" in p: - p["age_ts"] = int(self._clock.time_msec()) - int(p["age"]) - del p["age"] - - pdu_list = [ - self.event_from_pdu_json(p) for p in transaction.pdus - ] - - logger.debug("[%s] Got transaction", transaction.transaction_id) - - response = yield self.transaction_actions.have_responded(transaction) - - if response: - logger.debug("[%s] We've already responed to this request", - transaction.transaction_id) - defer.returnValue(response) - return - - logger.debug("[%s] Transaction is new", transaction.transaction_id) - - with PreserveLoggingContext(): - dl = [] - for pdu in pdu_list: - dl.append(self._handle_new_pdu(transaction.origin, pdu)) - - if hasattr(transaction, "edus"): - for edu in [Edu(**x) for x in transaction.edus]: - self.received_edu( - transaction.origin, - edu.edu_type, - edu.content - ) - - results = yield defer.DeferredList(dl) - - ret = [] - for r in results: - if r[0]: - ret.append({}) - else: - logger.exception(r[1]) - ret.append({"error": str(r[1])}) - - logger.debug("Returning: %s", str(ret)) - - yield self.transaction_actions.set_response( - transaction, - 200, response - ) - defer.returnValue((200, response)) - - def received_edu(self, origin, edu_type, content): - if edu_type in self.edu_handlers: - self.edu_handlers[edu_type](origin, content) - else: - logger.warn("Received EDU of type %s with no handler", edu_type) - - @defer.inlineCallbacks - @log_function - def on_context_state_request(self, origin, room_id, event_id): - if event_id: - pdus = yield self.handler.get_state_for_pdu( - origin, room_id, event_id, - ) - auth_chain = yield self.store.get_auth_chain( - [pdu.event_id for pdu in pdus] - ) - else: - raise NotImplementedError("Specify an event") - - defer.returnValue((200, { - "pdus": [pdu.get_pdu_json() for pdu in pdus], - "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], - })) - - @defer.inlineCallbacks - @log_function - def on_pdu_request(self, origin, event_id): - pdu = yield self._get_persisted_pdu(origin, event_id) - - if pdu: - defer.returnValue( - (200, self._transaction_from_pdus([pdu]).get_dict()) - ) - else: - defer.returnValue((404, "")) - - @defer.inlineCallbacks - @log_function - def on_pull_request(self, origin, versions): - raise NotImplementedError("Pull transactions not implemented") - - @defer.inlineCallbacks - def on_query_request(self, query_type, args): - if query_type in self.query_handlers: - response = yield self.query_handlers[query_type](args) - defer.returnValue((200, response)) - else: - defer.returnValue( - (404, "No handler for Query type '%s'" % (query_type,)) - ) - - @defer.inlineCallbacks - def on_make_join_request(self, room_id, user_id): - pdu = yield self.handler.on_make_join_request(room_id, user_id) - time_now = self._clock.time_msec() - defer.returnValue({"event": pdu.get_pdu_json(time_now)}) - - @defer.inlineCallbacks - def on_invite_request(self, origin, content): - pdu = self.event_from_pdu_json(content) - ret_pdu = yield self.handler.on_invite_request(origin, pdu) - time_now = self._clock.time_msec() - defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)})) - - @defer.inlineCallbacks - def on_send_join_request(self, origin, content): - logger.debug("on_send_join_request: content: %s", content) - pdu = self.event_from_pdu_json(content) - logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - res_pdus = yield self.handler.on_send_join_request(origin, pdu) - time_now = self._clock.time_msec() - defer.returnValue((200, { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [ - p.get_pdu_json(time_now) for p in res_pdus["auth_chain"] - ], - })) - - @defer.inlineCallbacks - def on_event_auth(self, origin, room_id, event_id): - time_now = self._clock.time_msec() - auth_pdus = yield self.handler.on_event_auth(event_id) - defer.returnValue((200, { - "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], - })) - - @defer.inlineCallbacks - def make_join(self, destination, room_id, user_id): - ret = yield self.transport_layer.make_join( - destination, room_id, user_id - ) - - pdu_dict = ret["event"] - - logger.debug("Got response to make_join: %s", pdu_dict) - - defer.returnValue(self.event_from_pdu_json(pdu_dict)) - - @defer.inlineCallbacks - def send_join(self, destination, pdu): - time_now = self._clock.time_msec() - _, content = yield self.transport_layer.send_join( - destination=destination, - room_id=pdu.room_id, - event_id=pdu.event_id, - content=pdu.get_pdu_json(time_now), - ) - - logger.debug("Got content: %s", content) - - state = [ - self.event_from_pdu_json(p, outlier=True) - for p in content.get("state", []) - ] - - auth_chain = [ - self.event_from_pdu_json(p, outlier=True) - for p in content.get("auth_chain", []) - ] - - auth_chain.sort(key=lambda e: e.depth) - - defer.returnValue({ - "state": state, - "auth_chain": auth_chain, - }) - - @defer.inlineCallbacks - def send_invite(self, destination, room_id, event_id, pdu): - time_now = self._clock.time_msec() - code, content = yield self.transport_layer.send_invite( - destination=destination, - room_id=room_id, - event_id=event_id, - content=pdu.get_pdu_json(time_now), - ) - - pdu_dict = content["event"] - - logger.debug("Got response to send_invite: %s", pdu_dict) - - defer.returnValue(self.event_from_pdu_json(pdu_dict)) - - @log_function - def _get_persisted_pdu(self, origin, event_id, do_auth=True): - """ Get a PDU from the database with given origin and id. - - Returns: - Deferred: Results in a `Pdu`. - """ - return self.handler.get_persisted_pdu( - origin, event_id, do_auth=do_auth - ) - - def _transaction_from_pdus(self, pdu_list): - """Returns a new Transaction containing the given PDUs suitable for - transmission. - """ - time_now = self._clock.time_msec() - pdus = [p.get_pdu_json(time_now) for p in pdu_list] - return Transaction( - origin=self.server_name, - pdus=pdus, - origin_server_ts=int(time_now), - destination=None, - ) - - @defer.inlineCallbacks - @log_function - def _handle_new_pdu(self, origin, pdu, backfilled=False): - # We reprocess pdus when we have seen them only as outliers - existing = yield self._get_persisted_pdu( - origin, pdu.event_id, do_auth=False - ) - - already_seen = ( - existing and ( - not existing.internal_metadata.is_outlier() - or pdu.internal_metadata.is_outlier() - ) - ) - if already_seen: - logger.debug("Already seen pdu %s", pdu.event_id) - defer.returnValue({}) - return - - state = None - - auth_chain = [] - - # We need to make sure we have all the auth events. - # for e_id, _ in pdu.auth_events: - # exists = yield self._get_persisted_pdu( - # origin, - # e_id, - # do_auth=False - # ) - # - # if not exists: - # try: - # logger.debug( - # "_handle_new_pdu fetch missing auth event %s from %s", - # e_id, - # origin, - # ) - # - # yield self.get_pdu( - # origin, - # event_id=e_id, - # outlier=True, - # ) - # - # logger.debug("Processed pdu %s", e_id) - # except: - # logger.warn( - # "Failed to get auth event %s from %s", - # e_id, - # origin - # ) - - # Get missing pdus if necessary. - if not pdu.internal_metadata.is_outlier(): - # We only backfill backwards to the min depth. - min_depth = yield self.handler.get_min_depth_for_context( - pdu.room_id - ) - - logger.debug( - "_handle_new_pdu min_depth for %s: %d", - pdu.room_id, min_depth - ) - - if min_depth and pdu.depth > min_depth: - for event_id, hashes in pdu.prev_events: - exists = yield self._get_persisted_pdu( - origin, - event_id, - do_auth=False - ) - - if not exists: - logger.debug( - "_handle_new_pdu requesting pdu %s", - event_id - ) - - try: - yield self.get_pdu( - origin, - event_id=event_id, - ) - logger.debug("Processed pdu %s", event_id) - except: - # TODO(erikj): Do some more intelligent retries. - logger.exception("Failed to get PDU") - else: - # We need to get the state at this event, since we have reached - # a backward extremity edge. - logger.debug( - "_handle_new_pdu getting state for %s", - pdu.room_id - ) - state, auth_chain = yield self.get_state_for_room( - origin, pdu.room_id, pdu.event_id, - ) - - if not backfilled: - ret = yield self.handler.on_receive_pdu( - origin, - pdu, - backfilled=backfilled, - state=state, - auth_chain=auth_chain, - ) - else: - ret = None - - # yield self.pdu_actions.mark_as_processed(pdu) - - defer.returnValue(ret) + self._order = 0 def __str__(self): return "" % self.server_name - - def event_from_pdu_json(self, pdu_json, outlier=False): - event = FrozenEvent( - pdu_json - ) - - event.internal_metadata.outlier = outlier - - return event - - -class _TransactionQueue(object): - """This class makes sure we only have one transaction in flight at - a time for a given destination. - - It batches pending PDUs into single transactions. - """ - - def __init__(self, hs, transaction_actions, transport_layer): - self.server_name = hs.hostname - self.transaction_actions = transaction_actions - self.transport_layer = transport_layer - - self._clock = hs.get_clock() - self.store = hs.get_datastore() - - # Is a mapping from destinations -> deferreds. Used to keep track - # of which destinations have transactions in flight and when they are - # done - self.pending_transactions = {} - - # Is a mapping from destination -> list of - # tuple(pending pdus, deferred, order) - self.pending_pdus_by_dest = {} - # destination -> list of tuple(edu, deferred) - self.pending_edus_by_dest = {} - - # destination -> list of tuple(failure, deferred) - self.pending_failures_by_dest = {} - - # HACK to get unique tx id - self._next_txn_id = int(self._clock.time_msec()) - - @defer.inlineCallbacks - @log_function - def enqueue_pdu(self, pdu, destinations, order): - # We loop through all destinations to see whether we already have - # a transaction in progress. If we do, stick it in the pending_pdus - # table and we'll get back to it later. - - destinations = set(destinations) - destinations.discard(self.server_name) - destinations.discard("localhost") - - logger.debug("Sending to: %s", str(destinations)) - - if not destinations: - return - - deferreds = [] - - for destination in destinations: - deferred = defer.Deferred() - self.pending_pdus_by_dest.setdefault(destination, []).append( - (pdu, deferred, order) - ) - - def eb(failure): - if not deferred.called: - deferred.errback(failure) - else: - logger.warn("Failed to send pdu", failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(eb) - - deferreds.append(deferred) - - yield defer.DeferredList(deferreds) - - # NO inlineCallbacks - def enqueue_edu(self, edu): - destination = edu.destination - - if destination == self.server_name: - return - - deferred = defer.Deferred() - self.pending_edus_by_dest.setdefault(destination, []).append( - (edu, deferred) - ) - - def eb(failure): - if not deferred.called: - deferred.errback(failure) - else: - logger.warn("Failed to send edu", failure) - - with PreserveLoggingContext(): - self._attempt_new_transaction(destination).addErrback(eb) - - return deferred - - @defer.inlineCallbacks - def enqueue_failure(self, failure, destination): - deferred = defer.Deferred() - - self.pending_failures_by_dest.setdefault( - destination, [] - ).append( - (failure, deferred) - ) - - yield deferred - - @defer.inlineCallbacks - @log_function - def _attempt_new_transaction(self, destination): - - (retry_last_ts, retry_interval) = (0, 0) - retry_timings = yield self.store.get_destination_retry_timings( - destination - ) - if retry_timings: - (retry_last_ts, retry_interval) = ( - retry_timings.retry_last_ts, retry_timings.retry_interval - ) - if retry_last_ts + retry_interval > int(self._clock.time_msec()): - logger.info( - "TX [%s] not ready for retry yet - " - "dropping transaction for now", - destination, - ) - return - else: - logger.info("TX [%s] is ready for retry", destination) - - logger.info("TX [%s] _attempt_new_transaction", destination) - - if destination in self.pending_transactions: - # XXX: pending_transactions can get stuck on by a never-ending - # request at which point pending_pdus_by_dest just keeps growing. - # we need application-layer timeouts of some flavour of these - # requests - return - - # list of (pending_pdu, deferred, order) - pending_pdus = self.pending_pdus_by_dest.pop(destination, []) - pending_edus = self.pending_edus_by_dest.pop(destination, []) - pending_failures = self.pending_failures_by_dest.pop(destination, []) - - if pending_pdus: - logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d", - destination, len(pending_pdus)) - - if not pending_pdus and not pending_edus and not pending_failures: - return - - logger.debug( - "TX [%s] Attempting new transaction" - " (pdus: %d, edus: %d, failures: %d)", - destination, - len(pending_pdus), - len(pending_edus), - len(pending_failures) - ) - - # Sort based on the order field - pending_pdus.sort(key=lambda t: t[2]) - - pdus = [x[0] for x in pending_pdus] - edus = [x[0] for x in pending_edus] - failures = [x[0].get_dict() for x in pending_failures] - deferreds = [ - x[1] - for x in pending_pdus + pending_edus + pending_failures - ] - - try: - self.pending_transactions[destination] = 1 - - logger.debug("TX [%s] Persisting transaction...", destination) - - transaction = Transaction.create_new( - origin_server_ts=int(self._clock.time_msec()), - transaction_id=str(self._next_txn_id), - origin=self.server_name, - destination=destination, - pdus=pdus, - edus=edus, - pdu_failures=failures, - ) - - self._next_txn_id += 1 - - yield self.transaction_actions.prepare_to_send(transaction) - - logger.debug("TX [%s] Persisted transaction", destination) - logger.info( - "TX [%s] Sending transaction [%s]", - destination, - transaction.transaction_id, - ) - - # Actually send the transaction - - # FIXME (erikj): This is a bit of a hack to make the Pdu age - # keys work - def json_data_cb(): - data = transaction.get_dict() - now = int(self._clock.time_msec()) - if "pdus" in data: - for p in data["pdus"]: - if "age_ts" in p: - unsigned = p.setdefault("unsigned", {}) - unsigned["age"] = now - int(p["age_ts"]) - del p["age_ts"] - return data - - code, response = yield self.transport_layer.send_transaction( - transaction, json_data_cb - ) - - logger.info("TX [%s] got %d response", destination, code) - - logger.debug("TX [%s] Sent transaction", destination) - logger.debug("TX [%s] Marking as delivered...", destination) - - yield self.transaction_actions.delivered( - transaction, code, response - ) - - logger.debug("TX [%s] Marked as delivered", destination) - logger.debug("TX [%s] Yielding to callbacks...", destination) - - for deferred in deferreds: - if code == 200: - if retry_last_ts: - # this host is alive! reset retry schedule - yield self.store.set_destination_retry_timings( - destination, 0, 0 - ) - deferred.callback(None) - else: - self.set_retrying(destination, retry_interval) - deferred.errback(RuntimeError("Got status %d" % code)) - - # Ensures we don't continue until all callbacks on that - # deferred have fired - try: - yield deferred - except: - pass - - logger.debug("TX [%s] Yielded to callbacks", destination) - - except Exception as e: - # We capture this here as there as nothing actually listens - # for this finishing functions deferred. - logger.warn( - "TX [%s] Problem in _attempt_transaction: %s", - destination, - e, - ) - - self.set_retrying(destination, retry_interval) - - for deferred in deferreds: - if not deferred.called: - deferred.errback(e) - - finally: - # We want to be *very* sure we delete this after we stop processing - self.pending_transactions.pop(destination, None) - - # Check to see if there is anything else to send. - self._attempt_new_transaction(destination) - - @defer.inlineCallbacks - def set_retrying(self, destination, retry_interval): - # track that this destination is having problems and we should - # give it a chance to recover before trying it again - - if retry_interval: - retry_interval *= 2 - # plateau at hourly retries for now - if retry_interval >= 60 * 60 * 1000: - retry_interval = 60 * 60 * 1000 - else: - retry_interval = 2000 # try again at first after 2 seconds - - yield self.store.set_destination_retry_timings( - destination, - int(self._clock.time_msec()), - retry_interval - ) diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py new file mode 100644 index 0000000000..9d4f2c09a2 --- /dev/null +++ b/synapse/federation/transaction_queue.py @@ -0,0 +1,317 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from twisted.internet import defer + +from .persistence import TransactionActions +from .units import Transaction + +from synapse.util.logutils import log_function +from synapse.util.logcontext import PreserveLoggingContext + +import logging + + +logger = logging.getLogger(__name__) + + +class TransactionQueue(object): + """This class makes sure we only have one transaction in flight at + a time for a given destination. + + It batches pending PDUs into single transactions. + """ + + def __init__(self, hs, transport_layer): + self.server_name = hs.hostname + + self.store = hs.get_datastore() + self.transaction_actions = TransactionActions(self.store) + + self.transport_layer = transport_layer + + self._clock = hs.get_clock() + + # Is a mapping from destinations -> deferreds. Used to keep track + # of which destinations have transactions in flight and when they are + # done + self.pending_transactions = {} + + # Is a mapping from destination -> list of + # tuple(pending pdus, deferred, order) + self.pending_pdus_by_dest = {} + # destination -> list of tuple(edu, deferred) + self.pending_edus_by_dest = {} + + # destination -> list of tuple(failure, deferred) + self.pending_failures_by_dest = {} + + # HACK to get unique tx id + self._next_txn_id = int(self._clock.time_msec()) + + @defer.inlineCallbacks + @log_function + def enqueue_pdu(self, pdu, destinations, order): + # We loop through all destinations to see whether we already have + # a transaction in progress. If we do, stick it in the pending_pdus + # table and we'll get back to it later. + + destinations = set(destinations) + destinations.discard(self.server_name) + destinations.discard("localhost") + + logger.debug("Sending to: %s", str(destinations)) + + if not destinations: + return + + deferreds = [] + + for destination in destinations: + deferred = defer.Deferred() + self.pending_pdus_by_dest.setdefault(destination, []).append( + (pdu, deferred, order) + ) + + def eb(failure): + if not deferred.called: + deferred.errback(failure) + else: + logger.warn("Failed to send pdu", failure) + + with PreserveLoggingContext(): + self._attempt_new_transaction(destination).addErrback(eb) + + deferreds.append(deferred) + + yield defer.DeferredList(deferreds) + + # NO inlineCallbacks + def enqueue_edu(self, edu): + destination = edu.destination + + if destination == self.server_name: + return + + deferred = defer.Deferred() + self.pending_edus_by_dest.setdefault(destination, []).append( + (edu, deferred) + ) + + def eb(failure): + if not deferred.called: + deferred.errback(failure) + else: + logger.warn("Failed to send edu", failure) + + with PreserveLoggingContext(): + self._attempt_new_transaction(destination).addErrback(eb) + + return deferred + + @defer.inlineCallbacks + def enqueue_failure(self, failure, destination): + deferred = defer.Deferred() + + self.pending_failures_by_dest.setdefault( + destination, [] + ).append( + (failure, deferred) + ) + + yield deferred + + @defer.inlineCallbacks + @log_function + def _attempt_new_transaction(self, destination): + + (retry_last_ts, retry_interval) = (0, 0) + retry_timings = yield self.store.get_destination_retry_timings( + destination + ) + if retry_timings: + (retry_last_ts, retry_interval) = ( + retry_timings.retry_last_ts, retry_timings.retry_interval + ) + if retry_last_ts + retry_interval > int(self._clock.time_msec()): + logger.info( + "TX [%s] not ready for retry yet - " + "dropping transaction for now", + destination, + ) + return + else: + logger.info("TX [%s] is ready for retry", destination) + + logger.info("TX [%s] _attempt_new_transaction", destination) + + if destination in self.pending_transactions: + # XXX: pending_transactions can get stuck on by a never-ending + # request at which point pending_pdus_by_dest just keeps growing. + # we need application-layer timeouts of some flavour of these + # requests + return + + # list of (pending_pdu, deferred, order) + pending_pdus = self.pending_pdus_by_dest.pop(destination, []) + pending_edus = self.pending_edus_by_dest.pop(destination, []) + pending_failures = self.pending_failures_by_dest.pop(destination, []) + + if pending_pdus: + logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d", + destination, len(pending_pdus)) + + if not pending_pdus and not pending_edus and not pending_failures: + return + + logger.debug( + "TX [%s] Attempting new transaction" + " (pdus: %d, edus: %d, failures: %d)", + destination, + len(pending_pdus), + len(pending_edus), + len(pending_failures) + ) + + # Sort based on the order field + pending_pdus.sort(key=lambda t: t[2]) + + pdus = [x[0] for x in pending_pdus] + edus = [x[0] for x in pending_edus] + failures = [x[0].get_dict() for x in pending_failures] + deferreds = [ + x[1] + for x in pending_pdus + pending_edus + pending_failures + ] + + try: + self.pending_transactions[destination] = 1 + + logger.debug("TX [%s] Persisting transaction...", destination) + + transaction = Transaction.create_new( + origin_server_ts=int(self._clock.time_msec()), + transaction_id=str(self._next_txn_id), + origin=self.server_name, + destination=destination, + pdus=pdus, + edus=edus, + pdu_failures=failures, + ) + + self._next_txn_id += 1 + + yield self.transaction_actions.prepare_to_send(transaction) + + logger.debug("TX [%s] Persisted transaction", destination) + logger.info( + "TX [%s] Sending transaction [%s]", + destination, + transaction.transaction_id, + ) + + # Actually send the transaction + + # FIXME (erikj): This is a bit of a hack to make the Pdu age + # keys work + def json_data_cb(): + data = transaction.get_dict() + now = int(self._clock.time_msec()) + if "pdus" in data: + for p in data["pdus"]: + if "age_ts" in p: + unsigned = p.setdefault("unsigned", {}) + unsigned["age"] = now - int(p["age_ts"]) + del p["age_ts"] + return data + + code, response = yield self.transport_layer.send_transaction( + transaction, json_data_cb + ) + + logger.info("TX [%s] got %d response", destination, code) + + logger.debug("TX [%s] Sent transaction", destination) + logger.debug("TX [%s] Marking as delivered...", destination) + + yield self.transaction_actions.delivered( + transaction, code, response + ) + + logger.debug("TX [%s] Marked as delivered", destination) + logger.debug("TX [%s] Yielding to callbacks...", destination) + + for deferred in deferreds: + if code == 200: + if retry_last_ts: + # this host is alive! reset retry schedule + yield self.store.set_destination_retry_timings( + destination, 0, 0 + ) + deferred.callback(None) + else: + self.set_retrying(destination, retry_interval) + deferred.errback(RuntimeError("Got status %d" % code)) + + # Ensures we don't continue until all callbacks on that + # deferred have fired + try: + yield deferred + except: + pass + + logger.debug("TX [%s] Yielded to callbacks", destination) + + except Exception as e: + # We capture this here as there as nothing actually listens + # for this finishing functions deferred. + logger.warn( + "TX [%s] Problem in _attempt_transaction: %s", + destination, + e, + ) + + self.set_retrying(destination, retry_interval) + + for deferred in deferreds: + if not deferred.called: + deferred.errback(e) + + finally: + # We want to be *very* sure we delete this after we stop processing + self.pending_transactions.pop(destination, None) + + # Check to see if there is anything else to send. + self._attempt_new_transaction(destination) + + @defer.inlineCallbacks + def set_retrying(self, destination, retry_interval): + # track that this destination is having problems and we should + # give it a chance to recover before trying it again + + if retry_interval: + retry_interval *= 2 + # plateau at hourly retries for now + if retry_interval >= 60 * 60 * 1000: + retry_interval = 60 * 60 * 1000 + else: + retry_interval = 2000 # try again at first after 2 seconds + + yield self.store.set_destination_retry_timings( + destination, + int(self._clock.time_msec()), + retry_interval + ) diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index e634a3a213..4cb1dea2de 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -213,3 +213,19 @@ class TransportLayerClient(object): ) defer.returnValue(response) + + @defer.inlineCallbacks + @log_function + def send_query_auth(self, destination, room_id, event_id, content): + path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id) + + code, content = yield self.client.post_json( + destination=destination, + path=path, + data=content, + ) + + if not 200 <= code < 300: + raise RuntimeError("Got %d from send_invite", code) + + defer.returnValue(json.loads(content)) diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index a380a6910b..9c9f8d525b 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -42,7 +42,7 @@ class TransportLayerServer(object): content = None origin = None - if request.method == "PUT": + if request.method in ["PUT", "POST"]: # TODO: Handle other method types? other content types? try: content_bytes = request.content.read() @@ -234,6 +234,16 @@ class TransportLayerServer(object): ) ) ) + self.server.register_path( + "POST", + re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"), + self._with_authentication( + lambda origin, content, query, context, event_id: + self._on_query_auth_request( + origin, content, event_id, + ) + ) + ) @defer.inlineCallbacks @log_function @@ -325,3 +335,12 @@ class TransportLayerServer(object): ) defer.returnValue((200, content)) + + @defer.inlineCallbacks + @log_function + def _on_query_auth_request(self, origin, content, event_id): + new_content = yield self.request_handler.on_query_auth_request( + origin, content, event_id + ) + + defer.returnValue((200, new_content)) diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py index 96a9b143ca..b31518bf62 100644 --- a/synapse/handlers/__init__.py +++ b/synapse/handlers/__init__.py @@ -27,6 +27,7 @@ from .directory import DirectoryHandler from .typing import TypingNotificationHandler from .admin import AdminHandler from .appservice import ApplicationServicesHandler +from .sync import SyncHandler class Handlers(object): @@ -53,3 +54,4 @@ class Handlers(object): self.typing_notification_handler = TypingNotificationHandler(hs) self.admin_handler = AdminHandler(hs) self.appservice_handler = ApplicationServicesHandler(hs) + self.sync_handler = SyncHandler(hs) diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py index d997917cd6..025e7e7e62 100644 --- a/synapse/handlers/events.py +++ b/synapse/handlers/events.py @@ -49,24 +49,25 @@ class EventStreamHandler(BaseHandler): @defer.inlineCallbacks @log_function def get_stream(self, auth_user_id, pagin_config, timeout=0, - as_client_event=True): + as_client_event=True, affect_presence=True): auth_user = UserID.from_string(auth_user_id) try: - if auth_user not in self._streams_per_user: - self._streams_per_user[auth_user] = 0 - if auth_user in self._stop_timer_per_user: - try: - self.clock.cancel_call_later( - self._stop_timer_per_user.pop(auth_user) + if affect_presence: + if auth_user not in self._streams_per_user: + self._streams_per_user[auth_user] = 0 + if auth_user in self._stop_timer_per_user: + try: + self.clock.cancel_call_later( + self._stop_timer_per_user.pop(auth_user) + ) + except: + logger.exception("Failed to cancel event timer") + else: + yield self.distributor.fire( + "started_user_eventstream", auth_user ) - except: - logger.exception("Failed to cancel event timer") - else: - yield self.distributor.fire( - "started_user_eventstream", auth_user - ) - self._streams_per_user[auth_user] += 1 + self._streams_per_user[auth_user] += 1 if pagin_config.from_token is None: pagin_config.from_token = None @@ -94,28 +95,29 @@ class EventStreamHandler(BaseHandler): defer.returnValue(chunk) finally: - self._streams_per_user[auth_user] -= 1 - if not self._streams_per_user[auth_user]: - del self._streams_per_user[auth_user] + if affect_presence: + self._streams_per_user[auth_user] -= 1 + if not self._streams_per_user[auth_user]: + del self._streams_per_user[auth_user] - # 10 seconds of grace to allow the client to reconnect again - # before we think they're gone - def _later(): - logger.debug( - "_later stopped_user_eventstream %s", auth_user + # 10 seconds of grace to allow the client to reconnect again + # before we think they're gone + def _later(): + logger.debug( + "_later stopped_user_eventstream %s", auth_user + ) + + self._stop_timer_per_user.pop(auth_user, None) + + return self.distributor.fire( + "stopped_user_eventstream", auth_user + ) + + logger.debug("Scheduling _later: for %s", auth_user) + self._stop_timer_per_user[auth_user] = ( + self.clock.call_later(30, _later) ) - self._stop_timer_per_user.pop(auth_user, None) - - yield self.distributor.fire( - "stopped_user_eventstream", auth_user - ) - - logger.debug("Scheduling _later: for %s", auth_user) - self._stop_timer_per_user[auth_user] = ( - self.clock.call_later(30, _later) - ) - class EventHandler(BaseHandler): diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index bcdcc90a18..8bf5a4cc11 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -17,19 +17,16 @@ from ._base import BaseHandler -from synapse.events.utils import prune_event from synapse.api.errors import ( - AuthError, FederationError, SynapseError, StoreError, + AuthError, FederationError, StoreError, ) -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventTypes, Membership, RejectedReason from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.crypto.event_signing import ( - compute_event_signature, check_event_content_hash, - add_hashes_and_signatures, + compute_event_signature, add_hashes_and_signatures, ) from synapse.types import UserID -from syutil.jsonutil import encode_canonical_json from twisted.internet import defer @@ -113,33 +110,6 @@ class FederationHandler(BaseHandler): logger.debug("Processing event: %s", event.event_id) - redacted_event = prune_event(event) - - redacted_pdu_json = redacted_event.get_pdu_json() - try: - yield self.keyring.verify_json_for_server( - event.origin, redacted_pdu_json - ) - except SynapseError as e: - logger.warn( - "Signature check failed for %s redacted to %s", - encode_canonical_json(pdu.get_pdu_json()), - encode_canonical_json(redacted_pdu_json), - ) - raise FederationError( - "ERROR", - e.code, - e.msg, - affected=event.event_id, - ) - - if not check_event_content_hash(event): - logger.warn( - "Event content has been tampered, redacting %s, %s", - event.event_id, encode_canonical_json(event.get_dict()) - ) - event = redacted_event - logger.debug("Event: %s", event) # FIXME (erikj): Awful hack to make the case where we are not currently @@ -149,41 +119,20 @@ class FederationHandler(BaseHandler): event.room_id, self.server_name ) - if not is_in_room and not event.internal_metadata.outlier: + if not is_in_room and not event.internal_metadata.is_outlier(): logger.debug("Got event for room we're not in.") - - replication = self.replication_layer - - if not state: - state, auth_chain = yield replication.get_state_for_room( - origin, context=event.room_id, event_id=event.event_id, - ) - - if not auth_chain: - auth_chain = yield replication.get_event_auth( - origin, - context=event.room_id, - event_id=event.event_id, - ) - - for e in auth_chain: - e.internal_metadata.outlier = True - try: - yield self._handle_new_event(e, fetch_auth_from=origin) - except: - logger.exception( - "Failed to handle auth event %s", - e.event_id, - ) - current_state = state - if state: + if state and auth_chain is not None: for e in state: - logging.info("A :) %r", e) e.internal_metadata.outlier = True try: - yield self._handle_new_event(e) + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } + yield self._handle_new_event(origin, e, auth_events=auth) except: logger.exception( "Failed to handle state event %s", @@ -192,6 +141,7 @@ class FederationHandler(BaseHandler): try: yield self._handle_new_event( + origin, event, state=state, backfilled=backfilled, @@ -393,8 +343,19 @@ class FederationHandler(BaseHandler): for e in auth_chain: e.internal_metadata.outlier = True + + if e.event_id == event.event_id: + continue + try: - yield self._handle_new_event(e) + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } + yield self._handle_new_event( + target_host, e, auth_events=auth + ) except: logger.exception( "Failed to handle auth event %s", @@ -402,11 +363,18 @@ class FederationHandler(BaseHandler): ) for e in state: - # FIXME: Auth these. + if e.event_id == event.event_id: + continue + e.internal_metadata.outlier = True try: + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } yield self._handle_new_event( - e, fetch_auth_from=target_host + target_host, e, auth_events=auth ) except: logger.exception( @@ -414,10 +382,18 @@ class FederationHandler(BaseHandler): e.event_id, ) + auth_ids = [e_id for e_id, _ in event.auth_events] + auth_events = { + (e.type, e.state_key): e for e in auth_chain + if e.event_id in auth_ids + } + yield self._handle_new_event( + target_host, new_event, state=state, current_state=state, + auth_events=auth_events, ) yield self.notifier.on_new_room_event( @@ -481,7 +457,7 @@ class FederationHandler(BaseHandler): event.internal_metadata.outlier = False - context = yield self._handle_new_event(event) + context = yield self._handle_new_event(origin, event) logger.debug( "on_send_join_request: After _handle_new_event: %s, sigs: %s", @@ -682,11 +658,12 @@ class FederationHandler(BaseHandler): waiters.pop().callback(None) @defer.inlineCallbacks - def _handle_new_event(self, event, state=None, backfilled=False, - current_state=None, fetch_auth_from=None): + @log_function + def _handle_new_event(self, origin, event, state=None, backfilled=False, + current_state=None, auth_events=None): logger.debug( - "_handle_new_event: Before annotate: %s, sigs: %s", + "_handle_new_event: %s, sigs: %s", event.event_id, event.signatures, ) @@ -694,65 +671,44 @@ class FederationHandler(BaseHandler): event, old_state=state ) + if not auth_events: + auth_events = context.auth_events + logger.debug( - "_handle_new_event: Before auth fetch: %s, sigs: %s", - event.event_id, event.signatures, + "_handle_new_event: %s, auth_events: %s", + event.event_id, auth_events, ) is_new_state = not event.internal_metadata.is_outlier() - known_ids = set( - [s.event_id for s in context.auth_events.values()] - ) - - for e_id, _ in event.auth_events: - if e_id not in known_ids: - e = yield self.store.get_event(e_id, allow_none=True) - - if not e and fetch_auth_from is not None: - # Grab the auth_chain over federation if we are missing - # auth events. - auth_chain = yield self.replication_layer.get_event_auth( - fetch_auth_from, event.event_id, event.room_id - ) - for auth_event in auth_chain: - yield self._handle_new_event(auth_event) - e = yield self.store.get_event(e_id, allow_none=True) - - if not e: - # TODO: Do some conflict res to make sure that we're - # not the ones who are wrong. - logger.info( - "Rejecting %s as %s not in db or %s", - event.event_id, e_id, known_ids, - ) - # FIXME: How does raising AuthError work with federation? - raise AuthError(403, "Cannot find auth event") - - context.auth_events[(e.type, e.state_key)] = e - - logger.debug( - "_handle_new_event: Before hack: %s, sigs: %s", - event.event_id, event.signatures, - ) - + # This is a hack to fix some old rooms where the initial join event + # didn't reference the create event in its auth events. if event.type == EventTypes.Member and not event.auth_events: if len(event.prev_events) == 1: c = yield self.store.get_event(event.prev_events[0][0]) if c.type == EventTypes.Create: - context.auth_events[(c.type, c.state_key)] = c + auth_events[(c.type, c.state_key)] = c - logger.debug( - "_handle_new_event: Before auth check: %s, sigs: %s", - event.event_id, event.signatures, - ) + try: + yield self.do_auth( + origin, event, context, auth_events=auth_events + ) + except AuthError as e: + logger.warn( + "Rejecting %s because %s", + event.event_id, e.msg + ) - self.auth.check(event, auth_events=context.auth_events) + context.rejected = RejectedReason.AUTH_ERROR - logger.debug( - "_handle_new_event: Before persist_event: %s, sigs: %s", - event.event_id, event.signatures, - ) + yield self.store.persist_event( + event, + context=context, + backfilled=backfilled, + is_new_state=False, + current_state=current_state, + ) + raise yield self.store.persist_event( event, @@ -762,9 +718,294 @@ class FederationHandler(BaseHandler): current_state=current_state, ) - logger.debug( - "_handle_new_event: After persist_event: %s, sigs: %s", - event.event_id, event.signatures, + defer.returnValue(context) + + @defer.inlineCallbacks + def on_query_auth(self, origin, event_id, remote_auth_chain, rejects, + missing): + # Just go through and process each event in `remote_auth_chain`. We + # don't want to fall into the trap of `missing` being wrong. + for e in remote_auth_chain: + try: + yield self._handle_new_event(origin, e) + except AuthError: + pass + + # Now get the current auth_chain for the event. + local_auth_chain = yield self.store.get_auth_chain([event_id]) + + # TODO: Check if we would now reject event_id. If so we need to tell + # everyone. + + ret = yield self.construct_auth_difference( + local_auth_chain, remote_auth_chain ) - defer.returnValue(context) + for event in ret["auth_chain"]: + event.signatures.update( + compute_event_signature( + event, + self.hs.hostname, + self.hs.config.signing_key[0] + ) + ) + + logger.debug("on_query_auth reutrning: %s", ret) + + defer.returnValue(ret) + + @defer.inlineCallbacks + @log_function + def do_auth(self, origin, event, context, auth_events): + # Check if we have all the auth events. + res = yield self.store.have_events( + [e_id for e_id, _ in event.auth_events] + ) + + event_auth_events = set(e_id for e_id, _ in event.auth_events) + seen_events = set(res.keys()) + + missing_auth = event_auth_events - seen_events + + if missing_auth: + logger.debug("Missing auth: %s", missing_auth) + # If we don't have all the auth events, we need to get them. + remote_auth_chain = yield self.replication_layer.get_event_auth( + origin, event.room_id, event.event_id + ) + + seen_remotes = yield self.store.have_events( + [e.event_id for e in remote_auth_chain] + ) + + for e in remote_auth_chain: + if e.event_id in seen_remotes.keys(): + continue + + if e.event_id == event.event_id: + continue + + try: + auth_ids = [e_id for e_id, _ in e.auth_events] + auth = { + (e.type, e.state_key): e for e in remote_auth_chain + if e.event_id in auth_ids + } + e.internal_metadata.outlier = True + + logger.debug( + "do_auth %s missing_auth: %s", + event.event_id, e.event_id + ) + yield self._handle_new_event( + origin, e, auth_events=auth + ) + + if e.event_id in event_auth_events: + auth_events[(e.type, e.state_key)] = e + except AuthError: + pass + + # FIXME: Assumes we have and stored all the state for all the + # prev_events + current_state = set(e.event_id for e in auth_events.values()) + different_auth = event_auth_events - current_state + + if different_auth and not event.internal_metadata.is_outlier(): + # Do auth conflict res. + logger.debug("Different auth: %s", different_auth) + + # 1. Get what we think is the auth chain. + auth_ids = self.auth.compute_auth_events(event, context) + local_auth_chain = yield self.store.get_auth_chain(auth_ids) + + # 2. Get remote difference. + result = yield self.replication_layer.query_auth( + origin, + event.room_id, + event.event_id, + local_auth_chain, + ) + + seen_remotes = yield self.store.have_events( + [e.event_id for e in result["auth_chain"]] + ) + + # 3. Process any remote auth chain events we haven't seen. + for ev in result["auth_chain"]: + if ev.event_id in seen_remotes.keys(): + continue + + if ev.event_id == event.event_id: + continue + + try: + auth_ids = [e_id for e_id, _ in ev.auth_events] + auth = { + (e.type, e.state_key): e for e in result["auth_chain"] + if e.event_id in auth_ids + } + ev.internal_metadata.outlier = True + + logger.debug( + "do_auth %s different_auth: %s", + event.event_id, e.event_id + ) + + yield self._handle_new_event( + origin, ev, auth_events=auth + ) + + if ev.event_id in event_auth_events: + auth_events[(ev.type, ev.state_key)] = ev + except AuthError: + pass + + # 4. Look at rejects and their proofs. + # TODO. + + context.current_state.update(auth_events) + context.state_group = None + + try: + self.auth.check(event, auth_events=auth_events) + except AuthError: + raise + + @defer.inlineCallbacks + def construct_auth_difference(self, local_auth, remote_auth): + """ Given a local and remote auth chain, find the differences. This + assumes that we have already processed all events in remote_auth + + Params: + local_auth (list) + remote_auth (list) + + Returns: + dict + """ + + logger.debug("construct_auth_difference Start!") + + # TODO: Make sure we are OK with local_auth or remote_auth having more + # auth events in them than strictly necessary. + + def sort_fun(ev): + return ev.depth, ev.event_id + + logger.debug("construct_auth_difference after sort_fun!") + + # We find the differences by starting at the "bottom" of each list + # and iterating up on both lists. The lists are ordered by depth and + # then event_id, we iterate up both lists until we find the event ids + # don't match. Then we look at depth/event_id to see which side is + # missing that event, and iterate only up that list. Repeat. + + remote_list = list(remote_auth) + remote_list.sort(key=sort_fun) + + local_list = list(local_auth) + local_list.sort(key=sort_fun) + + local_iter = iter(local_list) + remote_iter = iter(remote_list) + + logger.debug("construct_auth_difference before get_next!") + + def get_next(it, opt=None): + try: + return it.next() + except: + return opt + + current_local = get_next(local_iter) + current_remote = get_next(remote_iter) + + logger.debug("construct_auth_difference before while") + + missing_remotes = [] + missing_locals = [] + while current_local or current_remote: + if current_remote is None: + missing_locals.append(current_local) + current_local = get_next(local_iter) + continue + + if current_local is None: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + if current_local.event_id == current_remote.event_id: + current_local = get_next(local_iter) + current_remote = get_next(remote_iter) + continue + + if current_local.depth < current_remote.depth: + missing_locals.append(current_local) + current_local = get_next(local_iter) + continue + + if current_local.depth > current_remote.depth: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + # They have the same depth, so we fall back to the event_id order + if current_local.event_id < current_remote.event_id: + missing_locals.append(current_local) + current_local = get_next(local_iter) + + if current_local.event_id > current_remote.event_id: + missing_remotes.append(current_remote) + current_remote = get_next(remote_iter) + continue + + logger.debug("construct_auth_difference after while") + + # missing locals should be sent to the server + # We should find why we are missing remotes, as they will have been + # rejected. + + # Remove events from missing_remotes if they are referencing a missing + # remote. We only care about the "root" rejected ones. + missing_remote_ids = [e.event_id for e in missing_remotes] + base_remote_rejected = list(missing_remotes) + for e in missing_remotes: + for e_id, _ in e.auth_events: + if e_id in missing_remote_ids: + base_remote_rejected.remove(e) + + reason_map = {} + + for e in base_remote_rejected: + reason = yield self.store.get_rejection_reason(e.event_id) + if reason is None: + # FIXME: ERRR?! + logger.warn("Could not find reason for %s", e.event_id) + raise RuntimeError("") + + reason_map[e.event_id] = reason + + if reason == RejectedReason.AUTH_ERROR: + pass + elif reason == RejectedReason.REPLACED: + # TODO: Get proof + pass + elif reason == RejectedReason.NOT_ANCESTOR: + # TODO: Get proof. + pass + + logger.debug("construct_auth_difference returning") + + defer.returnValue({ + "auth_chain": local_auth, + "rejects": { + e.event_id: { + "reason": reason_map[e.event_id], + "proof": None, + } + for e in base_remote_rejected + }, + "missing": [e.event_id for e in missing_locals], + }) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9c3271fe88..6fbd2af4ab 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -114,7 +114,8 @@ class MessageHandler(BaseHandler): defer.returnValue(chunk) @defer.inlineCallbacks - def create_and_send_event(self, event_dict, ratelimit=True): + def create_and_send_event(self, event_dict, ratelimit=True, + client=None, txn_id=None): """ Given a dict from a client, create and handle a new event. Creates an FrozenEvent object, filling out auth_events, prev_events, @@ -148,6 +149,15 @@ class MessageHandler(BaseHandler): builder.content ) + if client is not None: + if client.token_id is not None: + builder.internal_metadata.token_id = client.token_id + if client.device_id is not None: + builder.internal_metadata.device_id = client.device_id + + if txn_id is not None: + builder.internal_metadata.txn_id = txn_id + event, context = yield self._create_new_client_event( builder=builder, ) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index d66bfea7b1..cd0798c2b0 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -87,6 +87,10 @@ class PresenceHandler(BaseHandler): "changed_presencelike_data", self.changed_presencelike_data ) + # outbound signal from the presence module to advertise when a user's + # presence has changed + distributor.declare("user_presence_changed") + self.distributor = distributor self.federation = hs.get_replication_layer() @@ -604,6 +608,7 @@ class PresenceHandler(BaseHandler): room_ids=room_ids, statuscache=statuscache, ) + yield self.distributor.fire("user_presence_changed", user, statuscache) @defer.inlineCallbacks def _push_presence_remote(self, user, destination, state=None): diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 732652c228..66a89c10b2 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -163,7 +163,7 @@ class RegistrationHandler(BaseHandler): # each request httpCli = SimpleHttpClient(self.hs) # XXX: make this configurable! - trustedIdServers = ['matrix.org:8090'] + trustedIdServers = ['matrix.org:8090', 'matrix.org'] if not creds['idServer'] in trustedIdServers: logger.warn('%s is not a trusted ID server: rejecting 3pid ' + 'credentials', creds['idServer']) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py new file mode 100644 index 0000000000..962686f4bb --- /dev/null +++ b/synapse/handlers/sync.py @@ -0,0 +1,434 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import BaseHandler + +from synapse.streams.config import PaginationConfig +from synapse.api.constants import Membership, EventTypes + +from twisted.internet import defer + +import collections +import logging + +logger = logging.getLogger(__name__) + + +SyncConfig = collections.namedtuple("SyncConfig", [ + "user", + "client_info", + "limit", + "gap", + "sort", + "backfill", + "filter", +]) + + +class RoomSyncResult(collections.namedtuple("RoomSyncResult", [ + "room_id", + "limited", + "published", + "events", + "state", + "prev_batch", + "ephemeral", +])): + __slots__ = [] + + def __nonzero__(self): + """Make the result appear empty if there are no updates. This is used + to tell if room needs to be part of the sync result. + """ + return bool(self.events or self.state or self.ephemeral) + + +class SyncResult(collections.namedtuple("SyncResult", [ + "next_batch", # Token for the next sync + "private_user_data", # List of private events for the user. + "public_user_data", # List of public events for all users. + "rooms", # RoomSyncResult for each room. +])): + __slots__ = [] + + def __nonzero__(self): + """Make the result appear empty if there are no updates. This is used + to tell if the notifier needs to wait for more events when polling for + events. + """ + return bool( + self.private_user_data or self.public_user_data or self.rooms + ) + + +class SyncHandler(BaseHandler): + + def __init__(self, hs): + super(SyncHandler, self).__init__(hs) + self.event_sources = hs.get_event_sources() + self.clock = hs.get_clock() + + @defer.inlineCallbacks + def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0): + """Get the sync for a client if we have new data for it now. Otherwise + wait for new data to arrive on the server. If the timeout expires, then + return an empty sync result. + Returns: + A Deferred SyncResult. + """ + if timeout == 0 or since_token is None: + result = yield self.current_sync_for_user(sync_config, since_token) + defer.returnValue(result) + else: + def current_sync_callback(): + return self.current_sync_for_user(sync_config, since_token) + + rm_handler = self.hs.get_handlers().room_member_handler + room_ids = yield rm_handler.get_rooms_for_user(sync_config.user) + result = yield self.notifier.wait_for_events( + sync_config.user, room_ids, + sync_config.filter, timeout, current_sync_callback + ) + defer.returnValue(result) + + def current_sync_for_user(self, sync_config, since_token=None): + """Get the sync for client needed to match what the server has now. + Returns: + A Deferred SyncResult. + """ + if since_token is None: + return self.initial_sync(sync_config) + else: + if sync_config.gap: + return self.incremental_sync_with_gap(sync_config, since_token) + else: + #TODO(mjark): Handle gapless sync + raise NotImplementedError() + + @defer.inlineCallbacks + def initial_sync(self, sync_config): + """Get a sync for a client which is starting without any state + Returns: + A Deferred SyncResult. + """ + if sync_config.sort == "timeline,desc": + # TODO(mjark): Handle going through events in reverse order?. + # What does "most recent events" mean when applying the limits mean + # in this case? + raise NotImplementedError() + + now_token = yield self.event_sources.get_current_token() + + presence_stream = self.event_sources.sources["presence"] + # TODO (mjark): This looks wrong, shouldn't we be getting the presence + # UP to the present rather than after the present? + pagination_config = PaginationConfig(from_token=now_token) + presence, _ = yield presence_stream.get_pagination_rows( + user=sync_config.user, + pagination_config=pagination_config.get_source_config("presence"), + key=None + ) + room_list = yield self.store.get_rooms_for_user_where_membership_is( + user_id=sync_config.user.to_string(), + membership_list=[Membership.INVITE, Membership.JOIN] + ) + + # TODO (mjark): Does public mean "published"? + published_rooms = yield self.store.get_rooms(is_public=True) + published_room_ids = set(r["room_id"] for r in published_rooms) + + rooms = [] + for event in room_list: + room_sync = yield self.initial_sync_for_room( + event.room_id, sync_config, now_token, published_room_ids + ) + rooms.append(room_sync) + + defer.returnValue(SyncResult( + public_user_data=presence, + private_user_data=[], + rooms=rooms, + next_batch=now_token, + )) + + @defer.inlineCallbacks + def initial_sync_for_room(self, room_id, sync_config, now_token, + published_room_ids): + """Sync a room for a client which is starting without any state + Returns: + A Deferred RoomSyncResult. + """ + + recents, prev_batch_token, limited = yield self.load_filtered_recents( + room_id, sync_config, now_token, + ) + + current_state_events = yield self.state_handler.get_current_state( + room_id + ) + + defer.returnValue(RoomSyncResult( + room_id=room_id, + published=room_id in published_room_ids, + events=recents, + prev_batch=prev_batch_token, + state=current_state_events, + limited=limited, + ephemeral=[], + )) + + @defer.inlineCallbacks + def incremental_sync_with_gap(self, sync_config, since_token): + """ Get the incremental delta needed to bring the client up to + date with the server. + Returns: + A Deferred SyncResult. + """ + if sync_config.sort == "timeline,desc": + # TODO(mjark): Handle going through events in reverse order?. + # What does "most recent events" mean when applying the limits mean + # in this case? + raise NotImplementedError() + + now_token = yield self.event_sources.get_current_token() + + presence_source = self.event_sources.sources["presence"] + presence, presence_key = yield presence_source.get_new_events_for_user( + user=sync_config.user, + from_key=since_token.presence_key, + limit=sync_config.limit, + ) + now_token = now_token.copy_and_replace("presence_key", presence_key) + + typing_source = self.event_sources.sources["typing"] + typing, typing_key = yield typing_source.get_new_events_for_user( + user=sync_config.user, + from_key=since_token.typing_key, + limit=sync_config.limit, + ) + now_token = now_token.copy_and_replace("typing_key", typing_key) + + typing_by_room = {event["room_id"]: [event] for event in typing} + for event in typing: + event.pop("room_id") + logger.debug("Typing %r", typing_by_room) + + rm_handler = self.hs.get_handlers().room_member_handler + room_ids = yield rm_handler.get_rooms_for_user(sync_config.user) + + # TODO (mjark): Does public mean "published"? + published_rooms = yield self.store.get_rooms(is_public=True) + published_room_ids = set(r["room_id"] for r in published_rooms) + + room_events, _ = yield self.store.get_room_events_stream( + sync_config.user.to_string(), + from_key=since_token.room_key, + to_key=now_token.room_key, + room_id=None, + limit=sync_config.limit + 1, + ) + + rooms = [] + if len(room_events) <= sync_config.limit: + # There is no gap in any of the rooms. Therefore we can just + # partition the new events by room and return them. + events_by_room_id = {} + for event in room_events: + events_by_room_id.setdefault(event.room_id, []).append(event) + + for room_id in room_ids: + recents = events_by_room_id.get(room_id, []) + state = [event for event in recents if event.is_state()] + if recents: + prev_batch = now_token.copy_and_replace( + "room_key", recents[0].internal_metadata.before + ) + else: + prev_batch = now_token + + state = yield self.check_joined_room( + sync_config, room_id, state + ) + + room_sync = RoomSyncResult( + room_id=room_id, + published=room_id in published_room_ids, + events=recents, + prev_batch=prev_batch, + state=state, + limited=False, + ephemeral=typing_by_room.get(room_id, []) + ) + if room_sync: + rooms.append(room_sync) + else: + for room_id in room_ids: + room_sync = yield self.incremental_sync_with_gap_for_room( + room_id, sync_config, since_token, now_token, + published_room_ids, typing_by_room + ) + if room_sync: + rooms.append(room_sync) + + defer.returnValue(SyncResult( + public_user_data=presence, + private_user_data=[], + rooms=rooms, + next_batch=now_token, + )) + + @defer.inlineCallbacks + def load_filtered_recents(self, room_id, sync_config, now_token, + since_token=None): + limited = True + recents = [] + filtering_factor = 2 + load_limit = max(sync_config.limit * filtering_factor, 100) + max_repeat = 3 # Only try a few times per room, otherwise + room_key = now_token.room_key + + while limited and len(recents) < sync_config.limit and max_repeat: + events, keys = yield self.store.get_recent_events_for_room( + room_id, + limit=load_limit + 1, + from_token=since_token.room_key if since_token else None, + end_token=room_key, + ) + (room_key, _) = keys + loaded_recents = sync_config.filter.filter_room_events(events) + loaded_recents.extend(recents) + recents = loaded_recents + if len(events) <= load_limit: + limited = False + max_repeat -= 1 + + if len(recents) > sync_config.limit: + recents = recents[-sync_config.limit:] + room_key = recents[0].internal_metadata.before + + prev_batch_token = now_token.copy_and_replace( + "room_key", room_key + ) + + defer.returnValue((recents, prev_batch_token, limited)) + + @defer.inlineCallbacks + def incremental_sync_with_gap_for_room(self, room_id, sync_config, + since_token, now_token, + published_room_ids, typing_by_room): + """ Get the incremental delta needed to bring the client up to date for + the room. Gives the client the most recent events and the changes to + state. + Returns: + A Deferred RoomSyncResult + """ + + # TODO(mjark): Check for redactions we might have missed. + + recents, prev_batch_token, limited = yield self.load_filtered_recents( + room_id, sync_config, now_token, since_token, + ) + + logging.debug("Recents %r", recents) + + # TODO(mjark): This seems racy since this isn't being passed a + # token to indicate what point in the stream this is + current_state_events = yield self.state_handler.get_current_state( + room_id + ) + + state_at_previous_sync = yield self.get_state_at_previous_sync( + room_id, since_token=since_token + ) + + state_events_delta = yield self.compute_state_delta( + since_token=since_token, + previous_state=state_at_previous_sync, + current_state=current_state_events, + ) + + state_events_delta = yield self.check_joined_room( + sync_config, room_id, state_events_delta + ) + + room_sync = RoomSyncResult( + room_id=room_id, + published=room_id in published_room_ids, + events=recents, + prev_batch=prev_batch_token, + state=state_events_delta, + limited=limited, + ephemeral=typing_by_room.get(room_id, []) + ) + + logging.debug("Room sync: %r", room_sync) + + defer.returnValue(room_sync) + + @defer.inlineCallbacks + def get_state_at_previous_sync(self, room_id, since_token): + """ Get the room state at the previous sync the client made. + Returns: + A Deferred list of Events. + """ + last_events, token = yield self.store.get_recent_events_for_room( + room_id, end_token=since_token.room_key, limit=1, + ) + + if last_events: + last_event = last_events[0] + last_context = yield self.state_handler.compute_event_context( + last_event + ) + if last_event.is_state(): + state = [last_event] + last_context.current_state.values() + else: + state = last_context.current_state.values() + else: + state = () + defer.returnValue(state) + + def compute_state_delta(self, since_token, previous_state, current_state): + """ Works out the differnce in state between the current state and the + state the client got when it last performed a sync. + Returns: + A list of events. + """ + # TODO(mjark) Check if the state events were received by the server + # after the previous sync, since we need to include those state + # updates even if they occured logically before the previous event. + # TODO(mjark) Check for new redactions in the state events. + previous_dict = {event.event_id: event for event in previous_state} + state_delta = [] + for event in current_state: + if event.event_id not in previous_dict: + state_delta.append(event) + return state_delta + + @defer.inlineCallbacks + def check_joined_room(self, sync_config, room_id, state_delta): + joined = False + for event in state_delta: + if ( + event.type == EventTypes.Member + and event.state_key == sync_config.user.to_string() + ): + if event.content["membership"] == Membership.JOIN: + joined = True + + if joined: + state_delta = yield self.state_handler.get_current_state(room_id) + + defer.returnValue(state_delta) diff --git a/synapse/http/client.py b/synapse/http/client.py index 7793bab106..198f575cfa 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -62,6 +62,25 @@ class SimpleHttpClient(object): defer.returnValue(json.loads(body)) + @defer.inlineCallbacks + def post_json_get_json(self, uri, post_json): + json_str = json.dumps(post_json) + + logger.info("HTTP POST %s -> %s", json_str, uri) + + response = yield self.agent.request( + "POST", + uri.encode("ascii"), + headers=Headers({ + "Content-Type": ["application/json"] + }), + bodyProducer=FileBodyProducer(StringIO(json_str)) + ) + + body = yield readBody(response) + + defer.returnValue(json.loads(body)) + @defer.inlineCallbacks def get_json(self, uri, args={}): """ Get's some json from the given host and path diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 1dda3ba2c7..c7bf1b47b8 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -244,6 +244,43 @@ class MatrixFederationHttpClient(object): defer.returnValue((response.code, body)) + @defer.inlineCallbacks + def post_json(self, destination, path, data={}): + """ Sends the specifed json data using POST + + Args: + destination (str): The remote server to send the HTTP request + to. + path (str): The HTTP path. + data (dict): A dict containing the data that will be used as + the request body. This will be encoded as JSON. + + Returns: + Deferred: Succeeds when we get a 2xx HTTP response. The result + will be the decoded JSON body. On a 4xx or 5xx error response a + CodeMessageException is raised. + """ + + def body_callback(method, url_bytes, headers_dict): + self.sign_request( + destination, method, url_bytes, headers_dict, data + ) + return _JsonProducer(data) + + response = yield self._create_request( + destination.encode("ascii"), + "POST", + path.encode("ascii"), + body_callback=body_callback, + headers_dict={"Content-Type": ["application/json"]}, + ) + + logger.debug("Getting resp body") + body = yield readBody(response) + logger.debug("Got resp body") + + defer.returnValue((response.code, body)) + @defer.inlineCallbacks def get_json(self, destination, path, args={}, retry_on_dns_fail=True): """ GETs some json from the given host homeserver and path diff --git a/synapse/http/server.py b/synapse/http/server.py index 8015a22edf..0f6539e1be 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -16,7 +16,7 @@ from synapse.http.agent_name import AGENT_NAME from synapse.api.errors import ( - cs_exception, SynapseError, CodeMessageException + cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError ) from synapse.util.logcontext import LoggingContext @@ -139,11 +139,7 @@ class JsonResource(HttpServer, resource.Resource): return # Huh. No one wanted to handle that? Fiiiiiine. Send 400. - self._send_response( - request, - 400, - {"error": "Unrecognized request"} - ) + raise UnrecognizedRequestError() except CodeMessageException as e: if isinstance(e, SynapseError): logger.info("%s SynapseError: %s - %s", request, e.code, e.msg) diff --git a/synapse/notifier.py b/synapse/notifier.py index 3aec1d4af2..e3b6ead620 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -18,6 +18,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.util.logcontext import PreserveLoggingContext from synapse.util.async import run_on_reactor +from synapse.types import StreamToken import logging @@ -205,6 +206,53 @@ class Notifier(object): [notify(l).addErrback(eb) for l in listeners] ) + @defer.inlineCallbacks + def wait_for_events(self, user, rooms, filter, timeout, callback): + """Wait until the callback returns a non empty response or the + timeout fires. + """ + + deferred = defer.Deferred() + + from_token = StreamToken("s0", "0", "0") + + listener = [_NotificationListener( + user=user, + rooms=rooms, + from_token=from_token, + limit=1, + timeout=timeout, + deferred=deferred, + )] + + if timeout: + self._register_with_keys(listener[0]) + + result = yield callback() + if timeout: + timed_out = [False] + + def _timeout_listener(): + timed_out[0] = True + listener[0].notify(self, [], from_token, from_token) + + self.clock.call_later(timeout/1000., _timeout_listener) + while not result and not timed_out[0]: + yield deferred + deferred = defer.Deferred() + listener[0] = _NotificationListener( + user=user, + rooms=rooms, + from_token=from_token, + limit=1, + timeout=timeout, + deferred=deferred, + ) + self._register_with_keys(listener[0]) + result = yield callback() + + defer.returnValue(result) + def get_events_for(self, user, rooms, pagination_config, timeout): """ For the given user and rooms, return any new events for them. If there are no new events wait for up to `timeout` milliseconds for any diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py new file mode 100644 index 0000000000..28e5dae81d --- /dev/null +++ b/synapse/push/__init__.py @@ -0,0 +1,410 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.streams.config import PaginationConfig +from synapse.types import StreamToken, UserID + +import synapse.util.async +import baserules + +import logging +import fnmatch +import json +import re + +logger = logging.getLogger(__name__) + + +class Pusher(object): + INITIAL_BACKOFF = 1000 + MAX_BACKOFF = 60 * 60 * 1000 + GIVE_UP_AFTER = 24 * 60 * 60 * 1000 + DEFAULT_ACTIONS = ['notify'] + + INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$") + + def __init__(self, _hs, instance_handle, user_name, app_id, + app_display_name, device_display_name, pushkey, pushkey_ts, + data, last_token, last_success, failing_since): + self.hs = _hs + self.evStreamHandler = self.hs.get_handlers().event_stream_handler + self.store = self.hs.get_datastore() + self.clock = self.hs.get_clock() + self.instance_handle = instance_handle + self.user_name = user_name + self.app_id = app_id + self.app_display_name = app_display_name + self.device_display_name = device_display_name + self.pushkey = pushkey + self.pushkey_ts = pushkey_ts + self.data = data + self.last_token = last_token + self.last_success = last_success # not actually used + self.backoff_delay = Pusher.INITIAL_BACKOFF + self.failing_since = failing_since + self.alive = True + + # The last value of last_active_time that we saw + self.last_last_active_time = 0 + self.has_unread = True + + @defer.inlineCallbacks + def _actions_for_event(self, ev): + """ + This should take into account notification settings that the user + has configured both globally and per-room when we have the ability + to do such things. + """ + if ev['user_id'] == self.user_name: + # let's assume you probably know about messages you sent yourself + defer.returnValue(['dont_notify']) + + if ev['type'] == 'm.room.member': + if ev['state_key'] != self.user_name: + defer.returnValue(['dont_notify']) + + rules = yield self.store.get_push_rules_for_user_name(self.user_name) + + for r in rules: + r['conditions'] = json.loads(r['conditions']) + r['actions'] = json.loads(r['actions']) + + user_name_localpart = UserID.from_string(self.user_name).localpart + + rules.extend(baserules.make_base_rules(user_name_localpart)) + + # get *our* member event for display name matching + member_events_for_room = yield self.store.get_current_state( + room_id=ev['room_id'], + event_type='m.room.member', + state_key=None + ) + my_display_name = None + room_member_count = 0 + for mev in member_events_for_room: + if mev.content['membership'] != 'join': + continue + + # This loop does two things: + # 1) Find our current display name + if mev.state_key == self.user_name and 'displayname' in mev.content: + my_display_name = mev.content['displayname'] + + # and 2) Get the number of people in that room + room_member_count += 1 + + for r in rules: + matches = True + + conditions = r['conditions'] + actions = r['actions'] + + for c in conditions: + matches &= self._event_fulfills_condition( + ev, c, display_name=my_display_name, + room_member_count=room_member_count + ) + # ignore rules with no actions (we have an explict 'dont_notify' + if len(actions) == 0: + logger.warn( + "Ignoring rule id %s with no actions for user %s" % + (r['rule_id'], r['user_name']) + ) + continue + if matches: + defer.returnValue(actions) + + defer.returnValue(Pusher.DEFAULT_ACTIONS) + + def _event_fulfills_condition(self, ev, condition, display_name, room_member_count): + if condition['kind'] == 'event_match': + if 'pattern' not in condition: + logger.warn("event_match condition with no pattern") + return False + pat = condition['pattern'] + + if pat.strip("*?[]") == pat: + # no special glob characters so we assume the user means + # 'contains this string' rather than 'is this string' + pat = "*%s*" % (pat,) + + val = _value_for_dotted_key(condition['key'], ev) + if val is None: + return False + return fnmatch.fnmatch(val.upper(), pat.upper()) + elif condition['kind'] == 'device': + if 'instance_handle' not in condition: + return True + return condition['instance_handle'] == self.instance_handle + elif condition['kind'] == 'contains_display_name': + # This is special because display names can be different + # between rooms and so you can't really hard code it in a rule. + # Optimisation: we should cache these names and update them from + # the event stream. + if 'content' not in ev or 'body' not in ev['content']: + return False + if not display_name: + return False + return fnmatch.fnmatch( + ev['content']['body'].upper(), "*%s*" % (display_name.upper(),) + ) + elif condition['kind'] == 'room_member_count': + if 'is' not in condition: + return False + m = Pusher.INEQUALITY_EXPR.match(condition['is']) + if not m: + return False + ineq = m.group(1) + rhs = m.group(2) + if not rhs.isdigit(): + return False + rhs = int(rhs) + + if ineq == '' or ineq == '==': + return room_member_count == rhs + elif ineq == '<': + return room_member_count < rhs + elif ineq == '>': + return room_member_count > rhs + elif ineq == '>=': + return room_member_count >= rhs + elif ineq == '<=': + return room_member_count <= rhs + else: + return False + else: + return True + + @defer.inlineCallbacks + def get_context_for_event(self, ev): + name_aliases = yield self.store.get_room_name_and_aliases( + ev['room_id'] + ) + + ctx = {'aliases': name_aliases[1]} + if name_aliases[0] is not None: + ctx['name'] = name_aliases[0] + + their_member_events_for_room = yield self.store.get_current_state( + room_id=ev['room_id'], + event_type='m.room.member', + state_key=ev['user_id'] + ) + for mev in their_member_events_for_room: + if mev.content['membership'] == 'join' and 'displayname' in mev.content: + dn = mev.content['displayname'] + if dn is not None: + ctx['sender_display_name'] = dn + + defer.returnValue(ctx) + + @defer.inlineCallbacks + def start(self): + if not self.last_token: + # First-time setup: get a token to start from (we can't + # just start from no token, ie. 'now' + # because we need the result to be reproduceable in case + # we fail to dispatch the push) + config = PaginationConfig(from_token=None, limit='1') + chunk = yield self.evStreamHandler.get_stream( + self.user_name, config, timeout=0) + self.last_token = chunk['end'] + self.store.update_pusher_last_token( + self.user_name, self.pushkey, self.last_token) + logger.info("Pusher %s for user %s starting from token %s", + self.pushkey, self.user_name, self.last_token) + + while self.alive: + from_tok = StreamToken.from_string(self.last_token) + config = PaginationConfig(from_token=from_tok, limit='1') + chunk = yield self.evStreamHandler.get_stream( + self.user_name, config, + timeout=100*365*24*60*60*1000, affect_presence=False + ) + + # limiting to 1 may get 1 event plus 1 presence event, so + # pick out the actual event + single_event = None + for c in chunk['chunk']: + if 'event_id' in c: # Hmmm... + single_event = c + break + if not single_event: + self.last_token = chunk['end'] + continue + + if not self.alive: + continue + + processed = False + actions = yield self._actions_for_event(single_event) + tweaks = _tweaks_for_actions(actions) + + if len(actions) == 0: + logger.warn("Empty actions! Using default action.") + actions = Pusher.DEFAULT_ACTIONS + if 'notify' not in actions and 'dont_notify' not in actions: + logger.warn("Neither notify nor dont_notify in actions: adding default") + actions.extend(Pusher.DEFAULT_ACTIONS) + if 'dont_notify' in actions: + logger.debug( + "%s for %s: dont_notify", + single_event['event_id'], self.user_name + ) + processed = True + else: + rejected = yield self.dispatch_push(single_event, tweaks) + self.has_unread = True + if isinstance(rejected, list) or isinstance(rejected, tuple): + processed = True + for pk in rejected: + if pk != self.pushkey: + # for sanity, we only remove the pushkey if it + # was the one we actually sent... + logger.warn( + ("Ignoring rejected pushkey %s because we" + " didn't send it"), pk + ) + else: + logger.info( + "Pushkey %s was rejected: removing", + pk + ) + yield self.hs.get_pusherpool().remove_pusher( + self.app_id, pk + ) + + if not self.alive: + continue + + if processed: + self.backoff_delay = Pusher.INITIAL_BACKOFF + self.last_token = chunk['end'] + self.store.update_pusher_last_token_and_success( + self.user_name, + self.pushkey, + self.last_token, + self.clock.time_msec() + ) + if self.failing_since: + self.failing_since = None + self.store.update_pusher_failing_since( + self.user_name, + self.pushkey, + self.failing_since) + else: + if not self.failing_since: + self.failing_since = self.clock.time_msec() + self.store.update_pusher_failing_since( + self.user_name, + self.pushkey, + self.failing_since + ) + + if (self.failing_since and + self.failing_since < + self.clock.time_msec() - Pusher.GIVE_UP_AFTER): + # we really only give up so that if the URL gets + # fixed, we don't suddenly deliver a load + # of old notifications. + logger.warn("Giving up on a notification to user %s, " + "pushkey %s", + self.user_name, self.pushkey) + self.backoff_delay = Pusher.INITIAL_BACKOFF + self.last_token = chunk['end'] + self.store.update_pusher_last_token( + self.user_name, + self.pushkey, + self.last_token + ) + + self.failing_since = None + self.store.update_pusher_failing_since( + self.user_name, + self.pushkey, + self.failing_since + ) + else: + logger.warn("Failed to dispatch push for user %s " + "(failing for %dms)." + "Trying again in %dms", + self.user_name, + self.clock.time_msec() - self.failing_since, + self.backoff_delay) + yield synapse.util.async.sleep(self.backoff_delay / 1000.0) + self.backoff_delay *= 2 + if self.backoff_delay > Pusher.MAX_BACKOFF: + self.backoff_delay = Pusher.MAX_BACKOFF + + def stop(self): + self.alive = False + + def dispatch_push(self, p, tweaks): + """ + Overridden by implementing classes to actually deliver the notification + Args: + p: The event to notify for as a single event from the event stream + Returns: If the notification was delivered, an array containing any + pushkeys that were rejected by the push gateway. + False if the notification could not be delivered (ie. + should be retried). + """ + pass + + def reset_badge_count(self): + pass + + def presence_changed(self, state): + """ + We clear badge counts whenever a user's last_active time is bumped + This is by no means perfect but I think it's the best we can do + without read receipts. + """ + if 'last_active' in state.state: + last_active = state.state['last_active'] + if last_active > self.last_last_active_time: + self.last_last_active_time = last_active + if self.has_unread: + logger.info("Resetting badge count for %s", self.user_name) + self.reset_badge_count() + self.has_unread = False + + +def _value_for_dotted_key(dotted_key, event): + parts = dotted_key.split(".") + val = event + while len(parts) > 0: + if parts[0] not in val: + return None + val = val[parts[0]] + parts = parts[1:] + return val + + +def _tweaks_for_actions(actions): + tweaks = {} + for a in actions: + if not isinstance(a, dict): + continue + if 'set_sound' in a: + tweaks['sound'] = a['set_sound'] + return tweaks + + +class PusherConfigException(Exception): + def __init__(self, msg): + super(PusherConfigException, self).__init__(msg) diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py new file mode 100644 index 0000000000..382de118e0 --- /dev/null +++ b/synapse/push/baserules.py @@ -0,0 +1,48 @@ +def make_base_rules(user_name): + rules = [ + { + 'conditions': [ + { + 'kind': 'event_match', + 'key': 'content.body', + 'pattern': '*%s*' % (user_name,), # Matrix ID match + } + ], + 'actions': [ + 'notify', + { + 'set_sound': 'default' + } + ] + }, + { + 'conditions': [ + { + 'kind': 'contains_display_name' + } + ], + 'actions': [ + 'notify', + { + 'set_sound': 'default' + } + ] + }, + { + 'conditions': [ + { + 'kind': 'room_member_count', + 'is': '2' + } + ], + 'actions': [ + 'notify', + { + 'set_sound': 'default' + } + ] + } + ] + for r in rules: + r['priority_class'] = 0 + return rules \ No newline at end of file diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py new file mode 100644 index 0000000000..7c6953c989 --- /dev/null +++ b/synapse/push/httppusher.py @@ -0,0 +1,146 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.push import Pusher, PusherConfigException +from synapse.http.client import SimpleHttpClient + +from twisted.internet import defer + +import logging + +logger = logging.getLogger(__name__) + + +class HttpPusher(Pusher): + def __init__(self, _hs, instance_handle, user_name, app_id, + app_display_name, device_display_name, pushkey, pushkey_ts, + data, last_token, last_success, failing_since): + super(HttpPusher, self).__init__( + _hs, + instance_handle, + user_name, + app_id, + app_display_name, + device_display_name, + pushkey, + pushkey_ts, + data, + last_token, + last_success, + failing_since + ) + if 'url' not in data: + raise PusherConfigException( + "'url' required in data for HTTP pusher" + ) + self.url = data['url'] + self.httpCli = SimpleHttpClient(self.hs) + self.data_minus_url = {} + self.data_minus_url.update(self.data) + del self.data_minus_url['url'] + + @defer.inlineCallbacks + def _build_notification_dict(self, event, tweaks): + # we probably do not want to push for every presence update + # (we may want to be able to set up notifications when specific + # people sign in, but we'd want to only deliver the pertinent ones) + # Actually, presence events will not get this far now because we + # need to filter them out in the main Pusher code. + if 'event_id' not in event: + defer.returnValue(None) + + ctx = yield self.get_context_for_event(event) + + d = { + 'notification': { + 'id': event['event_id'], + 'type': event['type'], + 'sender': event['user_id'], + 'counts': { # -- we don't mark messages as read yet so + # we have no way of knowing + # Just set the badge to 1 until we have read receipts + 'unread': 1, + # 'missed_calls': 2 + }, + 'devices': [ + { + 'app_id': self.app_id, + 'pushkey': self.pushkey, + 'pushkey_ts': long(self.pushkey_ts / 1000), + 'data': self.data_minus_url, + 'tweaks': tweaks + } + ] + } + } + if event['type'] == 'm.room.member': + d['notification']['membership'] = event['content']['membership'] + if 'content' in event: + d['notification']['content'] = event['content'] + + if len(ctx['aliases']): + d['notification']['room_alias'] = ctx['aliases'][0] + if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0: + d['notification']['sender_display_name'] = ctx['sender_display_name'] + if 'name' in ctx and len(ctx['name']) > 0: + d['notification']['room_name'] = ctx['name'] + + defer.returnValue(d) + + @defer.inlineCallbacks + def dispatch_push(self, event, tweaks): + notification_dict = yield self._build_notification_dict(event, tweaks) + if not notification_dict: + defer.returnValue([]) + try: + resp = yield self.httpCli.post_json_get_json(self.url, notification_dict) + except: + logger.exception("Failed to push %s ", self.url) + defer.returnValue(False) + rejected = [] + if 'rejected' in resp: + rejected = resp['rejected'] + defer.returnValue(rejected) + + @defer.inlineCallbacks + def reset_badge_count(self): + d = { + 'notification': { + 'id': '', + 'type': None, + 'sender': '', + 'counts': { + 'unread': 0, + 'missed_calls': 0 + }, + 'devices': [ + { + 'app_id': self.app_id, + 'pushkey': self.pushkey, + 'pushkey_ts': long(self.pushkey_ts / 1000), + 'data': self.data_minus_url, + } + ] + } + } + try: + resp = yield self.httpCli.post_json_get_json(self.url, d) + except: + logger.exception("Failed to push %s ", self.url) + defer.returnValue(False) + rejected = [] + if 'rejected' in resp: + rejected = resp['rejected'] + defer.returnValue(rejected) diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py new file mode 100644 index 0000000000..4892c21e7b --- /dev/null +++ b/synapse/push/pusherpool.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from httppusher import HttpPusher +from synapse.push import PusherConfigException + +import logging +import json + +logger = logging.getLogger(__name__) + + +class PusherPool: + def __init__(self, _hs): + self.hs = _hs + self.store = self.hs.get_datastore() + self.pushers = {} + self.last_pusher_started = -1 + + distributor = self.hs.get_distributor() + distributor.observe( + "user_presence_changed", self.user_presence_changed + ) + + @defer.inlineCallbacks + def user_presence_changed(self, user, state): + user_name = user.to_string() + + # until we have read receipts, pushers use this to reset a user's + # badge counters to zero + for p in self.pushers.values(): + if p.user_name == user_name: + yield p.presence_changed(state) + + @defer.inlineCallbacks + def start(self): + pushers = yield self.store.get_all_pushers() + for p in pushers: + p['data'] = json.loads(p['data']) + self._start_pushers(pushers) + + @defer.inlineCallbacks + def add_pusher(self, user_name, instance_handle, kind, app_id, + app_display_name, device_display_name, pushkey, lang, data): + # we try to create the pusher just to validate the config: it + # will then get pulled out of the database, + # recreated, added and started: this means we have only one + # code path adding pushers. + self._create_pusher({ + "user_name": user_name, + "kind": kind, + "instance_handle": instance_handle, + "app_id": app_id, + "app_display_name": app_display_name, + "device_display_name": device_display_name, + "pushkey": pushkey, + "pushkey_ts": self.hs.get_clock().time_msec(), + "lang": lang, + "data": data, + "last_token": None, + "last_success": None, + "failing_since": None + }) + yield self._add_pusher_to_store( + user_name, instance_handle, kind, app_id, + app_display_name, device_display_name, + pushkey, lang, data + ) + + @defer.inlineCallbacks + def _add_pusher_to_store(self, user_name, instance_handle, kind, app_id, + app_display_name, device_display_name, + pushkey, lang, data): + yield self.store.add_pusher( + user_name=user_name, + instance_handle=instance_handle, + kind=kind, + app_id=app_id, + app_display_name=app_display_name, + device_display_name=device_display_name, + pushkey=pushkey, + pushkey_ts=self.hs.get_clock().time_msec(), + lang=lang, + data=json.dumps(data) + ) + self._refresh_pusher((app_id, pushkey)) + + def _create_pusher(self, pusherdict): + if pusherdict['kind'] == 'http': + return HttpPusher( + self.hs, + instance_handle=pusherdict['instance_handle'], + user_name=pusherdict['user_name'], + app_id=pusherdict['app_id'], + app_display_name=pusherdict['app_display_name'], + device_display_name=pusherdict['device_display_name'], + pushkey=pusherdict['pushkey'], + pushkey_ts=pusherdict['pushkey_ts'], + data=pusherdict['data'], + last_token=pusherdict['last_token'], + last_success=pusherdict['last_success'], + failing_since=pusherdict['failing_since'] + ) + else: + raise PusherConfigException( + "Unknown pusher type '%s' for user %s" % + (pusherdict['kind'], pusherdict['user_name']) + ) + + @defer.inlineCallbacks + def _refresh_pusher(self, app_id_pushkey): + p = yield self.store.get_pushers_by_app_id_and_pushkey( + app_id_pushkey + ) + p['data'] = json.loads(p['data']) + + self._start_pushers([p]) + + def _start_pushers(self, pushers): + logger.info("Starting %d pushers", len(pushers)) + for pusherdict in pushers: + p = self._create_pusher(pusherdict) + if p: + fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey']) + if fullid in self.pushers: + self.pushers[fullid].stop() + self.pushers[fullid] = p + p.start() + + @defer.inlineCallbacks + def remove_pusher(self, app_id, pushkey): + fullid = "%s:%s" % (app_id, pushkey) + if fullid in self.pushers: + logger.info("Stopping pusher %s", fullid) + self.pushers[fullid].stop() + del self.pushers[fullid] + yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey) diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 4182ad990f..826a36f203 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -6,7 +6,7 @@ logger = logging.getLogger(__name__) REQUIREMENTS = { "syutil==0.0.2": ["syutil"], "matrix_angular_sdk==0.6.0": ["syweb>=0.6.0"], - "Twisted>=14.0.0": ["twisted>=14.0.0"], + "Twisted==14.0.2": ["twisted==14.0.2"], "service_identity>=1.0.0": ["service_identity>=1.0.0"], "pyopenssl>=0.14": ["OpenSSL>=0.14"], "pyyaml": ["yaml"], diff --git a/synapse/rest/client/v1/__init__.py b/synapse/rest/client/v1/__init__.py index 8bb89b2f6a..d8d01cdd16 100644 --- a/synapse/rest/client/v1/__init__.py +++ b/synapse/rest/client/v1/__init__.py @@ -13,10 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. - from . import ( room, events, register, login, profile, presence, initial_sync, directory, - voip, admin, + voip, admin, pusher, push_rule ) from synapse.http.server import JsonResource @@ -41,3 +40,5 @@ class ClientV1RestResource(JsonResource): directory.register_servlets(hs, client_resource) voip.register_servlets(hs, client_resource) admin.register_servlets(hs, client_resource) + pusher.register_servlets(hs, client_resource) + push_rule.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v1/admin.py b/synapse/rest/client/v1/admin.py index 1051d96f96..2ce754b028 100644 --- a/synapse/rest/client/v1/admin.py +++ b/synapse/rest/client/v1/admin.py @@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): target_user = UserID.from_string(user_id) - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(auth_user) if not is_admin and target_user != auth_user: diff --git a/synapse/rest/client/v1/directory.py b/synapse/rest/client/v1/directory.py index 15ae8749b8..8f65efec5f 100644 --- a/synapse/rest/client/v1/directory.py +++ b/synapse/rest/client/v1/directory.py @@ -45,7 +45,7 @@ class ClientDirectoryServer(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_alias): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) if not "room_id" in content: @@ -85,7 +85,7 @@ class ClientDirectoryServer(ClientV1RestServlet): @defer.inlineCallbacks def on_DELETE(self, request, room_alias): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) is_admin = yield self.auth.is_server_admin(user) if not is_admin: diff --git a/synapse/rest/client/v1/events.py b/synapse/rest/client/v1/events.py index a0d051227b..77b7c25a03 100644 --- a/synapse/rest/client/v1/events.py +++ b/synapse/rest/client/v1/events.py @@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) try: handler = self.handlers.event_stream_handler pagin_config = PaginationConfig.from_request(request) @@ -71,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, event_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) handler = self.handlers.event_handler event = yield handler.get_event(auth_user, event_id) diff --git a/synapse/rest/client/v1/initial_sync.py b/synapse/rest/client/v1/initial_sync.py index 357fa845b4..4a259bba64 100644 --- a/synapse/rest/client/v1/initial_sync.py +++ b/synapse/rest/client/v1/initial_sync.py @@ -25,7 +25,7 @@ class InitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) with_feedback = "feedback" in request.args as_client_event = "raw" not in request.args pagination_config = PaginationConfig.from_request(request) diff --git a/synapse/rest/client/v1/presence.py b/synapse/rest/client/v1/presence.py index b6c207e662..7feb4aadb1 100644 --- a/synapse/rest/client/v1/presence.py +++ b/synapse/rest/client/v1/presence.py @@ -32,7 +32,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = yield self.handlers.presence_handler.get_state( @@ -42,7 +42,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) state = {} @@ -77,7 +77,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): @@ -97,7 +97,7 @@ class PresenceListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) if not self.hs.is_mine(user): diff --git a/synapse/rest/client/v1/profile.py b/synapse/rest/client/v1/profile.py index 24f8d56952..15d6f3fc6c 100644 --- a/synapse/rest/client/v1/profile.py +++ b/synapse/rest/client/v1/profile.py @@ -37,7 +37,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: @@ -70,7 +70,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) user = UserID.from_string(user_id) try: diff --git a/synapse/rest/client/v1/push_rule.py b/synapse/rest/client/v1/push_rule.py new file mode 100644 index 0000000000..faa7919fbb --- /dev/null +++ b/synapse/rest/client/v1/push_rule.py @@ -0,0 +1,406 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError, NotFoundError, \ + StoreError +from .base import ClientV1RestServlet, client_path_pattern +from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException +import synapse.push.baserules as baserules + +import json + + +class PushRuleRestServlet(ClientV1RestServlet): + PATTERN = client_path_pattern("/pushrules/.*$") + PRIORITY_CLASS_MAP = { + 'default': 0, + 'underride': 1, + 'sender': 2, + 'room': 3, + 'content': 4, + 'override': 5, + } + PRIORITY_CLASS_INVERSE_MAP = {v: k for k, v in PRIORITY_CLASS_MAP.items()} + SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR = ( + "Unrecognised request: You probably wanted a trailing slash") + + def rule_spec_from_path(self, path): + if len(path) < 2: + raise UnrecognizedRequestError() + if path[0] != 'pushrules': + raise UnrecognizedRequestError() + + scope = path[1] + path = path[2:] + if scope not in ['global', 'device']: + raise UnrecognizedRequestError() + + device = None + if scope == 'device': + if len(path) == 0: + raise UnrecognizedRequestError() + device = path[0] + path = path[1:] + + if len(path) == 0: + raise UnrecognizedRequestError() + + template = path[0] + path = path[1:] + + if len(path) == 0: + raise UnrecognizedRequestError() + + rule_id = path[0] + + spec = { + 'scope': scope, + 'template': template, + 'rule_id': rule_id + } + if device: + spec['device'] = device + return spec + + def rule_tuple_from_request_object(self, rule_template, rule_id, req_obj, device=None): + if rule_template in ['override', 'underride']: + if 'conditions' not in req_obj: + raise InvalidRuleException("Missing 'conditions'") + conditions = req_obj['conditions'] + for c in conditions: + if 'kind' not in c: + raise InvalidRuleException("Condition without 'kind'") + elif rule_template == 'room': + conditions = [{ + 'kind': 'event_match', + 'key': 'room_id', + 'pattern': rule_id + }] + elif rule_template == 'sender': + conditions = [{ + 'kind': 'event_match', + 'key': 'user_id', + 'pattern': rule_id + }] + elif rule_template == 'content': + if 'pattern' not in req_obj: + raise InvalidRuleException("Content rule missing 'pattern'") + pat = req_obj['pattern'] + + conditions = [{ + 'kind': 'event_match', + 'key': 'content.body', + 'pattern': pat + }] + else: + raise InvalidRuleException("Unknown rule template: %s" % (rule_template,)) + + if device: + conditions.append({ + 'kind': 'device', + 'instance_handle': device + }) + + if 'actions' not in req_obj: + raise InvalidRuleException("No actions found") + actions = req_obj['actions'] + + for a in actions: + if a in ['notify', 'dont_notify', 'coalesce']: + pass + elif isinstance(a, dict) and 'set_sound' in a: + pass + else: + raise InvalidRuleException("Unrecognised action") + + return conditions, actions + + @defer.inlineCallbacks + def on_PUT(self, request): + spec = self.rule_spec_from_path(request.postpath) + try: + priority_class = _priority_class_from_spec(spec) + except InvalidRuleException as e: + raise SynapseError(400, e.message) + + user, _ = yield self.auth.get_user_by_req(request) + + if spec['template'] == 'default': + raise SynapseError(403, "The default rules are immutable.") + + content = _parse_json(request) + + try: + (conditions, actions) = self.rule_tuple_from_request_object( + spec['template'], + spec['rule_id'], + content, + device=spec['device'] if 'device' in spec else None + ) + except InvalidRuleException as e: + raise SynapseError(400, e.message) + + before = request.args.get("before", None) + if before and len(before): + before = before[0] + after = request.args.get("after", None) + if after and len(after): + after = after[0] + + try: + yield self.hs.get_datastore().add_push_rule( + user_name=user.to_string(), + rule_id=spec['rule_id'], + priority_class=priority_class, + conditions=conditions, + actions=actions, + before=before, + after=after + ) + except InconsistentRuleException as e: + raise SynapseError(400, e.message) + except RuleNotFoundException as e: + raise SynapseError(400, e.message) + + defer.returnValue((200, {})) + + @defer.inlineCallbacks + def on_DELETE(self, request): + spec = self.rule_spec_from_path(request.postpath) + try: + priority_class = _priority_class_from_spec(spec) + except InvalidRuleException as e: + raise SynapseError(400, e.message) + + user, _ = yield self.auth.get_user_by_req(request) + + if 'device' in spec: + rules = yield self.hs.get_datastore().get_push_rules_for_user_name( + user.to_string() + ) + + for r in rules: + conditions = json.loads(r['conditions']) + ih = _instance_handle_from_conditions(conditions) + if ih == spec['device'] and r['priority_class'] == priority_class: + yield self.hs.get_datastore().delete_push_rule( + user.to_string(), spec['rule_id'] + ) + defer.returnValue((200, {})) + raise NotFoundError() + else: + try: + yield self.hs.get_datastore().delete_push_rule( + user.to_string(), spec['rule_id'], + priority_class=priority_class + ) + defer.returnValue((200, {})) + except StoreError as e: + if e.code == 404: + raise NotFoundError() + else: + raise + + @defer.inlineCallbacks + def on_GET(self, request): + user, _ = yield self.auth.get_user_by_req(request) + + # we build up the full structure and then decide which bits of it + # to send which means doing unnecessary work sometimes but is + # is probably not going to make a whole lot of difference + rawrules = yield self.hs.get_datastore().get_push_rules_for_user_name(user.to_string()) + for r in rawrules: + r["conditions"] = json.loads(r["conditions"]) + r["actions"] = json.loads(r["actions"]) + rawrules.extend(baserules.make_base_rules(user.to_string())) + + rules = {'global': {}, 'device': {}} + + rules['global'] = _add_empty_priority_class_arrays(rules['global']) + + for r in rawrules: + rulearray = None + + template_name = _priority_class_to_template_name(r['priority_class']) + + if r['priority_class'] > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']: + # per-device rule + instance_handle = _instance_handle_from_conditions(r["conditions"]) + r = _strip_device_condition(r) + if not instance_handle: + continue + if instance_handle not in rules['device']: + rules['device'][instance_handle] = {} + rules['device'][instance_handle] = ( + _add_empty_priority_class_arrays( + rules['device'][instance_handle] + ) + ) + + rulearray = rules['device'][instance_handle][template_name] + else: + rulearray = rules['global'][template_name] + + template_rule = _rule_to_template(r) + if template_rule: + rulearray.append(template_rule) + + path = request.postpath[1:] + + if path == []: + # we're a reference impl: pedantry is our job. + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + + if path[0] == '': + defer.returnValue((200, rules)) + elif path[0] == 'global': + path = path[1:] + result = _filter_ruleset_with_path(rules['global'], path) + defer.returnValue((200, result)) + elif path[0] == 'device': + path = path[1:] + if path == []: + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + if path[0] == '': + defer.returnValue((200, rules['device'])) + + instance_handle = path[0] + path = path[1:] + if instance_handle not in rules['device']: + ret = {} + ret = _add_empty_priority_class_arrays(ret) + defer.returnValue((200, ret)) + ruleset = rules['device'][instance_handle] + result = _filter_ruleset_with_path(ruleset, path) + defer.returnValue((200, result)) + else: + raise UnrecognizedRequestError() + + def on_OPTIONS(self, _): + return 200, {} + + +def _add_empty_priority_class_arrays(d): + for pc in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys(): + d[pc] = [] + return d + + +def _instance_handle_from_conditions(conditions): + """ + Given a list of conditions, return the instance handle of the + device rule if there is one + """ + for c in conditions: + if c['kind'] == 'device': + return c['instance_handle'] + return None + + +def _filter_ruleset_with_path(ruleset, path): + if path == []: + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + + if path[0] == '': + return ruleset + template_kind = path[0] + if template_kind not in ruleset: + raise UnrecognizedRequestError() + path = path[1:] + if path == []: + raise UnrecognizedRequestError( + PushRuleRestServlet.SLIGHTLY_PEDANTIC_TRAILING_SLASH_ERROR + ) + if path[0] == '': + return ruleset[template_kind] + rule_id = path[0] + for r in ruleset[template_kind]: + if r['rule_id'] == rule_id: + return r + raise NotFoundError + + +def _priority_class_from_spec(spec): + if spec['template'] not in PushRuleRestServlet.PRIORITY_CLASS_MAP.keys(): + raise InvalidRuleException("Unknown template: %s" % (spec['kind'])) + pc = PushRuleRestServlet.PRIORITY_CLASS_MAP[spec['template']] + + if spec['scope'] == 'device': + pc += len(PushRuleRestServlet.PRIORITY_CLASS_MAP) + + return pc + + +def _priority_class_to_template_name(pc): + if pc > PushRuleRestServlet.PRIORITY_CLASS_MAP['override']: + # per-device + prio_class_index = pc - len(PushRuleRestServlet.PRIORITY_CLASS_MAP) + return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[prio_class_index] + else: + return PushRuleRestServlet.PRIORITY_CLASS_INVERSE_MAP[pc] + + +def _rule_to_template(rule): + template_name = _priority_class_to_template_name(rule['priority_class']) + if template_name in ['default']: + return {k: rule[k] for k in ["conditions", "actions"]} + elif template_name in ['override', 'underride']: + return {k: rule[k] for k in ["rule_id", "conditions", "actions"]} + elif template_name in ["sender", "room"]: + return {k: rule[k] for k in ["rule_id", "actions"]} + elif template_name == 'content': + if len(rule["conditions"]) != 1: + return None + thecond = rule["conditions"][0] + if "pattern" not in thecond: + return None + ret = {k: rule[k] for k in ["rule_id", "actions"]} + ret["pattern"] = thecond["pattern"] + return ret + + +def _strip_device_condition(rule): + for i, c in enumerate(rule['conditions']): + if c['kind'] == 'device': + del rule['conditions'][i] + return rule + + +class InvalidRuleException(Exception): + pass + + +# XXX: C+ped from rest/room.py - surely this should be common? +def _parse_json(request): + try: + content = json.loads(request.content.read()) + if type(content) != dict: + raise SynapseError(400, "Content must be a JSON object.", + errcode=Codes.NOT_JSON) + return content + except ValueError: + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + + +def register_servlets(hs, http_server): + PushRuleRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/pusher.py b/synapse/rest/client/v1/pusher.py new file mode 100644 index 0000000000..353a4a6589 --- /dev/null +++ b/synapse/rest/client/v1/pusher.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import SynapseError, Codes +from synapse.push import PusherConfigException +from .base import ClientV1RestServlet, client_path_pattern + +import json + + +class PusherRestServlet(ClientV1RestServlet): + PATTERN = client_path_pattern("/pushers/set$") + + @defer.inlineCallbacks + def on_POST(self, request): + user, _ = yield self.auth.get_user_by_req(request) + + content = _parse_json(request) + + pusher_pool = self.hs.get_pusherpool() + + if ('pushkey' in content and 'app_id' in content + and 'kind' in content and + content['kind'] is None): + yield pusher_pool.remove_pusher( + content['app_id'], content['pushkey'] + ) + defer.returnValue((200, {})) + + reqd = ['instance_handle', 'kind', 'app_id', 'app_display_name', + 'device_display_name', 'pushkey', 'lang', 'data'] + missing = [] + for i in reqd: + if i not in content: + missing.append(i) + if len(missing): + raise SynapseError(400, "Missing parameters: "+','.join(missing), + errcode=Codes.MISSING_PARAM) + + try: + yield pusher_pool.add_pusher( + user_name=user.to_string(), + instance_handle=content['instance_handle'], + kind=content['kind'], + app_id=content['app_id'], + app_display_name=content['app_display_name'], + device_display_name=content['device_display_name'], + pushkey=content['pushkey'], + lang=content['lang'], + data=content['data'] + ) + except PusherConfigException as pce: + raise SynapseError(400, "Config Error: "+pce.message, + errcode=Codes.MISSING_PARAM) + + defer.returnValue((200, {})) + + def on_OPTIONS(self, _): + return 200, {} + + +# XXX: C+ped from rest/room.py - surely this should be common? +def _parse_json(request): + try: + content = json.loads(request.content.read()) + if type(content) != dict: + raise SynapseError(400, "Content must be a JSON object.", + errcode=Codes.NOT_JSON) + return content + except ValueError: + raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON) + + +def register_servlets(hs, http_server): + PusherRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 58b09b6fc1..410f19ccf6 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -62,7 +62,7 @@ class RoomCreateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_POST(self, request): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) room_config = self.get_room_config(request) info = yield self.make_room(room_config, auth_user, None) @@ -125,7 +125,7 @@ class RoomStateEventRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id, event_type, state_key): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) msg_handler = self.handlers.message_handler data = yield msg_handler.get_room_data( @@ -142,8 +142,8 @@ class RoomStateEventRestServlet(ClientV1RestServlet): defer.returnValue((200, data.get_dict()["content"])) @defer.inlineCallbacks - def on_PUT(self, request, room_id, event_type, state_key): - user = yield self.auth.get_user_by_req(request) + def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -158,7 +158,9 @@ class RoomStateEventRestServlet(ClientV1RestServlet): event_dict["state_key"] = state_key msg_handler = self.handlers.message_handler - yield msg_handler.create_and_send_event(event_dict) + yield msg_handler.create_and_send_event( + event_dict, client=client, txn_id=txn_id, + ) defer.returnValue((200, {})) @@ -172,8 +174,8 @@ class RoomSendEventRestServlet(ClientV1RestServlet): register_txn_path(self, PATTERN, http_server, with_get=True) @defer.inlineCallbacks - def on_POST(self, request, room_id, event_type): - user = yield self.auth.get_user_by_req(request) + def on_POST(self, request, room_id, event_type, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -183,7 +185,9 @@ class RoomSendEventRestServlet(ClientV1RestServlet): "content": content, "room_id": room_id, "sender": user.to_string(), - } + }, + client=client, + txn_id=txn_id, ) defer.returnValue((200, {"event_id": event.event_id})) @@ -200,7 +204,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet): except KeyError: pass - response = yield self.on_POST(request, room_id, event_type) + response = yield self.on_POST(request, room_id, event_type, txn_id) self.txns.store_client_transaction(request, txn_id, response) defer.returnValue(response) @@ -215,8 +219,8 @@ class JoinRoomAliasServlet(ClientV1RestServlet): register_txn_path(self, PATTERN, http_server) @defer.inlineCallbacks - def on_POST(self, request, room_identifier): - user = yield self.auth.get_user_by_req(request) + def on_POST(self, request, room_identifier, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) # the identifier could be a room alias or a room id. Try one then the # other if it fails to parse, without swallowing other valid @@ -245,7 +249,9 @@ class JoinRoomAliasServlet(ClientV1RestServlet): "room_id": identifier.to_string(), "sender": user.to_string(), "state_key": user.to_string(), - } + }, + client=client, + txn_id=txn_id, ) defer.returnValue((200, {"room_id": identifier.to_string()})) @@ -259,7 +265,7 @@ class JoinRoomAliasServlet(ClientV1RestServlet): except KeyError: pass - response = yield self.on_POST(request, room_identifier) + response = yield self.on_POST(request, room_identifier, txn_id) self.txns.store_client_transaction(request, txn_id, response) defer.returnValue(response) @@ -283,7 +289,7 @@ class RoomMemberListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) handler = self.handlers.room_member_handler members = yield handler.get_room_members_as_pagination_chunk( room_id=room_id, @@ -311,7 +317,7 @@ class RoomMessageListRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request( request, default_limit=10, ) @@ -335,7 +341,7 @@ class RoomStateRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) handler = self.handlers.message_handler # Get all the current state for this room events = yield handler.get_state_events( @@ -351,7 +357,7 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request, room_id): - user = yield self.auth.get_user_by_req(request) + user, client = yield self.auth.get_user_by_req(request) pagination_config = PaginationConfig.from_request(request) content = yield self.handlers.message_handler.room_initial_sync( room_id=room_id, @@ -395,8 +401,8 @@ class RoomMembershipRestServlet(ClientV1RestServlet): register_txn_path(self, PATTERN, http_server) @defer.inlineCallbacks - def on_POST(self, request, room_id, membership_action): - user = yield self.auth.get_user_by_req(request) + def on_POST(self, request, room_id, membership_action, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) @@ -418,7 +424,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet): "room_id": room_id, "sender": user.to_string(), "state_key": state_key, - } + }, + client=client, + txn_id=txn_id, ) defer.returnValue((200, {})) @@ -432,7 +440,9 @@ class RoomMembershipRestServlet(ClientV1RestServlet): except KeyError: pass - response = yield self.on_POST(request, room_id, membership_action) + response = yield self.on_POST( + request, room_id, membership_action, txn_id + ) self.txns.store_client_transaction(request, txn_id, response) defer.returnValue(response) @@ -444,8 +454,8 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): register_txn_path(self, PATTERN, http_server) @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id): - user = yield self.auth.get_user_by_req(request) + def on_POST(self, request, room_id, event_id, txn_id=None): + user, client = yield self.auth.get_user_by_req(request) content = _parse_json(request) msg_handler = self.handlers.message_handler @@ -456,7 +466,9 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): "room_id": room_id, "sender": user.to_string(), "redacts": event_id, - } + }, + client=client, + txn_id=txn_id, ) defer.returnValue((200, {"event_id": event.event_id})) @@ -470,7 +482,7 @@ class RoomRedactEventRestServlet(ClientV1RestServlet): except KeyError: pass - response = yield self.on_POST(request, room_id, event_id) + response = yield self.on_POST(request, room_id, event_id, txn_id) self.txns.store_client_transaction(request, txn_id, response) defer.returnValue(response) @@ -483,7 +495,7 @@ class RoomTypingRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_PUT(self, request, room_id, user_id): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) room_id = urllib.unquote(room_id) target_user = UserID.from_string(urllib.unquote(user_id)) diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py index 822d863ce6..11d08fbced 100644 --- a/synapse/rest/client/v1/voip.py +++ b/synapse/rest/client/v1/voip.py @@ -28,7 +28,7 @@ class VoipRestServlet(ClientV1RestServlet): @defer.inlineCallbacks def on_GET(self, request): - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) turnUris = self.hs.config.turn_uris turnSecret = self.hs.config.turn_shared_secret diff --git a/synapse/rest/client/v2_alpha/__init__.py b/synapse/rest/client/v2_alpha/__init__.py index bb740e2803..8f611de3a8 100644 --- a/synapse/rest/client/v2_alpha/__init__.py +++ b/synapse/rest/client/v2_alpha/__init__.py @@ -13,6 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from . import ( + sync, + filter +) from synapse.http.server import JsonResource @@ -26,4 +30,5 @@ class ClientV2AlphaRestResource(JsonResource): @staticmethod def register_servlets(client_resource, hs): - pass + sync.register_servlets(hs, client_resource) + filter.register_servlets(hs, client_resource) diff --git a/synapse/rest/client/v2_alpha/filter.py b/synapse/rest/client/v2_alpha/filter.py new file mode 100644 index 0000000000..6ddc495d23 --- /dev/null +++ b/synapse/rest/client/v2_alpha/filter.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.api.errors import AuthError, SynapseError +from synapse.http.servlet import RestServlet +from synapse.types import UserID + +from ._base import client_v2_pattern + +import json +import logging + + +logger = logging.getLogger(__name__) + + +class GetFilterRestServlet(RestServlet): + PATTERN = client_v2_pattern("/user/(?P[^/]*)/filter/(?P[^/]*)") + + def __init__(self, hs): + super(GetFilterRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.filtering = hs.get_filtering() + + @defer.inlineCallbacks + def on_GET(self, request, user_id, filter_id): + target_user = UserID.from_string(user_id) + auth_user, client = yield self.auth.get_user_by_req(request) + + if target_user != auth_user: + raise AuthError(403, "Cannot get filters for other users") + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only get filters for local users") + + try: + filter_id = int(filter_id) + except: + raise SynapseError(400, "Invalid filter_id") + + try: + filter = yield self.filtering.get_user_filter( + user_localpart=target_user.localpart, + filter_id=filter_id, + ) + + defer.returnValue((200, filter.filter_json)) + except KeyError: + raise SynapseError(400, "No such filter") + + +class CreateFilterRestServlet(RestServlet): + PATTERN = client_v2_pattern("/user/(?P[^/]*)/filter") + + def __init__(self, hs): + super(CreateFilterRestServlet, self).__init__() + self.hs = hs + self.auth = hs.get_auth() + self.filtering = hs.get_filtering() + + @defer.inlineCallbacks + def on_POST(self, request, user_id): + target_user = UserID.from_string(user_id) + auth_user, client = yield self.auth.get_user_by_req(request) + + if target_user != auth_user: + raise AuthError(403, "Cannot create filters for other users") + + if not self.hs.is_mine(target_user): + raise SynapseError(400, "Can only create filters for local users") + + try: + content = json.loads(request.content.read()) + + # TODO(paul): check for required keys and invalid keys + except: + raise SynapseError(400, "Invalid filter definition") + + filter_id = yield self.filtering.add_user_filter( + user_localpart=target_user.localpart, + user_filter=content, + ) + + defer.returnValue((200, {"filter_id": str(filter_id)})) + + +def register_servlets(hs, http_server): + GetFilterRestServlet(hs).register(http_server) + CreateFilterRestServlet(hs).register(http_server) diff --git a/synapse/rest/client/v2_alpha/sync.py b/synapse/rest/client/v2_alpha/sync.py new file mode 100644 index 0000000000..81d5cf8ead --- /dev/null +++ b/synapse/rest/client/v2_alpha/sync.py @@ -0,0 +1,207 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from synapse.http.servlet import RestServlet +from synapse.handlers.sync import SyncConfig +from synapse.types import StreamToken +from synapse.events.utils import ( + serialize_event, format_event_for_client_v2_without_event_id, +) +from synapse.api.filtering import Filter +from ._base import client_v2_pattern + +import logging + +logger = logging.getLogger(__name__) + + +class SyncRestServlet(RestServlet): + """ + + GET parameters:: + timeout(int): How long to wait for new events in milliseconds. + limit(int): Maxiumum number of events per room to return. + gap(bool): Create gaps the message history if limit is exceeded to + ensure that the client has the most recent messages. Defaults to + "true". + sort(str,str): tuple of sort key (e.g. "timeline") and direction + (e.g. "asc", "desc"). Defaults to "timeline,asc". + since(batch_token): Batch token when asking for incremental deltas. + set_presence(str): What state the device presence should be set to. + default is "online". + backfill(bool): Should the HS request message history from other + servers. This may take a long time making it unsuitable for clients + expecting a prompt response. Defaults to "true". + filter(filter_id): A filter to apply to the events returned. + filter_*: Filter override parameters. + + Response JSON:: + { + "next_batch": // batch token for the next /sync + "private_user_data": // private events for this user. + "public_user_data": // public events for all users including the + // public events for this user. + "rooms": [{ // List of rooms with updates. + "room_id": // Id of the room being updated + "limited": // Was the per-room event limit exceeded? + "published": // Is the room published by our HS? + "event_map": // Map of EventID -> event JSON. + "events": { // The recent events in the room if gap is "true" + // otherwise the next events in the room. + "batch": [] // list of EventIDs in the "event_map". + "prev_batch": // back token for getting previous events. + } + "state": [] // list of EventIDs updating the current state to + // be what it should be at the end of the batch. + "ephemeral": [] + }] + } + """ + + PATTERN = client_v2_pattern("/sync$") + ALLOWED_SORT = set(["timeline,asc", "timeline,desc"]) + ALLOWED_PRESENCE = set(["online", "offline", "idle"]) + + def __init__(self, hs): + super(SyncRestServlet, self).__init__() + self.auth = hs.get_auth() + self.sync_handler = hs.get_handlers().sync_handler + self.clock = hs.get_clock() + self.filtering = hs.get_filtering() + + @defer.inlineCallbacks + def on_GET(self, request): + user, client = yield self.auth.get_user_by_req(request) + + timeout = self.parse_integer(request, "timeout", default=0) + limit = self.parse_integer(request, "limit", required=True) + gap = self.parse_boolean(request, "gap", default=True) + sort = self.parse_string( + request, "sort", default="timeline,asc", + allowed_values=self.ALLOWED_SORT + ) + since = self.parse_string(request, "since") + set_presence = self.parse_string( + request, "set_presence", default="online", + allowed_values=self.ALLOWED_PRESENCE + ) + backfill = self.parse_boolean(request, "backfill", default=False) + filter_id = self.parse_string(request, "filter", default=None) + + logger.info( + "/sync: user=%r, timeout=%r, limit=%r, gap=%r, sort=%r, since=%r," + " set_presence=%r, backfill=%r, filter_id=%r" % ( + user, timeout, limit, gap, sort, since, set_presence, + backfill, filter_id + ) + ) + + # TODO(mjark): Load filter and apply overrides. + try: + filter = yield self.filtering.get_user_filter( + user.localpart, filter_id + ) + except: + filter = Filter({}) + # filter = filter.apply_overrides(http_request) + #if filter.matches(event): + # # stuff + + sync_config = SyncConfig( + user=user, + client_info=client, + gap=gap, + limit=limit, + sort=sort, + backfill=backfill, + filter=filter, + ) + + if since is not None: + since_token = StreamToken.from_string(since) + else: + since_token = None + + sync_result = yield self.sync_handler.wait_for_sync_for_user( + sync_config, since_token=since_token, timeout=timeout + ) + + time_now = self.clock.time_msec() + + response_content = { + "public_user_data": self.encode_user_data( + sync_result.public_user_data, filter, time_now + ), + "private_user_data": self.encode_user_data( + sync_result.private_user_data, filter, time_now + ), + "rooms": self.encode_rooms( + sync_result.rooms, filter, time_now, client.token_id + ), + "next_batch": sync_result.next_batch.to_string(), + } + + defer.returnValue((200, response_content)) + + def encode_user_data(self, events, filter, time_now): + return events + + def encode_rooms(self, rooms, filter, time_now, token_id): + return [ + self.encode_room(room, filter, time_now, token_id) + for room in rooms + ] + + @staticmethod + def encode_room(room, filter, time_now, token_id): + event_map = {} + state_events = filter.filter_room_state(room.state) + recent_events = filter.filter_room_events(room.events) + state_event_ids = [] + recent_event_ids = [] + for event in state_events: + # TODO(mjark): Respect formatting requirements in the filter. + event_map[event.event_id] = serialize_event( + event, time_now, token_id=token_id, + event_format=format_event_for_client_v2_without_event_id, + ) + state_event_ids.append(event.event_id) + + for event in recent_events: + # TODO(mjark): Respect formatting requirements in the filter. + event_map[event.event_id] = serialize_event( + event, time_now, token_id=token_id, + event_format=format_event_for_client_v2_without_event_id, + ) + recent_event_ids.append(event.event_id) + result = { + "room_id": room.room_id, + "event_map": event_map, + "events": { + "batch": recent_event_ids, + "prev_batch": room.prev_batch.to_string(), + }, + "state": state_event_ids, + "limited": room.limited, + "published": room.published, + "ephemeral": room.ephemeral, + } + return result + + +def register_servlets(hs, http_server): + SyncRestServlet(hs).register(http_server) diff --git a/synapse/rest/media/v0/content_repository.py b/synapse/rest/media/v0/content_repository.py index 79ae0e3d74..22e26e3cd5 100644 --- a/synapse/rest/media/v0/content_repository.py +++ b/synapse/rest/media/v0/content_repository.py @@ -66,7 +66,7 @@ class ContentRepoResource(resource.Resource): @defer.inlineCallbacks def map_request_to_name(self, request): # auth the user - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) # namespace all file uploads on the user prefix = base64.urlsafe_b64encode( diff --git a/synapse/rest/media/v1/upload_resource.py b/synapse/rest/media/v1/upload_resource.py index b1718a630b..b939a30e19 100644 --- a/synapse/rest/media/v1/upload_resource.py +++ b/synapse/rest/media/v1/upload_resource.py @@ -42,7 +42,7 @@ class UploadResource(BaseMediaResource): @defer.inlineCallbacks def _async_render_POST(self, request): try: - auth_user = yield self.auth.get_user_by_req(request) + auth_user, client = yield self.auth.get_user_by_req(request) # TODO: The checks here are a bit late. The content will have # already been uploaded to a tmp file at this point content_length = request.getHeader("Content-Length") diff --git a/synapse/server.py b/synapse/server.py index 891c5aa13d..ba2b2593f1 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -31,7 +31,9 @@ from synapse.util.lockutils import LockManager from synapse.streams.events import EventSources from synapse.api.ratelimiting import Ratelimiter from synapse.crypto.keyring import Keyring +from synapse.push.pusherpool import PusherPool from synapse.events.builder import EventBuilderFactory +from synapse.api.filtering import Filtering class BaseHomeServer(object): @@ -79,7 +81,9 @@ class BaseHomeServer(object): 'event_sources', 'ratelimiter', 'keyring', + 'pusherpool', 'event_builder_factory', + 'filtering', ] def __init__(self, hostname, **kwargs): @@ -198,3 +202,9 @@ class HomeServer(BaseHomeServer): clock=self.get_clock(), hostname=self.hostname, ) + + def build_filtering(self): + return Filtering(self) + + def build_pusherpool(self): + return PusherPool(self) diff --git a/synapse/state.py b/synapse/state.py index 8144fa02b4..8a056ee955 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -19,6 +19,7 @@ from twisted.internet import defer from synapse.util.logutils import log_function from synapse.util.async import run_on_reactor from synapse.api.constants import EventTypes +from synapse.api.errors import AuthError from synapse.events.snapshot import EventContext from collections import namedtuple @@ -36,12 +37,16 @@ def _get_state_key_from_event(event): KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key")) +AuthEventTypes = (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,) + + class StateHandler(object): """ Responsible for doing state conflict resolution. """ def __init__(self, hs): self.store = hs.get_datastore() + self.hs = hs @defer.inlineCallbacks def get_current_state(self, room_id, event_type=None, state_key=""): @@ -163,10 +168,17 @@ class StateHandler(object): first is the name of a state group if one and only one is involved, otherwise `None`. """ + logger.debug("resolve_state_groups event_ids %s", event_ids) + state_groups = yield self.store.get_state_groups( event_ids ) + logger.debug( + "resolve_state_groups state_groups %s", + state_groups.keys() + ) + group_names = set(state_groups.keys()) if len(group_names) == 1: name, state_list = state_groups.items().pop() @@ -210,64 +222,93 @@ class StateHandler(object): else: prev_states = [] + auth_events = { + k: e for k, e in unconflicted_state.items() + if k[0] in AuthEventTypes + } + try: - new_state = {} - new_state.update(unconflicted_state) - for key, events in conflicted_state.items(): - new_state[key] = self._resolve_state_events(events) + resolved_state = self._resolve_state_events( + conflicted_state, auth_events + ) except: logger.exception("Failed to resolve state") raise + new_state = unconflicted_state + new_state.update(resolved_state) + defer.returnValue((None, new_state, prev_states)) - def _get_power_level_from_event_state(self, event, user_id): - if hasattr(event, "old_state_events") and event.old_state_events: - key = (EventTypes.PowerLevels, "", ) - power_level_event = event.old_state_events.get(key) - level = None - if power_level_event: - level = power_level_event.content.get("users", {}).get( - user_id - ) - if not level: - level = power_level_event.content.get("users_default", 0) - - return level - else: - return 0 - @log_function - def _resolve_state_events(self, events): - curr_events = events + def _resolve_state_events(self, conflicted_state, auth_events): + """ This is where we actually decide which of the conflicted state to + use. - new_powers = [ - self._get_power_level_from_event_state(e, e.user_id) - for e in curr_events - ] + We resolve conflicts in the following order: + 1. power levels + 2. memberships + 3. other events. + """ + resolved_state = {} + power_key = (EventTypes.PowerLevels, "") + if power_key in conflicted_state.items(): + power_levels = conflicted_state[power_key] + resolved_state[power_key] = self._resolve_auth_events(power_levels) - new_powers = [ - int(p) if p else 0 for p in new_powers - ] + auth_events.update(resolved_state) - max_power = max(new_powers) + for key, events in conflicted_state.items(): + if key[0] == EventTypes.Member: + resolved_state[key] = self._resolve_auth_events( + events, + auth_events + ) - curr_events = [ - z[0] for z in zip(curr_events, new_powers) - if z[1] == max_power - ] + auth_events.update(resolved_state) - if not curr_events: - raise RuntimeError("Max didn't get a max?") - elif len(curr_events) == 1: - return curr_events[0] + for key, events in conflicted_state.items(): + if key not in resolved_state: + resolved_state[key] = self._resolve_normal_events( + events, auth_events + ) - # TODO: For now, just choose the one with the largest event_id. - return ( - sorted( - curr_events, - key=lambda e: hashlib.sha1( - e.event_id + e.user_id + e.room_id + e.type - ).hexdigest() - )[0] - ) + return resolved_state + + def _resolve_auth_events(self, events, auth_events): + reverse = [i for i in reversed(self._ordered_events(events))] + + auth_events = dict(auth_events) + + prev_event = reverse[0] + for event in reverse[1:]: + auth_events[(prev_event.type, prev_event.state_key)] = prev_event + try: + # FIXME: hs.get_auth() is bad style, but we need to do it to + # get around circular deps. + self.hs.get_auth().check(event, auth_events) + prev_event = event + except AuthError: + return prev_event + + return event + + def _resolve_normal_events(self, events, auth_events): + for event in self._ordered_events(events): + try: + # FIXME: hs.get_auth() is bad style, but we need to do it to + # get around circular deps. + self.hs.get_auth().check(event, auth_events) + return event + except AuthError: + pass + + # Use the last event (the one with the least depth) if they all fail + # the auth check. + return event + + def _ordered_events(self, events): + def key_func(e): + return -int(e.depth), hashlib.sha1(e.event_id).hexdigest() + + return sorted(events, key=key_func) diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py index e86b981b47..9bbd553dfc 100644 --- a/synapse/storage/__init__.py +++ b/synapse/storage/__init__.py @@ -30,10 +30,14 @@ from .stream import StreamStore from .transactions import TransactionStore from .keys import KeyStore from .event_federation import EventFederationStore +from .pusher import PusherStore +from .push_rule import PushRuleStore from .media_repository import MediaRepositoryStore +from .rejections import RejectionsStore from .state import StateStore from .signatures import SignatureStore +from .filtering import FilteringStore from syutil.base64util import decode_base64 from syutil.jsonutil import encode_canonical_json @@ -61,14 +65,20 @@ SCHEMAS = [ "state", "event_edges", "event_signatures", + "pusher", "media_repository", +<<<<<<< HEAD "application_services" +======= + "filtering", + "rejections", +>>>>>>> develop ] # Remember to update this number every time an incompatible change is made to # database schema files, so the users will be informed on server restarts. -SCHEMA_VERSION = 11 +SCHEMA_VERSION = 12 class _RollbackButIsFineException(Exception): @@ -82,8 +92,17 @@ class DataStore(RoomMemberStore, RoomStore, RegistrationStore, StreamStore, ProfileStore, FeedbackStore, PresenceStore, TransactionStore, DirectoryStore, KeyStore, StateStore, SignatureStore, +<<<<<<< HEAD EventFederationStore, MediaRepositoryStore, ApplicationServiceStore +======= + EventFederationStore, + MediaRepositoryStore, + RejectionsStore, + FilteringStore, + PusherStore, + PushRuleStore +>>>>>>> develop ): def __init__(self, hs): @@ -226,6 +245,9 @@ class DataStore(RoomMemberStore, RoomStore, if not outlier: self._store_state_groups_txn(txn, event, context) + if context.rejected: + self._store_rejections_txn(txn, event.event_id, context.rejected) + if current_state: txn.execute( "DELETE FROM current_state_events WHERE room_id = ?", @@ -264,7 +286,7 @@ class DataStore(RoomMemberStore, RoomStore, or_replace=True, ) - if is_new_state: + if is_new_state and not context.rejected: self._simple_insert_txn( txn, "current_state_events", @@ -290,7 +312,7 @@ class DataStore(RoomMemberStore, RoomStore, or_ignore=True, ) - if not backfilled: + if not backfilled and not context.rejected: self._simple_insert_txn( txn, table="state_forward_extremities", @@ -372,9 +394,12 @@ class DataStore(RoomMemberStore, RoomStore, "redacted": del_sql, } - if event_type: + if event_type and state_key is not None: sql += " AND s.type = ? AND s.state_key = ? " args = (room_id, event_type, state_key) + elif event_type: + sql += " AND s.type = ?" + args = (room_id, event_type) else: args = (room_id, ) @@ -383,6 +408,41 @@ class DataStore(RoomMemberStore, RoomStore, events = yield self._parse_events(results) defer.returnValue(events) + @defer.inlineCallbacks + def get_room_name_and_aliases(self, room_id): + del_sql = ( + "SELECT event_id FROM redactions WHERE redacts = e.event_id " + "LIMIT 1" + ) + + sql = ( + "SELECT e.*, (%(redacted)s) AS redacted FROM events as e " + "INNER JOIN current_state_events as c ON e.event_id = c.event_id " + "INNER JOIN state_events as s ON e.event_id = s.event_id " + "WHERE c.room_id = ? " + ) % { + "redacted": del_sql, + } + + sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')" + sql += " OR s.type = 'm.room.aliases')" + args = (room_id,) + + results = yield self._execute_and_decode(sql, *args) + + events = yield self._parse_events(results) + + name = None + aliases = [] + + for e in events: + if e.type == 'm.room.name': + name = e.content['name'] + elif e.type == 'm.room.aliases': + aliases.extend(e.content['aliases']) + + defer.returnValue((name, aliases)) + @defer.inlineCallbacks def _get_min_token(self): row = yield self._execute( @@ -419,6 +479,35 @@ class DataStore(RoomMemberStore, RoomStore, ], ) + def have_events(self, event_ids): + """Given a list of event ids, check if we have already processed them. + + Returns: + dict: Has an entry for each event id we already have seen. Maps to + the rejected reason string if we rejected the event, else maps to + None. + """ + def f(txn): + sql = ( + "SELECT e.event_id, reason FROM events as e " + "LEFT JOIN rejections as r ON e.event_id = r.event_id " + "WHERE e.event_id = ?" + ) + + res = {} + for event_id in event_ids: + txn.execute(sql, (event_id,)) + row = txn.fetchone() + if row: + _, rejected = row + res[event_id] = rejected + + return res + + return self.runInteraction( + "have_events", f, + ) + def schema_path(schema): """ Get a filesystem path for the named database schema diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f660fc6eaf..b350fd61f1 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -193,6 +193,50 @@ class SQLBaseStore(object): txn.execute(sql, values.values()) return txn.lastrowid + def _simple_upsert(self, table, keyvalues, values): + """ + Args: + table (str): The table to upsert into + keyvalues (dict): The unique key tables and their new values + values (dict): The nonunique columns and their new values + Returns: A deferred + """ + return self.runInteraction( + "_simple_upsert", + self._simple_upsert_txn, table, keyvalues, values + ) + + def _simple_upsert_txn(self, txn, table, keyvalues, values): + # Try to update + sql = "UPDATE %s SET %s WHERE %s" % ( + table, + ", ".join("%s = ?" % (k,) for k in values), + " AND ".join("%s = ?" % (k,) for k in keyvalues) + ) + sqlargs = values.values() + keyvalues.values() + logger.debug( + "[SQL] %s Args=%s", + sql, sqlargs, + ) + + txn.execute(sql, sqlargs) + if txn.rowcount == 0: + # We didn't update and rows so insert a new one + allvalues = {} + allvalues.update(keyvalues) + allvalues.update(values) + + sql = "INSERT INTO %s (%s) VALUES (%s)" % ( + table, + ", ".join(k for k in allvalues), + ", ".join("?" for _ in allvalues) + ) + logger.debug( + "[SQL] %s Args=%s", + sql, keyvalues.values(), + ) + txn.execute(sql, allvalues.values()) + def _simple_select_one(self, table, keyvalues, retcols, allow_none=False): """Executes a SELECT query on the named table, which is expected to @@ -344,8 +388,8 @@ class SQLBaseStore(object): if updatevalues: update_sql = "UPDATE %s SET %s WHERE %s" % ( table, - ", ".join("%s = ?" % (k) for k in updatevalues), - " AND ".join("%s = ?" % (k) for k in keyvalues) + ", ".join("%s = ?" % (k,) for k in updatevalues), + " AND ".join("%s = ?" % (k,) for k in keyvalues) ) def func(txn): @@ -458,10 +502,12 @@ class SQLBaseStore(object): return [e for e in events if e] def _get_event_txn(self, txn, event_id, check_redacted=True, - get_prev_content=False): + get_prev_content=False, allow_rejected=False): sql = ( - "SELECT internal_metadata, json, r.event_id FROM event_json as e " + "SELECT e.internal_metadata, e.json, r.event_id, rej.reason " + "FROM event_json as e " "LEFT JOIN redactions as r ON e.event_id = r.redacts " + "LEFT JOIN rejections as rej on rej.event_id = e.event_id " "WHERE e.event_id = ? " "LIMIT 1 " ) @@ -473,13 +519,16 @@ class SQLBaseStore(object): if not res: return None - internal_metadata, js, redacted = res + internal_metadata, js, redacted, rejected_reason = res - return self._get_event_from_row_txn( - txn, internal_metadata, js, redacted, - check_redacted=check_redacted, - get_prev_content=get_prev_content, - ) + if allow_rejected or not rejected_reason: + return self._get_event_from_row_txn( + txn, internal_metadata, js, redacted, + check_redacted=check_redacted, + get_prev_content=get_prev_content, + ) + else: + return None def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted, check_redacted=True, get_prev_content=False): diff --git a/synapse/storage/filtering.py b/synapse/storage/filtering.py new file mode 100644 index 0000000000..e86eeced45 --- /dev/null +++ b/synapse/storage/filtering.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from ._base import SQLBaseStore + +import json + + +class FilteringStore(SQLBaseStore): + @defer.inlineCallbacks + def get_user_filter(self, user_localpart, filter_id): + def_json = yield self._simple_select_one_onecol( + table="user_filters", + keyvalues={ + "user_id": user_localpart, + "filter_id": filter_id, + }, + retcol="filter_json", + allow_none=False, + ) + + defer.returnValue(json.loads(def_json)) + + def add_user_filter(self, user_localpart, user_filter): + def_json = json.dumps(user_filter) + + # Need an atomic transaction to SELECT the maximal ID so far then + # INSERT a new one + def _do_txn(txn): + sql = ( + "SELECT MAX(filter_id) FROM user_filters " + "WHERE user_id = ?" + ) + txn.execute(sql, (user_localpart,)) + max_id = txn.fetchone()[0] + if max_id is None: + filter_id = 0 + else: + filter_id = max_id + 1 + + sql = ( + "INSERT INTO user_filters (user_id, filter_id, filter_json)" + "VALUES(?, ?, ?)" + ) + txn.execute(sql, (user_localpart, filter_id, def_json)) + + return filter_id + + return self.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py new file mode 100644 index 0000000000..27502d2399 --- /dev/null +++ b/synapse/storage/push_rule.py @@ -0,0 +1,213 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + +from ._base import SQLBaseStore, Table +from twisted.internet import defer + +import logging +import copy +import json + +logger = logging.getLogger(__name__) + + +class PushRuleStore(SQLBaseStore): + @defer.inlineCallbacks + def get_push_rules_for_user_name(self, user_name): + sql = ( + "SELECT "+",".join(PushRuleTable.fields)+" " + "FROM "+PushRuleTable.table_name+" " + "WHERE user_name = ? " + "ORDER BY priority_class DESC, priority DESC" + ) + rows = yield self._execute(None, sql, user_name) + + dicts = [] + for r in rows: + d = {} + for i, f in enumerate(PushRuleTable.fields): + d[f] = r[i] + dicts.append(d) + + defer.returnValue(dicts) + + @defer.inlineCallbacks + def add_push_rule(self, before, after, **kwargs): + vals = copy.copy(kwargs) + if 'conditions' in vals: + vals['conditions'] = json.dumps(vals['conditions']) + if 'actions' in vals: + vals['actions'] = json.dumps(vals['actions']) + # we could check the rest of the keys are valid column names + # but sqlite will do that anyway so I think it's just pointless. + if 'id' in vals: + del vals['id'] + + if before or after: + ret = yield self.runInteraction( + "_add_push_rule_relative_txn", + self._add_push_rule_relative_txn, + before=before, + after=after, + **vals + ) + defer.returnValue(ret) + else: + ret = yield self.runInteraction( + "_add_push_rule_highest_priority_txn", + self._add_push_rule_highest_priority_txn, + **vals + ) + defer.returnValue(ret) + + def _add_push_rule_relative_txn(self, txn, user_name, **kwargs): + after = None + relative_to_rule = None + if 'after' in kwargs and kwargs['after']: + after = kwargs['after'] + relative_to_rule = after + if 'before' in kwargs and kwargs['before']: + relative_to_rule = kwargs['before'] + + # get the priority of the rule we're inserting after/before + sql = ( + "SELECT priority_class, priority FROM ? " + "WHERE user_name = ? and rule_id = ?" % (PushRuleTable.table_name,) + ) + txn.execute(sql, (user_name, relative_to_rule)) + res = txn.fetchall() + if not res: + raise RuleNotFoundException("before/after rule not found: %s" % (relative_to_rule)) + priority_class, base_rule_priority = res[0] + + if 'priority_class' in kwargs and kwargs['priority_class'] != priority_class: + raise InconsistentRuleException( + "Given priority class does not match class of relative rule" + ) + + new_rule = copy.copy(kwargs) + if 'before' in new_rule: + del new_rule['before'] + if 'after' in new_rule: + del new_rule['after'] + new_rule['priority_class'] = priority_class + new_rule['user_name'] = user_name + + # check if the priority before/after is free + new_rule_priority = base_rule_priority + if after: + new_rule_priority -= 1 + else: + new_rule_priority += 1 + + new_rule['priority'] = new_rule_priority + + sql = ( + "SELECT COUNT(*) FROM " + PushRuleTable.table_name + + " WHERE user_name = ? AND priority_class = ? AND priority = ?" + ) + txn.execute(sql, (user_name, priority_class, new_rule_priority)) + res = txn.fetchall() + num_conflicting = res[0][0] + + # if there are conflicting rules, bump everything + if num_conflicting: + sql = "UPDATE "+PushRuleTable.table_name+" SET priority = priority " + if after: + sql += "-1" + else: + sql += "+1" + sql += " WHERE user_name = ? AND priority_class = ? AND priority " + if after: + sql += "<= ?" + else: + sql += ">= ?" + + txn.execute(sql, (user_name, priority_class, new_rule_priority)) + + # now insert the new rule + sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" (" + sql += ",".join(new_rule.keys())+") VALUES (" + sql += ", ".join(["?" for _ in new_rule.keys()])+")" + + txn.execute(sql, new_rule.values()) + + def _add_push_rule_highest_priority_txn(self, txn, user_name, + priority_class, **kwargs): + # find the highest priority rule in that class + sql = ( + "SELECT COUNT(*), MAX(priority) FROM " + PushRuleTable.table_name + + " WHERE user_name = ? and priority_class = ?" + ) + txn.execute(sql, (user_name, priority_class)) + res = txn.fetchall() + (how_many, highest_prio) = res[0] + + new_prio = 0 + if how_many > 0: + new_prio = highest_prio + 1 + + # and insert the new rule + new_rule = copy.copy(kwargs) + if 'id' in new_rule: + del new_rule['id'] + new_rule['user_name'] = user_name + new_rule['priority_class'] = priority_class + new_rule['priority'] = new_prio + + sql = "INSERT OR REPLACE INTO "+PushRuleTable.table_name+" (" + sql += ",".join(new_rule.keys())+") VALUES (" + sql += ", ".join(["?" for _ in new_rule.keys()])+")" + + txn.execute(sql, new_rule.values()) + + @defer.inlineCallbacks + def delete_push_rule(self, user_name, rule_id, **kwargs): + """ + Delete a push rule. Args specify the row to be deleted and can be + any of the columns in the push_rule table, but below are the + standard ones + + Args: + user_name (str): The matrix ID of the push rule owner + rule_id (str): The rule_id of the rule to be deleted + """ + yield self._simple_delete_one(PushRuleTable.table_name, kwargs) + + +class RuleNotFoundException(Exception): + pass + + +class InconsistentRuleException(Exception): + pass + + +class PushRuleTable(Table): + table_name = "push_rules" + + fields = [ + "id", + "user_name", + "rule_id", + "priority_class", + "priority", + "conditions", + "actions", + ] + + EntryType = collections.namedtuple("PushRuleEntry", fields) diff --git a/synapse/storage/pusher.py b/synapse/storage/pusher.py new file mode 100644 index 0000000000..f253c9e2c3 --- /dev/null +++ b/synapse/storage/pusher.py @@ -0,0 +1,173 @@ +# -*- coding: utf-8 -*- +# Copyright 2014 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + +from ._base import SQLBaseStore, Table +from twisted.internet import defer + +from synapse.api.errors import StoreError + +import logging + +logger = logging.getLogger(__name__) + + +class PusherStore(SQLBaseStore): + @defer.inlineCallbacks + def get_pushers_by_app_id_and_pushkey(self, app_id_and_pushkey): + sql = ( + "SELECT id, user_name, kind, instance_handle, app_id," + "app_display_name, device_display_name, pushkey, ts, data, " + "last_token, last_success, failing_since " + "FROM pushers " + "WHERE app_id = ? AND pushkey = ?" + ) + + rows = yield self._execute( + None, sql, app_id_and_pushkey[0], app_id_and_pushkey[1] + ) + + ret = [ + { + "id": r[0], + "user_name": r[1], + "kind": r[2], + "instance_handle": r[3], + "app_id": r[4], + "app_display_name": r[5], + "device_display_name": r[6], + "pushkey": r[7], + "pushkey_ts": r[8], + "data": r[9], + "last_token": r[10], + "last_success": r[11], + "failing_since": r[12] + } + for r in rows + ] + + defer.returnValue(ret[0]) + + @defer.inlineCallbacks + def get_all_pushers(self): + sql = ( + "SELECT id, user_name, kind, instance_handle, app_id," + "app_display_name, device_display_name, pushkey, ts, data, " + "last_token, last_success, failing_since " + "FROM pushers" + ) + + rows = yield self._execute(None, sql) + + ret = [ + { + "id": r[0], + "user_name": r[1], + "kind": r[2], + "instance_handle": r[3], + "app_id": r[4], + "app_display_name": r[5], + "device_display_name": r[6], + "pushkey": r[7], + "pushkey_ts": r[8], + "data": r[9], + "last_token": r[10], + "last_success": r[11], + "failing_since": r[12] + } + for r in rows + ] + + defer.returnValue(ret) + + @defer.inlineCallbacks + def add_pusher(self, user_name, instance_handle, kind, app_id, + app_display_name, device_display_name, + pushkey, pushkey_ts, lang, data): + try: + yield self._simple_upsert( + PushersTable.table_name, + dict( + app_id=app_id, + pushkey=pushkey, + ), + dict( + user_name=user_name, + kind=kind, + instance_handle=instance_handle, + app_display_name=app_display_name, + device_display_name=device_display_name, + ts=pushkey_ts, + lang=lang, + data=data + )) + except Exception as e: + logger.error("create_pusher with failed: %s", e) + raise StoreError(500, "Problem creating pusher.") + + @defer.inlineCallbacks + def delete_pusher_by_app_id_pushkey(self, app_id, pushkey): + yield self._simple_delete_one( + PushersTable.table_name, + dict(app_id=app_id, pushkey=pushkey) + ) + + @defer.inlineCallbacks + def update_pusher_last_token(self, user_name, pushkey, last_token): + yield self._simple_update_one( + PushersTable.table_name, + {'user_name': user_name, 'pushkey': pushkey}, + {'last_token': last_token} + ) + + @defer.inlineCallbacks + def update_pusher_last_token_and_success(self, user_name, pushkey, + last_token, last_success): + yield self._simple_update_one( + PushersTable.table_name, + {'user_name': user_name, 'pushkey': pushkey}, + {'last_token': last_token, 'last_success': last_success} + ) + + @defer.inlineCallbacks + def update_pusher_failing_since(self, user_name, pushkey, failing_since): + yield self._simple_update_one( + PushersTable.table_name, + {'user_name': user_name, 'pushkey': pushkey}, + {'failing_since': failing_since} + ) + + +class PushersTable(Table): + table_name = "pushers" + + fields = [ + "id", + "user_name", + "kind", + "instance_handle", + "app_id", + "app_display_name", + "device_display_name", + "pushkey", + "pushkey_ts", + "data", + "last_token", + "last_success", + "failing_since" + ] + + EntryType = collections.namedtuple("PusherEntry", fields) diff --git a/synapse/storage/registration.py b/synapse/storage/registration.py index 75dffa4db2..029b07cc66 100644 --- a/synapse/storage/registration.py +++ b/synapse/storage/registration.py @@ -122,7 +122,8 @@ class RegistrationStore(SQLBaseStore): def _query_for_auth(self, txn, token): sql = ( - "SELECT users.name, users.admin, access_tokens.device_id" + "SELECT users.name, users.admin," + " access_tokens.device_id, access_tokens.id as token_id" " FROM users" " INNER JOIN access_tokens on users.id = access_tokens.user_id" " WHERE token = ?" diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py new file mode 100644 index 0000000000..4e1a9a2783 --- /dev/null +++ b/synapse/storage/rejections.py @@ -0,0 +1,43 @@ +# -*- coding: utf-8 -*- +# Copyright 2014, 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ._base import SQLBaseStore + +import logging + +logger = logging.getLogger(__name__) + + +class RejectionsStore(SQLBaseStore): + def _store_rejections_txn(self, txn, event_id, reason): + self._simple_insert_txn( + txn, + table="rejections", + values={ + "event_id": event_id, + "reason": reason, + "last_check": self._clock.time_msec(), + } + ) + + def get_rejection_reason(self, event_id): + return self._simple_select_one_onecol( + table="rejections", + retcol="reason", + keyvalues={ + "event_id": event_id, + }, + allow_none=True, + ) diff --git a/synapse/storage/schema/delta/v12.sql b/synapse/storage/schema/delta/v12.sql new file mode 100644 index 0000000000..a6867cba62 --- /dev/null +++ b/synapse/storage/schema/delta/v12.sql @@ -0,0 +1,54 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS rejections( + event_id TEXT NOT NULL, + reason TEXT NOT NULL, + last_check TEXT NOT NULL, + CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE +); + +-- Push notification endpoints that users have configured +CREATE TABLE IF NOT EXISTS pushers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + instance_handle varchar(32) NOT NULL, + kind varchar(8) NOT NULL, + app_id varchar(64) NOT NULL, + app_display_name varchar(64) NOT NULL, + device_display_name varchar(128) NOT NULL, + pushkey blob NOT NULL, + ts BIGINT NOT NULL, + lang varchar(8), + data blob, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + FOREIGN KEY(user_name) REFERENCES users(name), + UNIQUE (app_id, pushkey) +); + +CREATE TABLE IF NOT EXISTS push_rules ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + priority_class TINYINT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + conditions TEXT NOT NULL, + actions TEXT NOT NULL, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name); diff --git a/synapse/storage/schema/delta/v13.sql b/synapse/storage/schema/delta/v13.sql new file mode 100644 index 0000000000..beb39ca201 --- /dev/null +++ b/synapse/storage/schema/delta/v13.sql @@ -0,0 +1,24 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS user_filters( + user_id TEXT, + filter_id INTEGER, + filter_json TEXT, + FOREIGN KEY(user_id) REFERENCES users(id) +); + +CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters( + user_id, filter_id +); diff --git a/synapse/storage/schema/filtering.sql b/synapse/storage/schema/filtering.sql new file mode 100644 index 0000000000..beb39ca201 --- /dev/null +++ b/synapse/storage/schema/filtering.sql @@ -0,0 +1,24 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +CREATE TABLE IF NOT EXISTS user_filters( + user_id TEXT, + filter_id INTEGER, + filter_json TEXT, + FOREIGN KEY(user_id) REFERENCES users(id) +); + +CREATE INDEX IF NOT EXISTS user_filters_by_user_id_filter_id ON user_filters( + user_id, filter_id +); diff --git a/synapse/storage/schema/pusher.sql b/synapse/storage/schema/pusher.sql new file mode 100644 index 0000000000..8c4dfd5c1b --- /dev/null +++ b/synapse/storage/schema/pusher.sql @@ -0,0 +1,46 @@ +/* Copyright 2014 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +-- Push notification endpoints that users have configured +CREATE TABLE IF NOT EXISTS pushers ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + instance_handle varchar(32) NOT NULL, + kind varchar(8) NOT NULL, + app_id varchar(64) NOT NULL, + app_display_name varchar(64) NOT NULL, + device_display_name varchar(128) NOT NULL, + pushkey blob NOT NULL, + ts BIGINT NOT NULL, + lang varchar(8), + data blob, + last_token TEXT, + last_success BIGINT, + failing_since BIGINT, + FOREIGN KEY(user_name) REFERENCES users(name), + UNIQUE (app_id, pushkey) +); + +CREATE TABLE IF NOT EXISTS push_rules ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + user_name TEXT NOT NULL, + rule_id TEXT NOT NULL, + priority_class TINYINT NOT NULL, + priority INTEGER NOT NULL DEFAULT 0, + conditions TEXT NOT NULL, + actions TEXT NOT NULL, + UNIQUE(user_name, rule_id) +); + +CREATE INDEX IF NOT EXISTS push_rules_user_name on push_rules (user_name); diff --git a/synapse/storage/schema/rejections.sql b/synapse/storage/schema/rejections.sql new file mode 100644 index 0000000000..bd2a8b1bb5 --- /dev/null +++ b/synapse/storage/schema/rejections.sql @@ -0,0 +1,21 @@ +/* Copyright 2015 OpenMarket Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +CREATE TABLE IF NOT EXISTS rejections( + event_id TEXT NOT NULL, + reason TEXT NOT NULL, + last_check TEXT NOT NULL, + CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE +); diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py index 8ac2adab05..3ccb6f8a61 100644 --- a/synapse/storage/stream.py +++ b/synapse/storage/stream.py @@ -82,10 +82,10 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")): def parse(cls, string): try: if string[0] == 's': - return cls(None, int(string[1:])) + return cls(topological=None, stream=int(string[1:])) if string[0] == 't': parts = string[1:].split('-', 1) - return cls(int(parts[1]), int(parts[0])) + return cls(topological=int(parts[0]), stream=int(parts[1])) except: pass raise SynapseError(400, "Invalid token %r" % (string,)) @@ -94,7 +94,7 @@ class _StreamToken(namedtuple("_StreamToken", "topological stream")): def parse_stream_token(cls, string): try: if string[0] == 's': - return cls(None, int(string[1:])) + return cls(topological=None, stream=int(string[1:])) except: pass raise SynapseError(400, "Invalid token %r" % (string,)) @@ -181,8 +181,11 @@ class StreamStore(SQLBaseStore): get_prev_content=True ) + self._set_before_and_after(ret, rows) + if rows: key = "s%d" % max([r["stream_ordering"] for r in rows]) + else: # Assume we didn't get anything because there was nothing to # get. @@ -260,22 +263,44 @@ class StreamStore(SQLBaseStore): get_prev_content=True ) + self._set_before_and_after(events, rows) + return events, next_token, return self.runInteraction("paginate_room_events", f) def get_recent_events_for_room(self, room_id, limit, end_token, - with_feedback=False): + with_feedback=False, from_token=None): # TODO (erikj): Handle compressed feedback - sql = ( - "SELECT stream_ordering, topological_ordering, event_id FROM events " - "WHERE room_id = ? AND stream_ordering <= ? AND outlier = 0 " - "ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ? " - ) + end_token = _StreamToken.parse_stream_token(end_token) - def f(txn): - txn.execute(sql, (room_id, end_token, limit,)) + if from_token is None: + sql = ( + "SELECT stream_ordering, topological_ordering, event_id" + " FROM events" + " WHERE room_id = ? AND stream_ordering <= ? AND outlier = 0" + " ORDER BY topological_ordering DESC, stream_ordering DESC" + " LIMIT ?" + ) + else: + from_token = _StreamToken.parse_stream_token(from_token) + sql = ( + "SELECT stream_ordering, topological_ordering, event_id" + " FROM events" + " WHERE room_id = ? AND stream_ordering > ?" + " AND stream_ordering <= ? AND outlier = 0" + " ORDER BY topological_ordering DESC, stream_ordering DESC" + " LIMIT ?" + ) + + def get_recent_events_for_room_txn(txn): + if from_token is None: + txn.execute(sql, (room_id, end_token.stream, limit,)) + else: + txn.execute(sql, ( + room_id, from_token.stream, end_token.stream, limit + )) rows = self.cursor_to_dict(txn) @@ -291,9 +316,9 @@ class StreamStore(SQLBaseStore): toke = rows[0]["stream_ordering"] - 1 start_token = str(_StreamToken(topo, toke)) - token = (start_token, end_token) + token = (start_token, str(end_token)) else: - token = (end_token, end_token) + token = (str(end_token), str(end_token)) events = self._get_events_txn( txn, @@ -301,9 +326,13 @@ class StreamStore(SQLBaseStore): get_prev_content=True ) + self._set_before_and_after(events, rows) + return events, token - return self.runInteraction("get_recent_events_for_room", f) + return self.runInteraction( + "get_recent_events_for_room", get_recent_events_for_room_txn + ) def get_room_events_max_id(self): return self.runInteraction( @@ -325,3 +354,12 @@ class StreamStore(SQLBaseStore): key = res[0]["m"] return "s%d" % (key,) + + @staticmethod + def _set_before_and_after(events, rows): + for event, row in zip(events, rows): + stream = row["stream_ordering"] + topo = event.depth + internal = event.internal_metadata + internal.before = str(_StreamToken(topo, stream - 1)) + internal.after = str(_StreamToken(topo, stream)) diff --git a/synapse/types.py b/synapse/types.py index faac729ff2..f6a1b0bbcf 100644 --- a/synapse/types.py +++ b/synapse/types.py @@ -119,3 +119,6 @@ class StreamToken( d = self._asdict() d[key] = new_value return StreamToken(**d) + + +ClientInfo = namedtuple("ClientInfo", ("device_id", "token_id")) diff --git a/tests/api/test_filtering.py b/tests/api/test_filtering.py new file mode 100644 index 0000000000..babf4c37f1 --- /dev/null +++ b/tests/api/test_filtering.py @@ -0,0 +1,512 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import namedtuple +from tests import unittest +from twisted.internet import defer + +from mock import Mock, NonCallableMock +from tests.utils import ( + MockHttpResource, MockClock, DeferredMockCallable, SQLiteMemoryDbPool, + MockKey +) + +from synapse.server import HomeServer +from synapse.types import UserID +from synapse.api.filtering import Filter + +user_localpart = "test_user" +MockEvent = namedtuple("MockEvent", "sender type room_id") + +class FilteringTestCase(unittest.TestCase): + + @defer.inlineCallbacks + def setUp(self): + db_pool = SQLiteMemoryDbPool() + yield db_pool.prepare() + + self.mock_config = NonCallableMock() + self.mock_config.signing_key = [MockKey()] + + self.mock_federation_resource = MockHttpResource() + + self.mock_http_client = Mock(spec=[]) + self.mock_http_client.put_json = DeferredMockCallable() + + hs = HomeServer("test", + db_pool=db_pool, + handlers=None, + http_client=self.mock_http_client, + config=self.mock_config, + keyring=Mock(), + ) + + self.filtering = hs.get_filtering() + self.filter = Filter({}) + + self.datastore = hs.get_datastore() + + def test_definition_types_works_with_literals(self): + definition = { + "types": ["m.room.message", "org.matrix.foo.bar"] + } + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!foo:bar" + ) + self.assertTrue( + self.filter._passes_definition(definition, event) + ) + + def test_definition_types_works_with_wildcards(self): + definition = { + "types": ["m.*", "org.matrix.foo.bar"] + } + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!foo:bar" + ) + self.assertTrue( + self.filter._passes_definition(definition, event) + ) + + def test_definition_types_works_with_unknowns(self): + definition = { + "types": ["m.room.message", "org.matrix.foo.bar"] + } + event = MockEvent( + sender="@foo:bar", + type="now.for.something.completely.different", + room_id="!foo:bar" + ) + self.assertFalse( + self.filter._passes_definition(definition, event) + ) + + def test_definition_not_types_works_with_literals(self): + definition = { + "not_types": ["m.room.message", "org.matrix.foo.bar"] + } + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!foo:bar" + ) + self.assertFalse( + self.filter._passes_definition(definition, event) + ) + + def test_definition_not_types_works_with_wildcards(self): + definition = { + "not_types": ["m.room.message", "org.matrix.*"] + } + event = MockEvent( + sender="@foo:bar", + type="org.matrix.custom.event", + room_id="!foo:bar" + ) + self.assertFalse( + self.filter._passes_definition(definition, event) + ) + + def test_definition_not_types_works_with_unknowns(self): + definition = { + "not_types": ["m.*", "org.*"] + } + event = MockEvent( + sender="@foo:bar", + type="com.nom.nom.nom", + room_id="!foo:bar" + ) + self.assertTrue( + self.filter._passes_definition(definition, event) + ) + + def test_definition_not_types_takes_priority_over_types(self): + definition = { + "not_types": ["m.*", "org.*"], + "types": ["m.room.message", "m.room.topic"] + } + event = MockEvent( + sender="@foo:bar", + type="m.room.topic", + room_id="!foo:bar" + ) + self.assertFalse( + self.filter._passes_definition(definition, event) + ) + + def test_definition_senders_works_with_literals(self): + definition = { + "senders": ["@flibble:wibble"] + } + event = MockEvent( + sender="@flibble:wibble", + type="com.nom.nom.nom", + room_id="!foo:bar" + ) + self.assertTrue( + self.filter._passes_definition(definition, event) + ) + + def test_definition_senders_works_with_unknowns(self): + definition = { + "senders": ["@flibble:wibble"] + } + event = MockEvent( + sender="@challenger:appears", + type="com.nom.nom.nom", + room_id="!foo:bar" + ) + self.assertFalse( + self.filter._passes_definition(definition, event) + ) + + def test_definition_not_senders_works_with_literals(self): + definition = { + "not_senders": ["@flibble:wibble"] + } + event = MockEvent( + sender="@flibble:wibble", + type="com.nom.nom.nom", + room_id="!foo:bar" + ) + self.assertFalse( + self.filter._passes_definition(definition, event) + ) + + def test_definition_not_senders_works_with_unknowns(self): + definition = { + "not_senders": ["@flibble:wibble"] + } + event = MockEvent( + sender="@challenger:appears", + type="com.nom.nom.nom", + room_id="!foo:bar" + ) + self.assertTrue( + self.filter._passes_definition(definition, event) + ) + + def test_definition_not_senders_takes_priority_over_senders(self): + definition = { + "not_senders": ["@misspiggy:muppets"], + "senders": ["@kermit:muppets", "@misspiggy:muppets"] + } + event = MockEvent( + sender="@misspiggy:muppets", + type="m.room.topic", + room_id="!foo:bar" + ) + self.assertFalse( + self.filter._passes_definition(definition, event) + ) + + def test_definition_rooms_works_with_literals(self): + definition = { + "rooms": ["!secretbase:unknown"] + } + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown" + ) + self.assertTrue( + self.filter._passes_definition(definition, event) + ) + + def test_definition_rooms_works_with_unknowns(self): + definition = { + "rooms": ["!secretbase:unknown"] + } + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!anothersecretbase:unknown" + ) + self.assertFalse( + self.filter._passes_definition(definition, event) + ) + + def test_definition_not_rooms_works_with_literals(self): + definition = { + "not_rooms": ["!anothersecretbase:unknown"] + } + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!anothersecretbase:unknown" + ) + self.assertFalse( + self.filter._passes_definition(definition, event) + ) + + def test_definition_not_rooms_works_with_unknowns(self): + definition = { + "not_rooms": ["!secretbase:unknown"] + } + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!anothersecretbase:unknown" + ) + self.assertTrue( + self.filter._passes_definition(definition, event) + ) + + def test_definition_not_rooms_takes_priority_over_rooms(self): + definition = { + "not_rooms": ["!secretbase:unknown"], + "rooms": ["!secretbase:unknown"] + } + event = MockEvent( + sender="@foo:bar", + type="m.room.message", + room_id="!secretbase:unknown" + ) + self.assertFalse( + self.filter._passes_definition(definition, event) + ) + + def test_definition_combined_event(self): + definition = { + "not_senders": ["@misspiggy:muppets"], + "senders": ["@kermit:muppets"], + "rooms": ["!stage:unknown"], + "not_rooms": ["!piggyshouse:muppets"], + "types": ["m.room.message", "muppets.kermit.*"], + "not_types": ["muppets.misspiggy.*"] + } + event = MockEvent( + sender="@kermit:muppets", # yup + type="m.room.message", # yup + room_id="!stage:unknown" # yup + ) + self.assertTrue( + self.filter._passes_definition(definition, event) + ) + + def test_definition_combined_event_bad_sender(self): + definition = { + "not_senders": ["@misspiggy:muppets"], + "senders": ["@kermit:muppets"], + "rooms": ["!stage:unknown"], + "not_rooms": ["!piggyshouse:muppets"], + "types": ["m.room.message", "muppets.kermit.*"], + "not_types": ["muppets.misspiggy.*"] + } + event = MockEvent( + sender="@misspiggy:muppets", # nope + type="m.room.message", # yup + room_id="!stage:unknown" # yup + ) + self.assertFalse( + self.filter._passes_definition(definition, event) + ) + + def test_definition_combined_event_bad_room(self): + definition = { + "not_senders": ["@misspiggy:muppets"], + "senders": ["@kermit:muppets"], + "rooms": ["!stage:unknown"], + "not_rooms": ["!piggyshouse:muppets"], + "types": ["m.room.message", "muppets.kermit.*"], + "not_types": ["muppets.misspiggy.*"] + } + event = MockEvent( + sender="@kermit:muppets", # yup + type="m.room.message", # yup + room_id="!piggyshouse:muppets" # nope + ) + self.assertFalse( + self.filter._passes_definition(definition, event) + ) + + def test_definition_combined_event_bad_type(self): + definition = { + "not_senders": ["@misspiggy:muppets"], + "senders": ["@kermit:muppets"], + "rooms": ["!stage:unknown"], + "not_rooms": ["!piggyshouse:muppets"], + "types": ["m.room.message", "muppets.kermit.*"], + "not_types": ["muppets.misspiggy.*"] + } + event = MockEvent( + sender="@kermit:muppets", # yup + type="muppets.misspiggy.kisses", # nope + room_id="!stage:unknown" # yup + ) + self.assertFalse( + self.filter._passes_definition(definition, event) + ) + + @defer.inlineCallbacks + def test_filter_public_user_data_match(self): + user_filter_json = { + "public_user_data": { + "types": ["m.*"] + } + } + user = UserID.from_string("@" + user_localpart + ":test") + filter_id = yield self.datastore.add_user_filter( + user_localpart=user_localpart, + user_filter=user_filter_json, + ) + event = MockEvent( + sender="@foo:bar", + type="m.profile", + room_id="!foo:bar" + ) + events = [event] + + user_filter = yield self.filtering.get_user_filter( + user_localpart=user_localpart, + filter_id=filter_id, + ) + + results = user_filter.filter_public_user_data(events=events) + self.assertEquals(events, results) + + @defer.inlineCallbacks + def test_filter_public_user_data_no_match(self): + user_filter_json = { + "public_user_data": { + "types": ["m.*"] + } + } + user = UserID.from_string("@" + user_localpart + ":test") + filter_id = yield self.datastore.add_user_filter( + user_localpart=user_localpart, + user_filter=user_filter_json, + ) + event = MockEvent( + sender="@foo:bar", + type="custom.avatar.3d.crazy", + room_id="!foo:bar" + ) + events = [event] + + user_filter = yield self.filtering.get_user_filter( + user_localpart=user_localpart, + filter_id=filter_id, + ) + + results = user_filter.filter_public_user_data(events=events) + self.assertEquals([], results) + + @defer.inlineCallbacks + def test_filter_room_state_match(self): + user_filter_json = { + "room": { + "state": { + "types": ["m.*"] + } + } + } + user = UserID.from_string("@" + user_localpart + ":test") + filter_id = yield self.datastore.add_user_filter( + user_localpart=user_localpart, + user_filter=user_filter_json, + ) + event = MockEvent( + sender="@foo:bar", + type="m.room.topic", + room_id="!foo:bar" + ) + events = [event] + + user_filter = yield self.filtering.get_user_filter( + user_localpart=user_localpart, + filter_id=filter_id, + ) + + results = user_filter.filter_room_state(events=events) + self.assertEquals(events, results) + + @defer.inlineCallbacks + def test_filter_room_state_no_match(self): + user_filter_json = { + "room": { + "state": { + "types": ["m.*"] + } + } + } + user = UserID.from_string("@" + user_localpart + ":test") + filter_id = yield self.datastore.add_user_filter( + user_localpart=user_localpart, + user_filter=user_filter_json, + ) + event = MockEvent( + sender="@foo:bar", + type="org.matrix.custom.event", + room_id="!foo:bar" + ) + events = [event] + + user_filter = yield self.filtering.get_user_filter( + user_localpart=user_localpart, + filter_id=filter_id, + ) + + results = user_filter.filter_room_state(events) + self.assertEquals([], results) + + @defer.inlineCallbacks + def test_add_filter(self): + user_filter_json = { + "room": { + "state": { + "types": ["m.*"] + } + } + } + + filter_id = yield self.filtering.add_user_filter( + user_localpart=user_localpart, + user_filter=user_filter_json, + ) + + self.assertEquals(filter_id, 0) + self.assertEquals(user_filter_json, + (yield self.datastore.get_user_filter( + user_localpart=user_localpart, + filter_id=0, + )) + ) + + @defer.inlineCallbacks + def test_get_filter(self): + user_filter_json = { + "room": { + "state": { + "types": ["m.*"] + } + } + } + + filter_id = yield self.datastore.add_user_filter( + user_localpart=user_localpart, + user_filter=user_filter_json, + ) + + filter = yield self.filtering.get_user_filter( + user_localpart=user_localpart, + filter_id=filter_id, + ) + + self.assertEquals(filter.filter_json, user_filter_json) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index ed21defd13..44dbce6bea 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase): "get_room", "get_destination_retry_timings", "set_destination_retry_timings", + "have_events", ]), resource_for_federation=NonCallableMock(), http_client=NonCallableMock(spec_set=[]), @@ -90,6 +91,7 @@ class FederationTestCase(unittest.TestCase): self.datastore.persist_event.return_value = defer.succeed(None) self.datastore.get_room.return_value = defer.succeed(True) self.auth.check_host_in_room.return_value = defer.succeed(True) + self.datastore.have_events.return_value = defer.succeed({}) def annotate(ev, old_state=None): context = Mock() diff --git a/tests/rest/client/v1/test_presence.py b/tests/rest/client/v1/test_presence.py index 65d5cc4916..f849120a3e 100644 --- a/tests/rest/client/v1/test_presence.py +++ b/tests/rest/client/v1/test_presence.py @@ -75,6 +75,7 @@ class PresenceStateTestCase(unittest.TestCase): "user": UserID.from_string(myid), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token @@ -165,6 +166,7 @@ class PresenceListTestCase(unittest.TestCase): "user": UserID.from_string(myid), "admin": False, "device_id": None, + "token_id": 1, } hs.handlers.room_member_handler = Mock( @@ -282,7 +284,7 @@ class PresenceEventStreamTestCase(unittest.TestCase): hs.get_clock().time_msec.return_value = 1000000 def _get_user_by_req(req=None): - return UserID.from_string(myid) + return (UserID.from_string(myid), "") hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_profile.py b/tests/rest/client/v1/test_profile.py index 39cd68d829..6a2085276a 100644 --- a/tests/rest/client/v1/test_profile.py +++ b/tests/rest/client/v1/test_profile.py @@ -58,7 +58,7 @@ class ProfileTestCase(unittest.TestCase): ) def _get_user_by_req(request=None): - return UserID.from_string(myid) + return (UserID.from_string(myid), "") hs.get_auth().get_user_by_req = _get_user_by_req diff --git a/tests/rest/client/v1/test_rooms.py b/tests/rest/client/v1/test_rooms.py index 76ed550b75..81ead10e76 100644 --- a/tests/rest/client/v1/test_rooms.py +++ b/tests/rest/client/v1/test_rooms.py @@ -70,6 +70,7 @@ class RoomPermissionsTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token @@ -466,6 +467,7 @@ class RoomsMemberListTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token @@ -555,6 +557,7 @@ class RoomsCreateTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token @@ -657,6 +660,7 @@ class RoomTopicTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token @@ -773,6 +777,7 @@ class RoomMemberStateTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token @@ -909,6 +914,7 @@ class RoomMessagesTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token @@ -1013,6 +1019,7 @@ class RoomInitialSyncTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token diff --git a/tests/rest/client/v1/test_typing.py b/tests/rest/client/v1/test_typing.py index c89b37d004..c5d5b06da3 100644 --- a/tests/rest/client/v1/test_typing.py +++ b/tests/rest/client/v1/test_typing.py @@ -73,6 +73,7 @@ class RoomTypingTestCase(RestTestCase): "user": UserID.from_string(self.auth_user_id), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token diff --git a/tests/rest/client/v2_alpha/__init__.py b/tests/rest/client/v2_alpha/__init__.py index f59745e13c..fa70575c57 100644 --- a/tests/rest/client/v2_alpha/__init__.py +++ b/tests/rest/client/v2_alpha/__init__.py @@ -39,9 +39,7 @@ class V2AlphaRestTestCase(unittest.TestCase): hs = HomeServer("test", db_pool=None, - datastore=Mock(spec=[ - "insert_client_ip", - ]), + datastore=self.make_datastore_mock(), http_client=None, resource_for_client=self.mock_resource, resource_for_federation=self.mock_resource, @@ -53,8 +51,14 @@ class V2AlphaRestTestCase(unittest.TestCase): "user": UserID.from_string(self.USER_ID), "admin": False, "device_id": None, + "token_id": 1, } hs.get_auth().get_user_by_token = _get_user_by_token for r in self.TO_REGISTER: r.register_servlets(hs, self.mock_resource) + + def make_datastore_mock(self): + return Mock(spec=[ + "insert_client_ip", + ]) diff --git a/tests/rest/client/v2_alpha/test_filter.py b/tests/rest/client/v2_alpha/test_filter.py new file mode 100644 index 0000000000..80ddabf818 --- /dev/null +++ b/tests/rest/client/v2_alpha/test_filter.py @@ -0,0 +1,95 @@ +# -*- coding: utf-8 -*- +# Copyright 2015 OpenMarket Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.internet import defer + +from mock import Mock + +from . import V2AlphaRestTestCase + +from synapse.rest.client.v2_alpha import filter + +from synapse.api.errors import StoreError + + +class FilterTestCase(V2AlphaRestTestCase): + USER_ID = "@apple:test" + TO_REGISTER = [filter] + + def make_datastore_mock(self): + datastore = super(FilterTestCase, self).make_datastore_mock() + + self._user_filters = {} + + def add_user_filter(user_localpart, definition): + filters = self._user_filters.setdefault(user_localpart, []) + filter_id = len(filters) + filters.append(definition) + return defer.succeed(filter_id) + datastore.add_user_filter = add_user_filter + + def get_user_filter(user_localpart, filter_id): + if user_localpart not in self._user_filters: + raise StoreError(404, "No user") + filters = self._user_filters[user_localpart] + if filter_id >= len(filters): + raise StoreError(404, "No filter") + return defer.succeed(filters[filter_id]) + datastore.get_user_filter = get_user_filter + + return datastore + + @defer.inlineCallbacks + def test_add_filter(self): + (code, response) = yield self.mock_resource.trigger("POST", + "/user/%s/filter" % (self.USER_ID), + '{"type": ["m.*"]}' + ) + self.assertEquals(200, code) + self.assertEquals({"filter_id": "0"}, response) + + self.assertIn("apple", self._user_filters) + self.assertEquals(len(self._user_filters["apple"]), 1) + self.assertEquals({"type": ["m.*"]}, self._user_filters["apple"][0]) + + @defer.inlineCallbacks + def test_get_filter(self): + self._user_filters["apple"] = [ + {"type": ["m.*"]} + ] + + (code, response) = yield self.mock_resource.trigger("GET", + "/user/%s/filter/0" % (self.USER_ID), None + ) + self.assertEquals(200, code) + self.assertEquals({"type": ["m.*"]}, response) + + @defer.inlineCallbacks + def test_get_filter_no_id(self): + self._user_filters["apple"] = [ + {"type": ["m.*"]} + ] + + (code, response) = yield self.mock_resource.trigger("GET", + "/user/%s/filter/2" % (self.USER_ID), None + ) + self.assertEquals(404, code) + + @defer.inlineCallbacks + def test_get_filter_no_user(self): + (code, response) = yield self.mock_resource.trigger("GET", + "/user/%s/filter/0" % (self.USER_ID), None + ) + self.assertEquals(404, code) diff --git a/tests/storage/test_registration.py b/tests/storage/test_registration.py index 84bfde7568..6f8bea2f61 100644 --- a/tests/storage/test_registration.py +++ b/tests/storage/test_registration.py @@ -53,7 +53,10 @@ class RegistrationStoreTestCase(unittest.TestCase): ) self.assertEquals( - {"admin": 0, "device_id": None, "name": self.user_id}, + {"admin": 0, + "device_id": None, + "name": self.user_id, + "token_id": 1}, (yield self.store.get_user_by_token(self.tokens[0])) ) @@ -63,7 +66,10 @@ class RegistrationStoreTestCase(unittest.TestCase): yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) self.assertEquals( - {"admin": 0, "device_id": None, "name": self.user_id}, + {"admin": 0, + "device_id": None, + "name": self.user_id, + "token_id": 2}, (yield self.store.get_user_by_token(self.tokens[1])) ) diff --git a/tests/test_state.py b/tests/test_state.py index 98ad9e54cd..019e794aa2 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -16,11 +16,120 @@ from tests import unittest from twisted.internet import defer +from synapse.events import FrozenEvent +from synapse.api.auth import Auth +from synapse.api.constants import EventTypes, Membership from synapse.state import StateHandler from mock import Mock +_next_event_id = 1000 + + +def create_event(name=None, type=None, state_key=None, depth=2, event_id=None, + prev_events=[], **kwargs): + global _next_event_id + + if not event_id: + _next_event_id += 1 + event_id = str(_next_event_id) + + if not name: + if state_key is not None: + name = "<%s-%s, %s>" % (type, state_key, event_id,) + else: + name = "<%s, %s>" % (type, event_id,) + + d = { + "event_id": event_id, + "type": type, + "sender": "@user_id:example.com", + "room_id": "!room_id:example.com", + "depth": depth, + "prev_events": prev_events, + } + + if state_key is not None: + d["state_key"] = state_key + + d.update(kwargs) + + event = FrozenEvent(d) + + return event + + +class StateGroupStore(object): + def __init__(self): + self._event_to_state_group = {} + self._group_to_state = {} + + self._next_group = 1 + + def get_state_groups(self, event_ids): + groups = {} + for event_id in event_ids: + group = self._event_to_state_group.get(event_id) + if group: + groups[group] = self._group_to_state[group] + + return defer.succeed(groups) + + def store_state_groups(self, event, context): + if context.current_state is None: + return + + state_events = context.current_state + + if event.is_state(): + state_events[(event.type, event.state_key)] = event + + state_group = context.state_group + if not state_group: + state_group = self._next_group + self._next_group += 1 + + self._group_to_state[state_group] = state_events.values() + + self._event_to_state_group[event.event_id] = state_group + + +class DictObj(dict): + def __init__(self, **kwargs): + super(DictObj, self).__init__(kwargs) + self.__dict__ = self + + +class Graph(object): + def __init__(self, nodes, edges): + events = {} + clobbered = set(events.keys()) + + for event_id, fields in nodes.items(): + refs = edges.get(event_id) + if refs: + clobbered.difference_update(refs) + prev_events = [(r, {}) for r in refs] + else: + prev_events = [] + + events[event_id] = create_event( + event_id=event_id, + prev_events=prev_events, + **fields + ) + + self._leaves = clobbered + self._events = sorted(events.values(), key=lambda e: e.depth) + + def walk(self): + return iter(self._events) + + def get_leaves(self): + return (self._events[i] for i in self._leaves) + + class StateTestCase(unittest.TestCase): def setUp(self): self.store = Mock( @@ -29,20 +138,188 @@ class StateTestCase(unittest.TestCase): "add_event_hashes", ] ) - hs = Mock(spec=["get_datastore"]) + hs = Mock(spec=["get_datastore", "get_auth", "get_state_handler"]) hs.get_datastore.return_value = self.store + hs.get_state_handler.return_value = None + hs.get_auth.return_value = Auth(hs) self.state = StateHandler(hs) self.event_id = 0 + @defer.inlineCallbacks + def test_branch_no_conflict(self): + graph = Graph( + nodes={ + "START": DictObj( + type=EventTypes.Create, + state_key="", + depth=1, + ), + "A": DictObj( + type=EventTypes.Message, + depth=2, + ), + "B": DictObj( + type=EventTypes.Message, + depth=3, + ), + "C": DictObj( + type=EventTypes.Name, + state_key="", + depth=3, + ), + "D": DictObj( + type=EventTypes.Message, + depth=4, + ), + }, + edges={ + "A": ["START"], + "B": ["A"], + "C": ["A"], + "D": ["B", "C"] + } + ) + + store = StateGroupStore() + self.store.get_state_groups.side_effect = store.get_state_groups + + context_store = {} + + for event in graph.walk(): + context = yield self.state.compute_event_context(event) + store.store_state_groups(event, context) + context_store[event.event_id] = context + + self.assertEqual(2, len(context_store["D"].current_state)) + + @defer.inlineCallbacks + def test_branch_basic_conflict(self): + graph = Graph( + nodes={ + "START": DictObj( + type=EventTypes.Create, + state_key="creator", + content={"membership": "@user_id:example.com"}, + depth=1, + ), + "A": DictObj( + type=EventTypes.Member, + state_key="@user_id:example.com", + content={"membership": Membership.JOIN}, + membership=Membership.JOIN, + depth=2, + ), + "B": DictObj( + type=EventTypes.Name, + state_key="", + depth=3, + ), + "C": DictObj( + type=EventTypes.Name, + state_key="", + depth=4, + ), + "D": DictObj( + type=EventTypes.Message, + depth=5, + ), + }, + edges={ + "A": ["START"], + "B": ["A"], + "C": ["A"], + "D": ["B", "C"] + } + ) + + store = StateGroupStore() + self.store.get_state_groups.side_effect = store.get_state_groups + + context_store = {} + + for event in graph.walk(): + context = yield self.state.compute_event_context(event) + store.store_state_groups(event, context) + context_store[event.event_id] = context + + self.assertSetEqual( + {"START", "A", "C"}, + {e.event_id for e in context_store["D"].current_state.values()} + ) + + @defer.inlineCallbacks + def test_branch_have_banned_conflict(self): + graph = Graph( + nodes={ + "START": DictObj( + type=EventTypes.Create, + state_key="creator", + content={"membership": "@user_id:example.com"}, + depth=1, + ), + "A": DictObj( + type=EventTypes.Member, + state_key="@user_id:example.com", + content={"membership": Membership.JOIN}, + membership=Membership.JOIN, + depth=2, + ), + "B": DictObj( + type=EventTypes.Name, + state_key="", + depth=3, + ), + "C": DictObj( + type=EventTypes.Member, + state_key="@user_id_2:example.com", + content={"membership": Membership.BAN}, + membership=Membership.BAN, + depth=4, + ), + "D": DictObj( + type=EventTypes.Name, + state_key="", + depth=4, + sender="@user_id_2:example.com", + ), + "E": DictObj( + type=EventTypes.Message, + depth=5, + ), + }, + edges={ + "A": ["START"], + "B": ["A"], + "C": ["B"], + "D": ["B"], + "E": ["C", "D"] + } + ) + + store = StateGroupStore() + self.store.get_state_groups.side_effect = store.get_state_groups + + context_store = {} + + for event in graph.walk(): + context = yield self.state.compute_event_context(event) + store.store_state_groups(event, context) + context_store[event.event_id] = context + + self.assertSetEqual( + {"START", "A", "B", "C"}, + {e.event_id for e in context_store["E"].current_state.values()} + ) + @defer.inlineCallbacks def test_annotate_with_old_message(self): - event = self.create_event(type="test_message", name="event") + event = create_event(type="test_message", name="event") old_state = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context( @@ -62,12 +339,12 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_annotate_with_old_state(self): - event = self.create_event(type="state", state_key="", name="event") + event = create_event(type="state", state_key="", name="event") old_state = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] context = yield self.state.compute_event_context( @@ -88,13 +365,12 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_trivial_annotate_message(self): - event = self.create_event(type="test_message", name="event") - event.prev_events = [] + event = create_event(type="test_message", name="event") old_state = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] group_name = "group_name_1" @@ -119,13 +395,12 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_trivial_annotate_state(self): - event = self.create_event(type="state", state_key="", name="event") - event.prev_events = [] + event = create_event(type="state", state_key="", name="event") old_state = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] group_name = "group_name_1" @@ -150,30 +425,21 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_resolve_message_conflict(self): - event = self.create_event(type="test_message", name="event") - event.prev_events = [] + event = create_event(type="test_message", name="event") old_state_1 = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] old_state_2 = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test3", state_key="2"), - self.create_event(type="test4", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test3", state_key="2"), + create_event(type="test4", state_key=""), ] - group_name_1 = "group_name_1" - group_name_2 = "group_name_2" - - self.store.get_state_groups.return_value = { - group_name_1: old_state_1, - group_name_2: old_state_2, - } - - context = yield self.state.compute_event_context(event) + context = yield self._get_context(event, old_state_1, old_state_2) self.assertEqual(len(context.current_state), 5) @@ -181,21 +447,70 @@ class StateTestCase(unittest.TestCase): @defer.inlineCallbacks def test_resolve_state_conflict(self): - event = self.create_event(type="test4", state_key="", name="event") - event.prev_events = [] + event = create_event(type="test4", state_key="", name="event") old_state_1 = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test1", state_key="2"), - self.create_event(type="test2", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test1", state_key="2"), + create_event(type="test2", state_key=""), ] old_state_2 = [ - self.create_event(type="test1", state_key="1"), - self.create_event(type="test3", state_key="2"), - self.create_event(type="test4", state_key=""), + create_event(type="test1", state_key="1"), + create_event(type="test3", state_key="2"), + create_event(type="test4", state_key=""), ] + context = yield self._get_context(event, old_state_1, old_state_2) + + self.assertEqual(len(context.current_state), 5) + + self.assertIsNone(context.state_group) + + @defer.inlineCallbacks + def test_standard_depth_conflict(self): + event = create_event(type="test4", name="event") + + member_event = create_event( + type=EventTypes.Member, + state_key="@user_id:example.com", + content={ + "membership": Membership.JOIN, + } + ) + + old_state_1 = [ + member_event, + create_event(type="test1", state_key="1", depth=1), + ] + + old_state_2 = [ + member_event, + create_event(type="test1", state_key="1", depth=2), + ] + + context = yield self._get_context(event, old_state_1, old_state_2) + + self.assertEqual(old_state_2[1], context.current_state[("test1", "1")]) + + # Reverse the depth to make sure we are actually using the depths + # during state resolution. + + old_state_1 = [ + member_event, + create_event(type="test1", state_key="1", depth=2), + ] + + old_state_2 = [ + member_event, + create_event(type="test1", state_key="1", depth=1), + ] + + context = yield self._get_context(event, old_state_1, old_state_2) + + self.assertEqual(old_state_1[1], context.current_state[("test1", "1")]) + + def _get_context(self, event, old_state_1, old_state_2): group_name_1 = "group_name_1" group_name_2 = "group_name_2" @@ -204,33 +519,4 @@ class StateTestCase(unittest.TestCase): group_name_2: old_state_2, } - context = yield self.state.compute_event_context(event) - - self.assertEqual(len(context.current_state), 5) - - self.assertIsNone(context.state_group) - - def create_event(self, name=None, type=None, state_key=None): - self.event_id += 1 - event_id = str(self.event_id) - - if not name: - if state_key is not None: - name = "<%s-%s>" % (type, state_key) - else: - name = "<%s>" % (type, ) - - event = Mock(name=name, spec=[]) - event.type = type - - if state_key is not None: - event.state_key = state_key - event.event_id = event_id - - event.is_state = lambda: (state_key is not None) - event.unsigned = {} - - event.user_id = "@user_id:example.com" - event.room_id = "!room_id:example.com" - - return event + return self.state.compute_event_context(event)