Avoid so much copypasta between 3PU and 3PL query by unifying around a ThirdPartyEntityKind enumeration

This commit is contained in:
Paul "LeoNerd" Evans 2016-08-18 17:19:55 +01:00
parent 2a91799fcc
commit b515f844ee
4 changed files with 35 additions and 43 deletions

View file

@ -17,6 +17,7 @@ from twisted.internet import defer
from synapse.api.errors import CodeMessageException
from synapse.http.client import SimpleHttpClient
from synapse.events.utils import serialize_event
from synapse.types import ThirdPartyEntityKind
import logging
import urllib
@ -72,25 +73,21 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(False)
@defer.inlineCallbacks
def query_3pu(self, service, protocol, fields):
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
response = None
try:
response = yield self.get_json(uri, fields)
defer.returnValue(response)
except Exception as ex:
logger.warning("query_3pu to %s threw exception %s", uri, ex)
defer.returnValue([])
def query_3pe(self, service, kind, protocol, fields):
if kind == ThirdPartyEntityKind.USER:
uri = "%s/3pu/%s" % (service.url, urllib.quote(protocol))
elif kind == ThirdPartyEntityKind.LOCATION:
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
else:
raise ValueError(
"Unrecognised 'kind' argument %r to query_3pe()", kind
)
@defer.inlineCallbacks
def query_3pl(self, service, protocol, fields):
uri = "%s/3pl/%s" % (service.url, urllib.quote(protocol))
response = None
try:
response = yield self.get_json(uri, fields)
defer.returnValue(response)
except Exception as ex:
logger.warning("query_3pl to %s threw exception %s", uri, ex)
logger.warning("query_3pe to %s threw exception %s", uri, ex)
defer.returnValue([])
@defer.inlineCallbacks

View file

@ -18,6 +18,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.util.metrics import Measure
from synapse.util.logcontext import preserve_fn
from synapse.types import ThirdPartyEntityKind
import logging
@ -169,14 +170,20 @@ class ApplicationServicesHandler(object):
defer.returnValue(result)
@defer.inlineCallbacks
def query_3pu(self, protocol, fields):
def query_3pe(self, kind, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
results = yield defer.DeferredList([
self.appservice_api.query_3pu(service, protocol, fields)
self.appservice_api.query_3pe(service, kind, protocol, fields)
for service in services
], consumeErrors=True)
required_field = (
"userid" if kind == ThirdPartyEntityKind.USER else
"alias" if kind == ThirdPartyEntityKind.LOCATION else
None
)
ret = []
for (success, result) in results:
if not success:
@ -184,31 +191,7 @@ class ApplicationServicesHandler(object):
if not isinstance(result, list):
continue
for r in result:
if _is_valid_3pentity_result(r, field="userid"):
ret.append(r)
else:
logger.warn("Application service returned an " +
"invalid result %r", r)
defer.returnValue(ret)
@defer.inlineCallbacks
def query_3pl(self, protocol, fields):
services = yield self._get_services_for_3pn(protocol)
results = yield defer.DeferredList([
self.appservice_api.query_3pl(service, protocol, fields)
for service in services
], consumeErrors=True)
ret = []
for (success, result) in results:
if not success:
continue
if not isinstance(result, list):
continue
for r in result:
if _is_valid_3pentity_result(r, field="alias"):
if _is_valid_3pentity_result(r, field=required_field):
ret.append(r)
else:
logger.warn("Application service returned an " +

View file

@ -19,6 +19,7 @@ import logging
from twisted.internet import defer
from synapse.http.servlet import RestServlet
from synapse.types import ThirdPartyEntityKind
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@ -41,7 +42,9 @@ class ThirdPartyUserServlet(RestServlet):
fields = request.args
del fields["access_token"]
results = yield self.appservice_handler.query_3pu(protocol, fields)
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.USER, protocol, fields
)
defer.returnValue((200, results))
@ -63,7 +66,9 @@ class ThirdPartyLocationServlet(RestServlet):
fields = request.args
del fields["access_token"]
results = yield self.appservice_handler.query_3pl(protocol, fields)
results = yield self.appservice_handler.query_3pe(
ThirdPartyEntityKind.LOCATION, protocol, fields
)
defer.returnValue((200, results))

View file

@ -269,3 +269,10 @@ class RoomStreamToken(namedtuple("_StreamToken", "topological stream")):
return "t%d-%d" % (self.topological, self.stream)
else:
return "s%d" % (self.stream,)
# Some arbitrary constants used for internal API enumerations. Don't rely on
# exact values; always pass or compare symbolically
class ThirdPartyEntityKind(object):
USER = 'user'
LOCATION = 'location'